提交 abf7bc89 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add a case to local_neg_div_neg

上级 16ff6ebb
......@@ -932,8 +932,11 @@ def local_neg_neg(node):
@register_specialize
@gof.local_optimizer([T.neg])
def local_neg_div_neg(node):
"""- (-a / b) -> a / b
Also performs - (c / b) -> ((-c) / b) when c is a scalar constant.
"""
if node.op == T.neg:
"""- (-a / b) -> a / b"""
if node.inputs[0].owner and node.inputs[0].owner.op == T.true_div:
frac = node.inputs[0]
num, denom = frac.owner.inputs
......@@ -942,6 +945,11 @@ def local_neg_div_neg(node):
# No other clients of the original division
new_num = num.owner.inputs[0]
return [T.true_div(new_num, denom)]
elif numpy.all(num.broadcastable) and isinstance(num, gof.Constant):
if len(frac.clients) == 1:
new_num = -num.data
return [T.true_div(new_num, denom)]
@gof.local_optimizer([T.mul])
def local_mul_zero(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论