提交 5967e2cb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use register_* decorators in basic_opt and math_opt

上级 881f3494
......@@ -622,6 +622,8 @@ def is_dimshuffle_useless(new_order, input):
return is_useless
@register_canonicalize
@register_specialize
@local_optimizer([DimShuffle])
def local_dimshuffle_lift(fgraph, node):
"""
......@@ -705,9 +707,6 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
return [ret]
register_canonicalize(local_dimshuffle_lift)
register_specialize(local_dimshuffle_lift)
######################
# Casting operations #
######################
......@@ -1633,6 +1632,7 @@ def local_elemwise_alloc(fgraph, node):
return ret
@register_canonicalize
@local_optimizer([Elemwise])
def local_fill_sink(fgraph, node):
"""
......@@ -1680,9 +1680,6 @@ def local_fill_sink(fgraph, node):
return replacements
register_canonicalize(local_fill_sink)
@register_specialize
@register_stabilize
# @register_canonicalize # We make full pass after the canonizer phase.
......
......@@ -1071,15 +1071,13 @@ local_mul_canonizer = AlgebraicCanonizer(
register_canonicalize(local_mul_canonizer, name="local_mul_canonizer")
@register_canonicalize
@local_optimizer([neg])
def local_neg_to_mul(fgraph, node):
if node.op == neg:
return [mul(np.array(-1, dtype=node.inputs[0].dtype), node.inputs[0])]
register_canonicalize(local_neg_to_mul)
@register_specialize
@local_optimizer([Sum, Prod])
def local_sum_prod_mul_by_scalar(fgraph, node):
......@@ -1779,6 +1777,7 @@ def local_neg_div_neg(fgraph, node):
return [true_div(new_num, denom)]
@register_canonicalize
@local_optimizer([mul])
def local_mul_zero(fgraph, node):
"""
......@@ -1800,9 +1799,8 @@ def local_mul_zero(fgraph, node):
return fill_chain(_asarray(0, dtype=otype.dtype), node.inputs)
register_canonicalize(local_mul_zero)
# TODO: Add this to the canonicalization to reduce redundancy.
@register_specialize
@local_optimizer([true_div])
def local_div_to_reciprocal(fgraph, node):
if node.op == true_div and np.all(
......@@ -1821,10 +1819,7 @@ def local_div_to_reciprocal(fgraph, node):
return False
# TODO: Add this to the canonicalization to reduce redundancy.
register_specialize(local_div_to_reciprocal)
@register_canonicalize
@local_optimizer([reciprocal])
def local_reciprocal_canon(fgraph, node):
if node.op == reciprocal:
......@@ -1833,9 +1828,7 @@ def local_reciprocal_canon(fgraph, node):
return False
register_canonicalize(local_reciprocal_canon)
@register_canonicalize
@local_optimizer([aet_pow])
def local_pow_canonicalize(fgraph, node):
if node.op == aet_pow:
......@@ -1848,9 +1841,6 @@ def local_pow_canonicalize(fgraph, node):
return False
register_canonicalize(local_pow_canonicalize)
@register_specialize
@local_optimizer([mul])
def local_mul_to_sqr(fgraph, node):
......@@ -1892,6 +1882,7 @@ def local_zero_div(fgraph, node):
return [ret]
@register_specialize
@local_optimizer([aet_pow])
def local_pow_specialize(fgraph, node):
# here, we are past the point of canonicalization, so we don't want
......@@ -1929,9 +1920,6 @@ def local_pow_specialize(fgraph, node):
return False
register_specialize(local_pow_specialize)
@register_specialize_device
@local_optimizer([aet_pow])
def local_pow_specialize_device(fgraph, node):
......@@ -1999,6 +1987,7 @@ def local_pow_specialize_device(fgraph, node):
return rval
@register_specialize
@local_optimizer([mul])
def local_mul_specialize(fgraph, node):
"""
......@@ -2074,9 +2063,7 @@ def local_mul_specialize(fgraph, node):
return [broadcast_like(1, node.outputs[0], fgraph)]
register_specialize(local_mul_specialize)
@register_specialize
@local_optimizer([add])
def local_add_specialize(fgraph, node):
def _fill_chain(v):
......@@ -2119,8 +2106,6 @@ def local_add_specialize(fgraph, node):
return False
register_specialize(local_add_specialize)
mul_canonizer = in2out(
LocalOptGroup(local_mul_canonizer, local_fill_sink, apply_all_opts=True),
name="mul_canonizer_groups",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论