提交 9beb6c2d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove broken warning in AlgebraicCanonizer and refactor it tests

上级 77bcd689
......@@ -1024,28 +1024,10 @@ class AlgebraicCanonizer(LocalOptimizer):
new = fill_chain(new, node.inputs)[0]
if new.type == out.type:
# This happen with test
# aesara/tensor/tests/test_opt.py:T_local_switch_sink
new.tag.values_eq_approx = values_eq_approx_remove_inf_nan
# We need to implement the copy over of the stacktrace.
# See issue #5104.
copy_stack_trace(out, new)
return [new]
else:
_logger.warning(
" ".join(
(
"CANONIZE FAILED: new, out = ",
new,
",",
out,
"types",
new.type,
",",
out.type,
)
)
)
return False
def __str__(self):
......
......@@ -16,9 +16,15 @@ from aesara.compile.function import function
from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config
from aesara.graph.basic import Constant
from aesara.graph.basic import Apply, Constant, equal_computations
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import LocalOptGroup, TopoOptimizer, check_stack_trace, out2in
from aesara.graph.opt import (
LocalOptGroup,
TopoOptimizer,
check_stack_trace,
in2out,
out2in,
)
from aesara.graph.opt_utils import is_same_graph, optimize_graph
from aesara.graph.optdb import OptimizationQuery
from aesara.misc.safe_asarray import _asarray
......@@ -79,6 +85,7 @@ from aesara.tensor.math_opt import (
is_1pexp,
local_grad_log_erfc_neg,
local_greedy_distributor,
local_mul_canonizer,
mul_canonizer,
parse_mul_tree,
perform_sigm_times_exp,
......@@ -220,23 +227,31 @@ class TestGreedyDistribute:
assert np.all(r0 == r2)
class TestAlgebraicCanonize:
def test_muldiv(self):
x, y, z = matrices("xyz")
a, b, c, d = matrices("abcd")
# e = (2.0 * x) / (2.0 * y)
# e = (2.0 * x) / (4.0 * y)
# e = x / (y / z)
# e = (x * y) / x
# e = (x / y) * (y / z) * (z / x)
# e = (a / b) * (b / c) * (c / d)
# e = (a * b) / (b * c) / (c * d)
# e = 2 * x / 2
# e = x / y / x
# e = (x / x) * (y / y)
e = (-1 * x) / y / (-2 * z)
g = FunctionGraph([x, y, z, a, b, c, d], [e])
mul_canonizer.optimize(g)
class TestAlgebraicCanonizer:
x, y, z = matrices("xyz")
@pytest.mark.parametrize(
"e, exp_g",
[
# ((2.0 * x) / (2.0 * y), None),
# ((2.0 * x) / (4.0 * y), None),
# (x / (y / z), None),
# ((x * y) / x, None),
# ((x / y) * (y / z) * (z / x), None),
# ((a / b) * (b / c) * (c / d), None),
# ((a * b) / (b * c) / (c * d), None),
# (2 * x / 2, None),
# (x / y / x, None),
# ((x / x) * (y / y), None),
(
(-1 * x) / y / (-2 * z),
(at.as_tensor([[0.5]], dtype="floatX") * x) / (y * z),
),
],
)
def test_muldiv(self, e, exp_g):
g_opt = optimize_graph(e, custom_opt=mul_canonizer)
assert equal_computations([g_opt], [exp_g])
def test_elemwise_multiple_inputs_optimisation(self):
# verify that the AlgebraicCanonizer merge sequential Elemwise({mul,add}) part 1
......@@ -245,7 +260,6 @@ class TestAlgebraicCanonize:
# that are not implemented but are supposed to be.
#
# Test with and without DimShuffle
shp = (5, 5)
fx, fy, fz = fmatrices("xyz")
dx, dy, dz = dmatrices("xyz")
......@@ -369,8 +383,7 @@ class TestAlgebraicCanonize:
assert out_dtype == out.dtype
@pytest.mark.skip(
reason="Current implementation of AlgebraicCanonizer does not "
"implement all cases. Skip the corresponding test."
reason="Current implementation of AlgebraicCanonizer does not implement all cases."
)
def test_elemwise_multiple_inputs_optimisation2(self):
# verify that the AlgebraicCanonizer merge sequential Elemwise({mul,add}) part 2.
......@@ -951,6 +964,19 @@ class TestAlgebraicCanonize:
# at all.
assert not sio.getvalue()
def test_mismatching_types(self):
a = at.as_tensor([[0.0]], dtype=np.float64)
b = tensor("float64", (None,)).dimshuffle("x", 0)
z = add(a, b)
# Construct a node with the wrong output `Type`
z = Apply(
z.owner.op, z.owner.inputs, [tensor("float64", (None, None))]
).outputs[0]
z_opt = optimize_graph(z, custom_opt=in2out(local_mul_canonizer, name="blah"))
# No rewrite was applied
assert z_opt is z
def test_local_merge_abs():
x, y, z = matrices("xyz")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论