提交 087c502c authored 作者: Razvan Pascanu's avatar Razvan Pascanu

use astype which works better

上级 4838cea1
...@@ -1348,8 +1348,7 @@ class Scan(PureOp): ...@@ -1348,8 +1348,7 @@ class Scan(PureOp):
dC_dXtm1s.append(dC_dXts[opos].type()) dC_dXtm1s.append(dC_dXts[opos].type())
if x.dtype != dC_dXts[opos].dtype: if x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \ dC_dinps_t[pos + self.n_seqs] = \
tensor.cast(x, x.astype(dC_dXts[opos].dtype)
dtype=dC_dXts[opos].dtype)
else: else:
dC_dXtm1s.append(x.type()) dC_dXtm1s.append(x.type())
for dx, dC_dXtm1 in enumerate(dC_dXtm1s): for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
......
...@@ -47,7 +47,7 @@ def safe_new(x, tag='', dtype=None): ...@@ -47,7 +47,7 @@ def safe_new(x, tag='', dtype=None):
nw_name = None nw_name = None
if isinstance(x, theano.Constant): if isinstance(x, theano.Constant):
if dtype and x.dtype != dtype: if dtype and x.dtype != dtype:
return tensor.cast(x.clone(), dtype=dtype) return x.clone().astype(dtype)
else: else:
return x.clone() return x.clone()
# Note, as_tensor_variable will convert the Scalar into a # Note, as_tensor_variable will convert the Scalar into a
...@@ -70,7 +70,7 @@ def safe_new(x, tag='', dtype=None): ...@@ -70,7 +70,7 @@ def safe_new(x, tag='', dtype=None):
pass pass
nw_x = x.type() nw_x = x.type()
if dtype and nw_x.dtype != dtype: if dtype and nw_x.dtype != dtype:
nw_x = tensor.cast(nw_x, dtype=dtype) nw_x = nw_x.astype(dtype)
nw_x.name = nw_name nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used. # Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions # The test value is deep-copied to ensure there can be no interactions
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论