提交 bd16aee9 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

remove safe_to_cpu, and re-wrote safe_new

One thing that I've tried is to avoid importing cuda.
上级 6a6ac9b2
......@@ -30,24 +30,31 @@ import theano
# Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_utils')
def safe_new(x):
if isinstance(x, numpy.ndarray):
x = tensor.as_tensor_variable(x)
if cuda.cuda_available and isinstance(x.type, cuda.CudaNdarrayType):
return tensor.TensorType(
broadcastable = x.type.broadcastable
, dtype = config.floatX)()
def safe_new(x, tag = ''):
"""
Internal function that constructs a new variable from x with the same
type, but with a different name ( old name + tag). This function is used
by gradient, or the R-op to construct new variables for the inputs of
the inner graph such that there is no interference between the original
graph and the newly constructed graph.
"""
if hasattr(x, 'name') and x.name is not None:
nw_name = x.name + tag
else:
return x.type()
def safe_to_cpu(x):
if isinstance(x, numpy.ndarray):
x = tensor.as_tensor_variable(x)
if cuda.cuda_available and isinstance(x.type, cuda.CudaNdarrayType):
return cuda.basic_ops.host_from_gpu(x)
nw_name = None
if isinstance(x.type, tensor.Constant):
return x.clone()
else:
return x
try:
x = tensor.as_tensor_variable(x)
except TypeError:
# This could happend for example for random states, and I really
# want to avoid the convoluted logic that checks for cuda
# ndarrays
pass
nw_x = x.type()
nw_x.name = nw_name
return nw_x
def traverse(out, x,x_copy, d):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论