提交 df1234f0 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix optimization that would return the wrong type in certain cases.

上级 a7a18ca9
......@@ -6035,20 +6035,24 @@ def local_log1p(node):
log_arg.owner.inputs, only_process_constants=True)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and numpy.allclose(numpy.sum(scalars), 1):
if not nonconsts:
pass # leave for constant-merge
if len(nonconsts) == 1:
return _fill_chain(T.log1p(nonconsts[0]), scalar_inputs)
else:
return _fill_chain(T.log1p(T.add(*nonconsts)),
scalar_inputs)
if nonconsts:
if len(nonconsts) > 1:
ninp = T.add(*nonconsts)
else:
ninp = nonconsts[0]
if ninp.dtype != node.outputs[0].dtype:
ninp = ninp.astype(node.outputs[0].dtype)
return _fill_chain(T.log1p(ninp), scalar_inputs)
elif log_arg.owner and log_arg.owner.op == T.sub:
one = T.extract_constant(log_arg.owner.inputs[0],
only_process_constants=True)
if one != 1:
return
return [T.log1p(T.neg(log_arg.owner.inputs[1]))]
other = log_arg.owner.inputs[1]
if other.dtype != log_arg.dtype:
other = other.astype(log_arg.dtype)
return [T.log1p(T.neg(other))]
# TODO: in canonicalize, change log10 and log2 -> log
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论