提交 ae58bdf8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor TestMergeOptimizer

These changes remove the tests' reliance on string output and make the expected results much more explicit.
上级 3e7f00e5
...@@ -2,7 +2,8 @@ import pytest ...@@ -2,7 +2,8 @@ import pytest
from aesara.assert_op import assert_op from aesara.assert_op import assert_op
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant from aesara.graph.basic import Apply, Constant, equal_computations
from aesara.graph.features import Feature
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
...@@ -13,14 +14,13 @@ from aesara.graph.opt import ( ...@@ -13,14 +14,13 @@ from aesara.graph.opt import (
OpSub, OpSub,
PatternSub, PatternSub,
TopoOptimizer, TopoOptimizer,
aesara,
local_optimizer, local_optimizer,
logging, logging,
pre_constant_merge, pre_constant_merge,
pre_greedy_local_optimizer, pre_greedy_local_optimizer,
) )
from aesara.tensor.basic_opt import constant_folding from aesara.tensor.basic_opt import constant_folding
from aesara.tensor.math import dot from aesara.tensor.math import Dot, add, dot
from aesara.tensor.subtensor import AdvancedSubtensor from aesara.tensor.subtensor import AdvancedSubtensor
from aesara.tensor.type import matrix, values_eq_approx_always_true from aesara.tensor.type import matrix, values_eq_approx_always_true
from aesara.tensor.type_other import MakeSlice, SliceConstant, slicetype from aesara.tensor.type_other import MakeSlice, SliceConstant, slicetype
...@@ -39,6 +39,13 @@ from tests.graph.utils import ( ...@@ -39,6 +39,13 @@ from tests.graph.utils import (
) )
class AssertNoChanges(Feature):
"""A `Feature` that raises an error when nodes are changed in a graph."""
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
raise AssertionError()
def inputs(): def inputs():
x = MyVariable("x") x = MyVariable("x")
y = MyVariable("y") y = MyVariable("y")
...@@ -263,72 +270,60 @@ class TestMergeOptimizer: ...@@ -263,72 +270,60 @@ class TestMergeOptimizer:
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(x, z)) e = op1(op2(x, y), op2(x, y), op2(x, z))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "FunctionGraph(Op1(*1 -> Op2(x, y), *1, Op2(x, z)))" out_var = g.outputs[0]
var_1, var_2, var_3 = out_var.owner.inputs
assert var_1 is var_2
assert var_1 is not var_3
def test_constant_merging(self): def test_constant_merging(self):
x = MyVariable("x") x = MyVariable("x")
y = Constant(MyType(), 2, name="y") y = Constant(MyType(), 2, name="y")
z = Constant(MyType(), 2, name="z") z = Constant(MyType(), 2, name="z")
e = op1(op2(x, y), op2(x, y), op2(x, z)) e = op1(op2(x, y), op2(x, y), op2(x, z))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = str(g) out_var = g.outputs[0]
assert ( var_1, var_2, var_3 = out_var.owner.inputs
strg == "FunctionGraph(Op1(*1 -> Op2(x, y), *1, *1))" assert var_1 is var_2
or strg == "FunctionGraph(Op1(*1 -> Op2(x, z), *1, *1))" assert var_2 is var_3
)
def test_deep_merge(self): def test_deep_merge(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z))) e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "FunctionGraph(Op1(*1 -> Op3(Op2(x, y), z), Op4(*1)))" out_var = g.outputs[0]
var_1, var_2 = out_var.owner.inputs
assert var_2.owner.inputs[0] is var_1
def test_no_merge(self): def test_no_merge(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op3(op2(x, y)), op3(op2(y, x))) e = op1(op3(op2(x, y)), op3(op2(y, x)))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
g.attach_feature(AssertNoChanges())
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "FunctionGraph(Op1(Op3(Op2(x, y)), Op3(Op2(y, x))))"
def test_merge_outputs(self): def test_merge_outputs(self):
x, y, z = inputs() x, y, z = inputs()
e1 = op3(op2(x, y)) e1 = op3(op2(x, y))
e2 = op3(op2(x, y)) e2 = op3(op2(x, y))
g = FunctionGraph([x, y, z], [e1, e2]) g = FunctionGraph([x, y, z], [e1, e2], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "FunctionGraph(*1 -> Op3(Op2(x, y)), *1)" assert g.outputs[0] is g.outputs[1]
def test_multiple_merges(self):
x, y, z = inputs()
e1 = op1(x, y)
e2 = op2(op3(x), y, z)
e = op1(e1, op4(e2, e1), op1(e2))
g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g)
strg = str(g)
# note: graph.as_string can only produce the following two possibilities, but if
# the implementation was to change there are 6 other acceptable answers.
assert (
strg
== "FunctionGraph(Op1(*1 -> Op1(x, y), Op4(*2 -> Op2(Op3(x), y, z), *1), Op1(*2)))"
or strg
== "FunctionGraph(Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1)))"
)
def test_identical_constant_args(self): def test_identical_constant_args(self):
x = MyVariable("x") x = MyVariable("x")
y = Constant(MyType(), 2, name="y") y = Constant(MyType(), 2, name="y")
z = Constant(MyType(), 2, name="z") z = Constant(MyType(), 2, name="z")
with config.change_flags(compute_test_value="off"):
e1 = op1(y, z) e1 = op1(y, z)
g = FunctionGraph([x, y, z], [e1]) g = FunctionGraph([x, y, z], [e1], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = str(g)
assert strg == "FunctionGraph(Op1(y, y))" or strg == "FunctionGraph(Op1(z, z))" assert g.outputs[0].owner.op == op1
input_1 = g.outputs[0].owner.inputs[0]
assert input_1 is g.outputs[0].owner.inputs[1]
@pytest.mark.skip(reason="This was disabled for some unknown reason") @pytest.mark.skip(reason="This was disabled for some unknown reason")
def test_one_assert_merge(self): def test_one_assert_merge(self):
...@@ -336,21 +331,18 @@ class TestMergeOptimizer: ...@@ -336,21 +331,18 @@ class TestMergeOptimizer:
x1 = matrix("x1") x1 = matrix("x1")
x2 = matrix("x2") x2 = matrix("x2")
e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2) e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2)
g = FunctionGraph([x1, x2], [e]) g = FunctionGraph([x1, x2], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = aesara.printing.debugprint(g, file="str")
strref = """Elemwise{add,no_inplace} [id A] '' 4 assert g.outputs[0].owner.op == add
|dot [id B] '' 3 add_inputs = g.outputs[0].owner.inputs
| |Assert{msg='Aesara Assert failed!'} [id C] '' 2 assert isinstance(add_inputs[0].owner.op, Dot)
| | |x1 [id D] # Confirm that the `Assert`s are correct
| | |All [id E] '' 1 assert_var = add_inputs[0].owner.inputs[0]
| | |Elemwise{gt,no_inplace} [id F] '' 0 assert_ref = assert_op(x1, (x1 > x2).all())
| | |x1 [id D] assert equal_computations([assert_var], [assert_ref])
| | |x2 [id G] # Confirm the merge
| |x2 [id G] assert add_inputs[0] is add_inputs[1]
|dot [id B] '' 3
"""
assert strg == strref, (strg, strref)
def test_both_assert_merge_identical(self): def test_both_assert_merge_identical(self):
# Merge two nodes, both have assert on the same node # Merge two nodes, both have assert on the same node
...@@ -360,24 +352,20 @@ class TestMergeOptimizer: ...@@ -360,24 +352,20 @@ class TestMergeOptimizer:
e = dot(assert_op(x1, (x1 > x2).all()), x2) + dot( e = dot(assert_op(x1, (x1 > x2).all()), x2) + dot(
assert_op(x1, (x1 > x2).all()), x2 assert_op(x1, (x1 > x2).all()), x2
) )
g = FunctionGraph([x1, x2], [e]) g = FunctionGraph([x1, x2], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = aesara.printing.debugprint(g, file="str")
strref = """Elemwise{add,no_inplace} [id A] '' 4
|dot [id B] '' 3
| |Assert{msg='Aesara Assert failed!'} [id C] '' 2
| | |x1 [id D]
| | |All [id E] '' 1
| | |Elemwise{gt,no_inplace} [id F] '' 0
| | |x1 [id D]
| | |x2 [id G]
| |x2 [id G]
|dot [id B] '' 3
"""
# print(strg)
assert strg == strref, (strg, strref)
@pytest.mark.skip(reason="This was disabled for some unknown reason") assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
assert isinstance(add_inputs[0].owner.op, Dot)
# Confirm that the `Assert`s are correct
assert_var = add_inputs[0].owner.inputs[0]
assert_ref = assert_op(x1, (x1 > x2).all())
assert equal_computations([assert_var], [assert_ref])
# Confirm the merge
assert add_inputs[0] is add_inputs[1]
@pytest.mark.skip(reason="Advanced `Assert` merging is disabled")
def test_both_assert_merge_1(self): def test_both_assert_merge_1(self):
# Merge two nodes, both have assert on the same node # Merge two nodes, both have assert on the same node
# with different conditions. # with different conditions.
...@@ -387,43 +375,20 @@ class TestMergeOptimizer: ...@@ -387,43 +375,20 @@ class TestMergeOptimizer:
e = dot(assert_op(x1, (x1 > x3).all()), x2) + dot( e = dot(assert_op(x1, (x1 > x3).all()), x2) + dot(
assert_op(x1, (x1 > x2).all()), x2 assert_op(x1, (x1 > x2).all()), x2
) )
g = FunctionGraph([x1, x2, x3], [e]) g = FunctionGraph([x1, x2, x3], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = aesara.printing.debugprint(g, file="str")
strref1 = """Elemwise{add,no_inplace} [id A] '' 6
|dot [id B] '' 5
| |Assert{msg='Aesara Assert failed!'} [id C] '' 4
| | |x1 [id D]
| | |All [id E] '' 3
| | | |Elemwise{gt,no_inplace} [id F] '' 1
| | | |x1 [id D]
| | | |x3 [id G]
| | |All [id H] '' 2
| | |Elemwise{gt,no_inplace} [id I] '' 0
| | |x1 [id D]
| | |x2 [id J]
| |x2 [id J]
|dot [id B] '' 5
"""
strref2 = """Elemwise{add,no_inplace} [id A] '' 6
|dot [id B] '' 5
| |Assert{msg='Aesara Assert failed!'} [id C] '' 4
| | |x1 [id D]
| | |All [id E] '' 3
| | | |Elemwise{gt,no_inplace} [id F] '' 1
| | | |x1 [id D]
| | | |x2 [id G]
| | |All [id H] '' 2
| | |Elemwise{gt,no_inplace} [id I] '' 0
| | |x1 [id D]
| | |x3 [id J]
| |x2 [id G]
|dot [id B] '' 5
"""
# print(strg)
assert strg == strref1 or strg == strref2, (strg, strref1, strref2)
@pytest.mark.skip(reason="This was disabled for some unknown reason") assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
assert isinstance(add_inputs[0].owner.op, Dot)
# Confirm that the `Assert`s are correct
assert_var = add_inputs[0].owner.inputs[0]
assert_ref = assert_op(x1, (x1 > x3).all(), (x1 > x2).all())
assert equal_computations([assert_var], [assert_ref])
# Confirm the merge
assert add_inputs[0] is add_inputs[1]
@pytest.mark.skip(reason="Advanced `Assert` merging is disabled")
def test_both_assert_merge_2(self): def test_both_assert_merge_2(self):
# Merge two nodes, both have assert on different node # Merge two nodes, both have assert on different node
x1 = matrix("x1") x1 = matrix("x1")
...@@ -432,29 +397,22 @@ class TestMergeOptimizer: ...@@ -432,29 +397,22 @@ class TestMergeOptimizer:
e = dot(assert_op(x1, (x1 > x3).all()), x2) + dot( e = dot(assert_op(x1, (x1 > x3).all()), x2) + dot(
x1, assert_op(x2, (x2 > x3).all()) x1, assert_op(x2, (x2 > x3).all())
) )
g = FunctionGraph([x1, x2, x3], [e]) g = FunctionGraph([x1, x2, x3], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = aesara.printing.debugprint(g, file="str")
strref = """Elemwise{add,no_inplace} [id A] '' 7
|dot [id B] '' 6
| |Assert{msg='Aesara Assert failed!'} [id C] '' 5
| | |x1 [id D]
| | |All [id E] '' 3
| | |Elemwise{gt,no_inplace} [id F] '' 1
| | |x1 [id D]
| | |x3 [id G]
| |Assert{msg='Aesara Assert failed!'} [id H] '' 4
| |x2 [id I]
| |All [id J] '' 2
| |Elemwise{gt,no_inplace} [id K] '' 0
| |x2 [id I]
| |x3 [id G]
|dot [id B] '' 6
"""
# print(strg)
assert strg == strref, (strg, strref)
@pytest.mark.skip(reason="This was disabled for some unknown reason") assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
assert isinstance(add_inputs[0].owner.op, Dot)
# Confirm that the `Assert`s are correct
assert_var_1, assert_var_2 = add_inputs[0].owner.inputs
assert_ref_1 = assert_op(x1, (x1 > x3).all())
assert equal_computations([assert_var_1], [assert_ref_1])
assert_ref_2 = assert_op(x2, (x2 > x3).all())
assert equal_computations([assert_var_2], [assert_ref_2])
# Confirm the merge
assert add_inputs[0] is add_inputs[1]
@pytest.mark.skip(reason="Advanced `Assert` merging is disabled")
def test_both_assert_merge_2_reverse(self): def test_both_assert_merge_2_reverse(self):
# Test case "test_both_assert_merge_2" but in reverse order # Test case "test_both_assert_merge_2" but in reverse order
x1 = matrix("x1") x1 = matrix("x1")
...@@ -463,27 +421,20 @@ class TestMergeOptimizer: ...@@ -463,27 +421,20 @@ class TestMergeOptimizer:
e = dot(x1, assert_op(x2, (x2 > x3).all())) + dot( e = dot(x1, assert_op(x2, (x2 > x3).all())) + dot(
assert_op(x1, (x1 > x3).all()), x2 assert_op(x1, (x1 > x3).all()), x2
) )
g = FunctionGraph([x1, x2, x3], [e]) g = FunctionGraph([x1, x2, x3], [e], clone=False)
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = aesara.printing.debugprint(g, file="str")
strref = """Elemwise{add,no_inplace} [id A] '' 7 assert g.outputs[0].owner.op == add
|dot [id B] '' 6 add_inputs = g.outputs[0].owner.inputs
| |Assert{msg='Aesara Assert failed!'} [id C] '' 5 assert isinstance(add_inputs[0].owner.op, Dot)
| | |x1 [id D] # Confirm that the `Assert`s are correct
| | |All [id E] '' 3 assert_var_1, assert_var_2 = add_inputs[0].owner.inputs
| | |Elemwise{gt,no_inplace} [id F] '' 1 assert_ref_1 = assert_op(x2, (x2 > x3).all())
| | |x1 [id D] assert equal_computations([assert_var_1], [assert_ref_1])
| | |x3 [id G] assert_ref_2 = assert_op(x1, (x1 > x3).all())
| |Assert{msg='Aesara Assert failed!'} [id H] '' 4 assert equal_computations([assert_var_2], [assert_ref_2])
| |x2 [id I] # Confirm the merge
| |All [id J] '' 2 assert add_inputs[0] is add_inputs[1]
| |Elemwise{gt,no_inplace} [id K] '' 0
| |x2 [id I]
| |x3 [id G]
|dot [id B] '' 6
"""
print(strg)
assert strg == strref, (strg, strref)
def test_merge_noinput(self): def test_merge_noinput(self):
# Check that identical Apply nodes without inputs will be merged # Check that identical Apply nodes without inputs will be merged
...@@ -491,10 +442,11 @@ class TestMergeOptimizer: ...@@ -491,10 +442,11 @@ class TestMergeOptimizer:
y = NoInputOp(param=0)() y = NoInputOp(param=0)()
z = NoInputOp(param=1)() z = NoInputOp(param=1)()
fg = FunctionGraph([], [x, y, z]) fg = FunctionGraph([], [x, y, z], clone=False)
MergeOptimizer().optimize(fg) MergeOptimizer().optimize(fg)
no_input_ops = [n for n in fg.apply_nodes if isinstance(n.op, NoInputOp)]
assert len(no_input_ops) == 2, fg.apply_nodes assert fg.outputs[0] is fg.outputs[1]
assert fg.outputs[0] is not fg.outputs[2]
class TestEquilibrium: class TestEquilibrium:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论