提交 c15933f2 authored 作者: Razvan Pascanu's avatar Razvan Pascanu 提交者: Razvan Pascanu

allow providing dtype to safe_new

上级 ee947ae8
...@@ -33,7 +33,7 @@ from theano.tensor.basic import get_constant_value ...@@ -33,7 +33,7 @@ from theano.tensor.basic import get_constant_value
_logger = logging.getLogger('theano.scan_utils') _logger = logging.getLogger('theano.scan_utils')
def safe_new(x, tag=''): def safe_new(x, tag='', dtype=None):
""" """
Internal function that constructs a new variable from x with the same Internal function that constructs a new variable from x with the same
type, but with a different name (old name + tag). This function is used type, but with a different name (old name + tag). This function is used
...@@ -46,12 +46,18 @@ def safe_new(x, tag=''): ...@@ -46,12 +46,18 @@ def safe_new(x, tag=''):
else: else:
nw_name = None nw_name = None
if isinstance(x, theano.Constant): if isinstance(x, theano.Constant):
return x.clone() if dtype and x.dtype != dtype:
return tensor.cast(x.clone(), dtype=dtype)
else:
return x.clone()
# Note, as_tensor_variable will convert the Scalar into a # Note, as_tensor_variable will convert the Scalar into a
# TensorScalar that will require a ScalarFromTensor op, # TensorScalar that will require a ScalarFromTensor op,
# making the pushout optimization fail # making the pushout optimization fail
elif isinstance(x, scalar.ScalarVariable): elif isinstance(x, scalar.ScalarVariable):
nw_x = x.type() if dtype:
new_x = scalar.Scalar(dtype=dtype)()
else:
nw_x = x.type()
nw_x.name = nw_name nw_x.name = nw_name
return nw_x return nw_x
else: else:
...@@ -63,6 +69,8 @@ def safe_new(x, tag=''): ...@@ -63,6 +69,8 @@ def safe_new(x, tag=''):
# ndarrays # ndarrays
pass pass
nw_x = x.type() nw_x = x.type()
if dtype and nw_x.dtype != dtype:
nw_x = tensor.cast(nw_x, dtype=dtype)
nw_x.name = nw_name nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used. # Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions # The test value is deep-copied to ensure there can be no interactions
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论