提交 939076b5 authored 作者: Frederic Bastien's avatar Frederic Bastien

Be sure to use the good dtype and not a possibly modified one from get_scalar_constant_value.

上级 e3b3ecc1
...@@ -5210,7 +5210,7 @@ def local_opt_alloc(node): ...@@ -5210,7 +5210,7 @@ def local_opt_alloc(node):
val = get_scalar_constant_value(input) val = get_scalar_constant_value(input)
assert val.size == 1 assert val.size == 1
# check which type of op # check which type of op
casted = T.mul(*shapes).astype(str(val.dtype)) casted = T.mul(*shapes).astype(str(input.dtype))
if isinstance(node.op, T.Sum): if isinstance(node.op, T.Sum):
val = val.reshape(1)[0] * casted val = val.reshape(1)[0] * casted
else: else:
...@@ -5227,7 +5227,7 @@ def local_opt_alloc(node): ...@@ -5227,7 +5227,7 @@ def local_opt_alloc(node):
to_prod = [shapes[i] for i in xrange(len(shapes)) to_prod = [shapes[i] for i in xrange(len(shapes))
if i in node.op.axis] if i in node.op.axis]
if to_prod: if to_prod:
casted = T.mul(*to_prod).astype(str(val.dtype)) casted = T.mul(*to_prod).astype(str(input.dtype))
if isinstance(node.op, T.Sum): if isinstance(node.op, T.Sum):
val *= casted val *= casted
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论