提交 5bc35bc2 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a transfer() method to explicit data transfers in the graph easy.

上级 9cbdc34b
......@@ -427,7 +427,8 @@ TensorVariable
you'll want to call.
.. class:: _tensor_py_operators(object)
.. autoclass:: _tensor_py_operators
:members:
This mix-in class adds convenient attributes, methods, and support
to TensorVariable, TensorConstant and TensorSharedVariable for
......
......@@ -17,6 +17,8 @@ from theano.configparser import (
config, AddConfigVar, BoolParam, FloatParam, StrParam)
from . import nvcc_compiler
from theano.tensor.basic import register_transfer
# ignore_newtrees is to speed the optimization as this is the pattern
# we use for optimization. Otherwise, we can iterate 100s of time on
# the graph and apply only a few optimizations each time.
......@@ -327,6 +329,12 @@ if cuda_available:
from . import opt, dnn
from .rng_curand import CURAND_RandomStreams
def transfer(x, target):
if target == 'gpu':
return as_cuda_ndarray_variable(x)
register_transfer(transfer)
def use(device,
force=False,
......
......@@ -6,6 +6,8 @@ import theano
from theano.configparser import config, AddConfigVar, BoolParam
from theano.compile import optdb
from theano.tensor.basic import register_transfer
_logger_name = 'theano.sandbox.gpuarray'
_logger = logging.getLogger(_logger_name)
......@@ -23,8 +25,18 @@ except ImportError:
from .type import (GpuArrayType, GpuArrayVariable, GpuArrayConstant,
GpuArraySharedVariable, gpuarray_shared_constructor,
reg_context)
from .basic import as_gpuarray_variable
from . import opt, nerv
def transfer(x, target):
try:
get_context(target)
return as_gpuarray_variable(x, target)
except ContextNotDefined:
pass
register_transfer(transfer)
def init_dev(dev, name=None):
if pygpu.gpuarray.api_version() != (-10000, 0):
......
......@@ -2844,11 +2844,46 @@ class Alloc(gof.Op):
return False
return True
alloc = Alloc()
pprint.assign(alloc, printing.FunctionPrinter('alloc'))
def transfer(var, target):
"""
Return a version of `var` transferred to `target`.
`cpu` mean a TensorType (on the CPU). Other types may define
additional targets.
Parameters
----------
var : variable
A theano variable
target : str
The target of the transfer
"""
if target == 'cpu':
return as_tensor_variable(var)
else:
for trans in transfer._others:
res = trans(var, target)
if res is not None:
return res
raise ValueError("Can't transfer to target %s" % (target,))
transfer._others = []
def register_transfer(fn):
"""
Register a transfer function for alternative targets.
Parameters
----------
fn : callable
"""
transfer._others.append(fn)
"""Create a duplicate of `a` (with duplicated storage)"""
tensor_copy = elemwise.Elemwise(scal.identity)
pprint.assign(tensor_copy, printing.IgnorePrinter())
......
......@@ -29,7 +29,7 @@ class AsTensorError(TypeError):
pass
class _tensor_py_operators:
class _tensor_py_operators(object):
# UNARY
def __abs__(self):
return theano.tensor.basic.abs_(self)
......@@ -369,6 +369,19 @@ class _tensor_py_operators:
def diagonal(self, offset=0, axis1=0, axis2=1):
return theano.tensor.basic.diagonal(self, offset, axis1, axis2)
# Transfer the data to another device
def transfer(self, target):
"""
If `target` is `'cpu'` this will transfer to a TensorType (if
not already one). Other types may define additional targets.
Paramters
---------
target : str
The desired location of the output variable
"""
return theano.tensor.transfer(self, target)
# Elemwise
def arccos(self):
return theano.tensor.arccos(self)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论