提交 f777fd27 authored 作者: p's avatar p

fixes the 'float64' in RNNs, wastes memory and breaks when warn_float64=raise

上级 5e5e5cc5
......@@ -5227,7 +5227,7 @@ def local_opt_alloc(node):
if i in node.op.axis]
if to_prod:
if isinstance(node.op, T.Sum):
val *= T.mul(*to_prod)
val *= T.mul(*to_prod).astype(str(val.dtype))
else:
val = val ** T.mul(*to_prod)
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论