提交 5967e2cb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use register_* decorators in basic_opt and math_opt

上级 881f3494
...@@ -622,6 +622,8 @@ def is_dimshuffle_useless(new_order, input): ...@@ -622,6 +622,8 @@ def is_dimshuffle_useless(new_order, input):
return is_useless return is_useless
@register_canonicalize
@register_specialize
@local_optimizer([DimShuffle]) @local_optimizer([DimShuffle])
def local_dimshuffle_lift(fgraph, node): def local_dimshuffle_lift(fgraph, node):
""" """
...@@ -705,9 +707,6 @@ def local_useless_dimshuffle_in_reshape(fgraph, node): ...@@ -705,9 +707,6 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
return [ret] return [ret]
register_canonicalize(local_dimshuffle_lift)
register_specialize(local_dimshuffle_lift)
###################### ######################
# Casting operations # # Casting operations #
###################### ######################
...@@ -1633,6 +1632,7 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1633,6 +1632,7 @@ def local_elemwise_alloc(fgraph, node):
return ret return ret
@register_canonicalize
@local_optimizer([Elemwise]) @local_optimizer([Elemwise])
def local_fill_sink(fgraph, node): def local_fill_sink(fgraph, node):
""" """
...@@ -1680,9 +1680,6 @@ def local_fill_sink(fgraph, node): ...@@ -1680,9 +1680,6 @@ def local_fill_sink(fgraph, node):
return replacements return replacements
register_canonicalize(local_fill_sink)
@register_specialize @register_specialize
@register_stabilize @register_stabilize
# @register_canonicalize # We make full pass after the canonizer phase. # @register_canonicalize # We make full pass after the canonizer phase.
......
...@@ -1071,15 +1071,13 @@ local_mul_canonizer = AlgebraicCanonizer( ...@@ -1071,15 +1071,13 @@ local_mul_canonizer = AlgebraicCanonizer(
register_canonicalize(local_mul_canonizer, name="local_mul_canonizer") register_canonicalize(local_mul_canonizer, name="local_mul_canonizer")
@register_canonicalize
@local_optimizer([neg]) @local_optimizer([neg])
def local_neg_to_mul(fgraph, node): def local_neg_to_mul(fgraph, node):
if node.op == neg: if node.op == neg:
return [mul(np.array(-1, dtype=node.inputs[0].dtype), node.inputs[0])] return [mul(np.array(-1, dtype=node.inputs[0].dtype), node.inputs[0])]
register_canonicalize(local_neg_to_mul)
@register_specialize @register_specialize
@local_optimizer([Sum, Prod]) @local_optimizer([Sum, Prod])
def local_sum_prod_mul_by_scalar(fgraph, node): def local_sum_prod_mul_by_scalar(fgraph, node):
...@@ -1779,6 +1777,7 @@ def local_neg_div_neg(fgraph, node): ...@@ -1779,6 +1777,7 @@ def local_neg_div_neg(fgraph, node):
return [true_div(new_num, denom)] return [true_div(new_num, denom)]
@register_canonicalize
@local_optimizer([mul]) @local_optimizer([mul])
def local_mul_zero(fgraph, node): def local_mul_zero(fgraph, node):
""" """
...@@ -1800,9 +1799,8 @@ def local_mul_zero(fgraph, node): ...@@ -1800,9 +1799,8 @@ def local_mul_zero(fgraph, node):
return fill_chain(_asarray(0, dtype=otype.dtype), node.inputs) return fill_chain(_asarray(0, dtype=otype.dtype), node.inputs)
register_canonicalize(local_mul_zero) # TODO: Add this to the canonicalization to reduce redundancy.
@register_specialize
@local_optimizer([true_div]) @local_optimizer([true_div])
def local_div_to_reciprocal(fgraph, node): def local_div_to_reciprocal(fgraph, node):
if node.op == true_div and np.all( if node.op == true_div and np.all(
...@@ -1821,10 +1819,7 @@ def local_div_to_reciprocal(fgraph, node): ...@@ -1821,10 +1819,7 @@ def local_div_to_reciprocal(fgraph, node):
return False return False
# TODO: Add this to the canonicalization to reduce redundancy. @register_canonicalize
register_specialize(local_div_to_reciprocal)
@local_optimizer([reciprocal]) @local_optimizer([reciprocal])
def local_reciprocal_canon(fgraph, node): def local_reciprocal_canon(fgraph, node):
if node.op == reciprocal: if node.op == reciprocal:
...@@ -1833,9 +1828,7 @@ def local_reciprocal_canon(fgraph, node): ...@@ -1833,9 +1828,7 @@ def local_reciprocal_canon(fgraph, node):
return False return False
register_canonicalize(local_reciprocal_canon) @register_canonicalize
@local_optimizer([aet_pow]) @local_optimizer([aet_pow])
def local_pow_canonicalize(fgraph, node): def local_pow_canonicalize(fgraph, node):
if node.op == aet_pow: if node.op == aet_pow:
...@@ -1848,9 +1841,6 @@ def local_pow_canonicalize(fgraph, node): ...@@ -1848,9 +1841,6 @@ def local_pow_canonicalize(fgraph, node):
return False return False
register_canonicalize(local_pow_canonicalize)
@register_specialize @register_specialize
@local_optimizer([mul]) @local_optimizer([mul])
def local_mul_to_sqr(fgraph, node): def local_mul_to_sqr(fgraph, node):
...@@ -1892,6 +1882,7 @@ def local_zero_div(fgraph, node): ...@@ -1892,6 +1882,7 @@ def local_zero_div(fgraph, node):
return [ret] return [ret]
@register_specialize
@local_optimizer([aet_pow]) @local_optimizer([aet_pow])
def local_pow_specialize(fgraph, node): def local_pow_specialize(fgraph, node):
# here, we are past the point of canonicalization, so we don't want # here, we are past the point of canonicalization, so we don't want
...@@ -1929,9 +1920,6 @@ def local_pow_specialize(fgraph, node): ...@@ -1929,9 +1920,6 @@ def local_pow_specialize(fgraph, node):
return False return False
register_specialize(local_pow_specialize)
@register_specialize_device @register_specialize_device
@local_optimizer([aet_pow]) @local_optimizer([aet_pow])
def local_pow_specialize_device(fgraph, node): def local_pow_specialize_device(fgraph, node):
...@@ -1999,6 +1987,7 @@ def local_pow_specialize_device(fgraph, node): ...@@ -1999,6 +1987,7 @@ def local_pow_specialize_device(fgraph, node):
return rval return rval
@register_specialize
@local_optimizer([mul]) @local_optimizer([mul])
def local_mul_specialize(fgraph, node): def local_mul_specialize(fgraph, node):
""" """
...@@ -2074,9 +2063,7 @@ def local_mul_specialize(fgraph, node): ...@@ -2074,9 +2063,7 @@ def local_mul_specialize(fgraph, node):
return [broadcast_like(1, node.outputs[0], fgraph)] return [broadcast_like(1, node.outputs[0], fgraph)]
register_specialize(local_mul_specialize) @register_specialize
@local_optimizer([add]) @local_optimizer([add])
def local_add_specialize(fgraph, node): def local_add_specialize(fgraph, node):
def _fill_chain(v): def _fill_chain(v):
...@@ -2119,8 +2106,6 @@ def local_add_specialize(fgraph, node): ...@@ -2119,8 +2106,6 @@ def local_add_specialize(fgraph, node):
return False return False
register_specialize(local_add_specialize)
mul_canonizer = in2out( mul_canonizer = in2out(
LocalOptGroup(local_mul_canonizer, local_fill_sink, apply_all_opts=True), LocalOptGroup(local_mul_canonizer, local_fill_sink, apply_all_opts=True),
name="mul_canonizer_groups", name="mul_canonizer_groups",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论