提交 7173f901 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix test and remove upcast to float64.

上级 56e5b9c0
...@@ -318,10 +318,10 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -318,10 +318,10 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
@theano.configparser.change_flags(compute_test_value='raise') @theano.configparser.change_flags(compute_test_value='raise')
def test_compute_test_value(self): def test_compute_test_value(self):
x = T.scalar('x') x = T.scalar('x')
x.tag.test_value = np.array(1.) x.tag.test_value = np.array(1., dtype=config.floatX)
op = OpFromGraph([x], [x ** 3]) op = OpFromGraph([x], [x ** 3])
y = T.scalar('y') y = T.scalar('y')
y.tag.test_value = np.array(1.) y.tag.test_value = np.array(1., dtype=config.floatX)
f = op(y) f = op(y)
grad_f = T.grad(f, y) grad_f = T.grad(f, y)
assert grad_f.tag.test_value is not None assert grad_f.tag.test_value is not None
...@@ -5718,6 +5718,11 @@ def local_opt_alloc(node): ...@@ -5718,6 +5718,11 @@ def local_opt_alloc(node):
assert val.size == 1 assert val.size == 1
# check which type of op # check which type of op
size = T.mul(*shapes) size = T.mul(*shapes)
if input.dtype == "float32":
# shapes are ints and normally int64.
# We don't want to have a float64 upcast here
# if input is a float32.
size = size.astype(input.dtype)
if isinstance(node.op, T.Sum): if isinstance(node.op, T.Sum):
val = val.reshape(1)[0] * size val = val.reshape(1)[0] * size
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论