提交 3e00b5cd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor local_add_specialize and test_local_add_specialize

These changes "flatten" the nested conditions that lead to the replacement logic. They also clarify the tests that are labeled as being for `local_add_specialize`, which actually aren't. Also, the unrelated tests implicitly relied on the canonicalizations built into `get_scalar_constant_value`; now, the actual canonicalizations are required in anticipation of `get_scalar_constant_value`'s replacement.
上级 35f52cff
...@@ -2062,13 +2062,15 @@ def local_mul_specialize(fgraph, node): ...@@ -2062,13 +2062,15 @@ def local_mul_specialize(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([add]) @local_optimizer([add])
def local_add_specialize(fgraph, node): def local_add_specialize(fgraph, node):
def _fill_chain(v): """Remove zeros from ``add``s.
out = fill_chain(v, node.inputs)
return out
TODO: This should be a canonicalization, no?
"""
# here, we are past the point of canonicalization, so we don't want # here, we are past the point of canonicalization, so we don't want
# to put in un-necessary fills. # to put in un-necessary fills.
if node.op == add: if node.op != add:
return False
new_inputs = [] new_inputs = []
for inp in node.inputs: for inp in node.inputs:
try: try:
...@@ -2079,27 +2081,31 @@ def local_add_specialize(fgraph, node): ...@@ -2079,27 +2081,31 @@ def local_add_specialize(fgraph, node):
continue continue
new_inputs.append(inp) new_inputs.append(inp)
if len(new_inputs) < len(node.inputs): if len(new_inputs) == len(node.inputs):
dtype = node.outputs[0].type.dtype return False
node_output = node.outputs[0]
dtype = node_output.type.dtype
if len(new_inputs) == 0: if len(new_inputs) == 0:
# we got rid of the entire expression! # we got rid of the entire expression!
ndim = node.outputs[0].type.ndim ndim = node_output.type.ndim
# Reuse call to constant for cache() # Reuse call to constant for cache()
cst = constant(np.zeros((1,) * ndim, dtype=dtype)) cst = constant(np.zeros((1,) * ndim, dtype=dtype))
assert cst.type.broadcastable == (True,) * ndim assert cst.type.broadcastable == (True,) * ndim
return _fill_chain(cst) return fill_chain(cst, node.inputs)
if len(new_inputs) == 1: if len(new_inputs) == 1:
ret = _fill_chain(new_inputs[0]) ret = fill_chain(new_inputs[0], node.inputs)
else: else:
ret = _fill_chain(add(*new_inputs)) ret = fill_chain(add(*new_inputs), node.inputs)
# The dtype should not be changed. It can happen if the input # The dtype should not be changed. It can happen if the input
# that was forcing upcasting was equal to 0. # that was forcing upcasting was equal to 0.
if ret[0].dtype != dtype: if ret[0].dtype != dtype:
ret = [cast(ret[0], dtype)] ret = [cast(ret[0], dtype)]
return ret return ret
else:
return False
mul_canonizer = in2out( mul_canonizer = in2out(
......
...@@ -77,7 +77,6 @@ from aesara.tensor.math import tan, tanh, true_div, xor ...@@ -77,7 +77,6 @@ from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import ( from aesara.tensor.math_opt import (
compute_mul, compute_mul,
is_1pexp, is_1pexp,
local_add_specialize,
local_grad_log_erfc_neg, local_grad_log_erfc_neg,
local_greedy_distributor, local_greedy_distributor,
mul_canonizer, mul_canonizer,
...@@ -3781,24 +3780,41 @@ class TestLocalSumProdDimshuffle: ...@@ -3781,24 +3780,41 @@ class TestLocalSumProdDimshuffle:
# test_local_sum_divprod_dimshuffle ((a * b) / (c * d)) # test_local_sum_divprod_dimshuffle ((a * b) / (c * d))
def test_local_add_specialize(): def test_local_useless_adds():
default_mode = get_default_mode()
# Test for all zeros
a = scalar()
s = add(aet.zeros_like(a))
mode_with_opt = default_mode.including("canonicalization", "local_useless_fill")
f = function([a], s, mode=mode_with_opt)
assert not any(node.op == add for node in f.maker.fgraph.apply_nodes)
# test of non-zero dimension # test of non-zero dimension
a = vector() a = vector()
s = add(aet.zeros_like(a)) s = add(aet.zeros_like(a))
assert local_add_specialize.transform(None, s.owner) mode_with_opt = default_mode.including("canonicalization", "local_useless_elemwise")
f = function([a], s, mode=mode_with_opt)
assert not any(node.op == add for node in f.maker.fgraph.apply_nodes)
# test of 0-d # test of 0-d
a = scalar() a = scalar()
s = add(aet.zeros_like(a)) s = add(aet.zeros_like(a))
assert local_add_specialize.transform(None, s.owner) mode_with_opt = default_mode.including(
"canonicalization", "local_useless_fill", "local_useless_elemwise"
)
f = function([a], s, mode=mode_with_opt)
assert not any(node.op == add for node in f.maker.fgraph.apply_nodes)
# Test when the 0 input is forcing upcasting # Test when the 0 input is forcing upcasting
a = aet.constant(0, dtype="int64") a = aet.constant(0, dtype="int64")
b = aet.constant(1, dtype="int32") b = aet.constant(1, dtype="int32")
s = a + b s = a + b
transformed = local_add_specialize.transform(None, s.owner) mode_with_opt = default_mode.including("canonicalization", "local_add_canonizer")
assert transformed f = function([], s, mode=mode_with_opt)
assert transformed[0].type == s.type transformed = f.maker.fgraph.outputs[0]
assert not any(node.op == add for node in f.maker.fgraph.apply_nodes)
assert transformed.type == s.type
def test_local_div_to_reciprocal(): def test_local_div_to_reciprocal():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论