提交 6f1afdb7 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Caglar

Fix a test by generalising the opt local_log1p.

上级 77bfb370
...@@ -134,7 +134,7 @@ def scalarconsts_rest(inputs): ...@@ -134,7 +134,7 @@ def scalarconsts_rest(inputs):
nonconsts = [] nonconsts = []
for i in inputs: for i in inputs:
try: try:
v = get_scalar_constant_value(i) v = get_scalar_constant_value(i, only_process_constants=True)
consts.append(v) consts.append(v)
origconsts.append(i) origconsts.append(i)
except NotScalarConstantError: except NotScalarConstantError:
...@@ -5786,7 +5786,7 @@ def local_abs_merge(node): ...@@ -5786,7 +5786,7 @@ def local_abs_merge(node):
@gof.local_optimizer([T.log]) @gof.local_optimizer([T.log])
def local_log1p(node): def local_log1p(node):
# log(1+x) -> log1p(x) # log(1+x) -> log1p(x)
# log(1-x) -> log1p(-x)
if node.op == T.log: if node.op == T.log:
log_arg, = node.inputs log_arg, = node.inputs
if log_arg.owner and log_arg.owner.op == T.add: if log_arg.owner and log_arg.owner.op == T.add:
...@@ -5802,6 +5802,13 @@ def local_log1p(node): ...@@ -5802,6 +5802,13 @@ def local_log1p(node):
return _fill_chain(T.log1p(T.add(*nonconsts)), return _fill_chain(T.log1p(T.add(*nonconsts)),
scalar_inputs) scalar_inputs)
elif log_arg.owner and log_arg.owner.op == T.sub:
one = T.extract_constant(log_arg.owner.inputs[0],
only_process_constants=True)
if one != 1:
return
return [T.log1p(T.neg(log_arg.owner.inputs[1]))]
# TODO: in canonicalize, change log10 and log2 -> log # TODO: in canonicalize, change log10 and log2 -> log
@register_stabilize @register_stabilize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论