提交 df1234f0 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix optimization that would return the wrong type in certain cases.

上级 a7a18ca9
...@@ -6035,20 +6035,24 @@ def local_log1p(node): ...@@ -6035,20 +6035,24 @@ def local_log1p(node):
log_arg.owner.inputs, only_process_constants=True) log_arg.owner.inputs, only_process_constants=True)
# scalar_inputs are potentially dimshuffled and fill'd scalars # scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and numpy.allclose(numpy.sum(scalars), 1): if scalars and numpy.allclose(numpy.sum(scalars), 1):
if not nonconsts: if nonconsts:
pass # leave for constant-merge if len(nonconsts) > 1:
if len(nonconsts) == 1: ninp = T.add(*nonconsts)
return _fill_chain(T.log1p(nonconsts[0]), scalar_inputs)
else: else:
return _fill_chain(T.log1p(T.add(*nonconsts)), ninp = nonconsts[0]
scalar_inputs) if ninp.dtype != node.outputs[0].dtype:
ninp = ninp.astype(node.outputs[0].dtype)
return _fill_chain(T.log1p(ninp), scalar_inputs)
elif log_arg.owner and log_arg.owner.op == T.sub: elif log_arg.owner and log_arg.owner.op == T.sub:
one = T.extract_constant(log_arg.owner.inputs[0], one = T.extract_constant(log_arg.owner.inputs[0],
only_process_constants=True) only_process_constants=True)
if one != 1: if one != 1:
return return
return [T.log1p(T.neg(log_arg.owner.inputs[1]))] other = log_arg.owner.inputs[1]
if other.dtype != log_arg.dtype:
other = other.astype(log_arg.dtype)
return [T.log1p(T.neg(other))]
# TODO: in canonicalize, change log10 and log2 -> log # TODO: in canonicalize, change log10 and log2 -> log
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论