提交 9beb6c2d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove broken warning in AlgebraicCanonizer and refactor it tests

上级 77bcd689
...@@ -1024,28 +1024,10 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -1024,28 +1024,10 @@ class AlgebraicCanonizer(LocalOptimizer):
new = fill_chain(new, node.inputs)[0] new = fill_chain(new, node.inputs)[0]
if new.type == out.type: if new.type == out.type:
# This happen with test
# aesara/tensor/tests/test_opt.py:T_local_switch_sink
new.tag.values_eq_approx = values_eq_approx_remove_inf_nan new.tag.values_eq_approx = values_eq_approx_remove_inf_nan
copy_stack_trace(out, new)
# We need to implement the copy over of the stacktrace.
# See issue #5104.
return [new] return [new]
else: else:
_logger.warning(
" ".join(
(
"CANONIZE FAILED: new, out = ",
new,
",",
out,
"types",
new.type,
",",
out.type,
)
)
)
return False return False
def __str__(self): def __str__(self):
......
...@@ -16,9 +16,15 @@ from aesara.compile.function import function ...@@ -16,9 +16,15 @@ from aesara.compile.function import function
from aesara.compile.mode import Mode, get_default_mode, get_mode from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant from aesara.graph.basic import Apply, Constant, equal_computations
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import LocalOptGroup, TopoOptimizer, check_stack_trace, out2in from aesara.graph.opt import (
LocalOptGroup,
TopoOptimizer,
check_stack_trace,
in2out,
out2in,
)
from aesara.graph.opt_utils import is_same_graph, optimize_graph from aesara.graph.opt_utils import is_same_graph, optimize_graph
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import OptimizationQuery
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
...@@ -79,6 +85,7 @@ from aesara.tensor.math_opt import ( ...@@ -79,6 +85,7 @@ from aesara.tensor.math_opt import (
is_1pexp, is_1pexp,
local_grad_log_erfc_neg, local_grad_log_erfc_neg,
local_greedy_distributor, local_greedy_distributor,
local_mul_canonizer,
mul_canonizer, mul_canonizer,
parse_mul_tree, parse_mul_tree,
perform_sigm_times_exp, perform_sigm_times_exp,
...@@ -220,23 +227,31 @@ class TestGreedyDistribute: ...@@ -220,23 +227,31 @@ class TestGreedyDistribute:
assert np.all(r0 == r2) assert np.all(r0 == r2)
class TestAlgebraicCanonize: class TestAlgebraicCanonizer:
def test_muldiv(self):
x, y, z = matrices("xyz") x, y, z = matrices("xyz")
a, b, c, d = matrices("abcd")
# e = (2.0 * x) / (2.0 * y) @pytest.mark.parametrize(
# e = (2.0 * x) / (4.0 * y) "e, exp_g",
# e = x / (y / z) [
# e = (x * y) / x # ((2.0 * x) / (2.0 * y), None),
# e = (x / y) * (y / z) * (z / x) # ((2.0 * x) / (4.0 * y), None),
# e = (a / b) * (b / c) * (c / d) # (x / (y / z), None),
# e = (a * b) / (b * c) / (c * d) # ((x * y) / x, None),
# e = 2 * x / 2 # ((x / y) * (y / z) * (z / x), None),
# e = x / y / x # ((a / b) * (b / c) * (c / d), None),
# e = (x / x) * (y / y) # ((a * b) / (b * c) / (c * d), None),
e = (-1 * x) / y / (-2 * z) # (2 * x / 2, None),
g = FunctionGraph([x, y, z, a, b, c, d], [e]) # (x / y / x, None),
mul_canonizer.optimize(g) # ((x / x) * (y / y), None),
(
(-1 * x) / y / (-2 * z),
(at.as_tensor([[0.5]], dtype="floatX") * x) / (y * z),
),
],
)
def test_muldiv(self, e, exp_g):
g_opt = optimize_graph(e, custom_opt=mul_canonizer)
assert equal_computations([g_opt], [exp_g])
def test_elemwise_multiple_inputs_optimisation(self): def test_elemwise_multiple_inputs_optimisation(self):
# verify that the AlgebraicCanonizer merge sequential Elemwise({mul,add}) part 1 # verify that the AlgebraicCanonizer merge sequential Elemwise({mul,add}) part 1
...@@ -245,7 +260,6 @@ class TestAlgebraicCanonize: ...@@ -245,7 +260,6 @@ class TestAlgebraicCanonize:
# that are not implemented but are supposed to be. # that are not implemented but are supposed to be.
# #
# Test with and without DimShuffle # Test with and without DimShuffle
shp = (5, 5) shp = (5, 5)
fx, fy, fz = fmatrices("xyz") fx, fy, fz = fmatrices("xyz")
dx, dy, dz = dmatrices("xyz") dx, dy, dz = dmatrices("xyz")
...@@ -369,8 +383,7 @@ class TestAlgebraicCanonize: ...@@ -369,8 +383,7 @@ class TestAlgebraicCanonize:
assert out_dtype == out.dtype assert out_dtype == out.dtype
@pytest.mark.skip( @pytest.mark.skip(
reason="Current implementation of AlgebraicCanonizer does not " reason="Current implementation of AlgebraicCanonizer does not implement all cases."
"implement all cases. Skip the corresponding test."
) )
def test_elemwise_multiple_inputs_optimisation2(self): def test_elemwise_multiple_inputs_optimisation2(self):
# verify that the AlgebraicCanonizer merge sequential Elemwise({mul,add}) part 2. # verify that the AlgebraicCanonizer merge sequential Elemwise({mul,add}) part 2.
...@@ -951,6 +964,19 @@ class TestAlgebraicCanonize: ...@@ -951,6 +964,19 @@ class TestAlgebraicCanonize:
# at all. # at all.
assert not sio.getvalue() assert not sio.getvalue()
def test_mismatching_types(self):
a = at.as_tensor([[0.0]], dtype=np.float64)
b = tensor("float64", (None,)).dimshuffle("x", 0)
z = add(a, b)
# Construct a node with the wrong output `Type`
z = Apply(
z.owner.op, z.owner.inputs, [tensor("float64", (None, None))]
).outputs[0]
z_opt = optimize_graph(z, custom_opt=in2out(local_mul_canonizer, name="blah"))
# No rewrite was applied
assert z_opt is z
def test_local_merge_abs(): def test_local_merge_abs():
x, y, z = matrices("xyz") x, y, z = matrices("xyz")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论