提交 08f3bf01 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Only cast the input to log1p if its type differs from the original input type.

上级 66a32849
...@@ -6040,7 +6040,7 @@ def local_log1p(node): ...@@ -6040,7 +6040,7 @@ def local_log1p(node):
ninp = T.add(*nonconsts) ninp = T.add(*nonconsts)
else: else:
ninp = nonconsts[0] ninp = nonconsts[0]
if ninp.dtype != node.outputs[0].dtype: if ninp.dtype != log_arg.type.dtype:
ninp = ninp.astype(node.outputs[0].dtype) ninp = ninp.astype(node.outputs[0].dtype)
return _fill_chain(T.log1p(ninp), scalar_inputs) return _fill_chain(T.log1p(ninp), scalar_inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论