提交 ae58bdf8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor TestMergeOptimizer

These changes remove the tests' reliance on string output and make the expected results much more explicit.
上级 3e7f00e5
......@@ -2,7 +2,8 @@ import pytest
from aesara.assert_op import assert_op
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant
from aesara.graph.basic import Apply, Constant, equal_computations
from aesara.graph.features import Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import (
......@@ -13,14 +14,13 @@ from aesara.graph.opt import (
OpSub,
PatternSub,
TopoOptimizer,
aesara,
local_optimizer,
logging,
pre_constant_merge,
pre_greedy_local_optimizer,
)
from aesara.tensor.basic_opt import constant_folding
from aesara.tensor.math import dot
from aesara.tensor.math import Dot, add, dot
from aesara.tensor.subtensor import AdvancedSubtensor
from aesara.tensor.type import matrix, values_eq_approx_always_true
from aesara.tensor.type_other import MakeSlice, SliceConstant, slicetype
......@@ -39,6 +39,13 @@ from tests.graph.utils import (
)
class AssertNoChanges(Feature):
"""A `Feature` that raises an error when nodes are changed in a graph."""
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
raise AssertionError()
def inputs():
x = MyVariable("x")
y = MyVariable("y")
......@@ -263,72 +270,60 @@ class TestMergeOptimizer:
def test_straightforward(self):
x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(x, z))
g = FunctionGraph([x, y, z], [e])
g = FunctionGraph([x, y, z], [e], clone=False)
MergeOptimizer().optimize(g)
assert str(g) == "FunctionGraph(Op1(*1 -> Op2(x, y), *1, Op2(x, z)))"
out_var = g.outputs[0]
var_1, var_2, var_3 = out_var.owner.inputs
assert var_1 is var_2
assert var_1 is not var_3
def test_constant_merging(self):
x = MyVariable("x")
y = Constant(MyType(), 2, name="y")
z = Constant(MyType(), 2, name="z")
e = op1(op2(x, y), op2(x, y), op2(x, z))
g = FunctionGraph([x, y, z], [e])
g = FunctionGraph([x, y, z], [e], clone=False)
MergeOptimizer().optimize(g)
strg = str(g)
assert (
strg == "FunctionGraph(Op1(*1 -> Op2(x, y), *1, *1))"
or strg == "FunctionGraph(Op1(*1 -> Op2(x, z), *1, *1))"
)
out_var = g.outputs[0]
var_1, var_2, var_3 = out_var.owner.inputs
assert var_1 is var_2
assert var_2 is var_3
def test_deep_merge(self):
x, y, z = inputs()
e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
g = FunctionGraph([x, y, z], [e])
g = FunctionGraph([x, y, z], [e], clone=False)
MergeOptimizer().optimize(g)
assert str(g) == "FunctionGraph(Op1(*1 -> Op3(Op2(x, y), z), Op4(*1)))"
out_var = g.outputs[0]
var_1, var_2 = out_var.owner.inputs
assert var_2.owner.inputs[0] is var_1
def test_no_merge(self):
x, y, z = inputs()
e = op1(op3(op2(x, y)), op3(op2(y, x)))
g = FunctionGraph([x, y, z], [e])
g.attach_feature(AssertNoChanges())
MergeOptimizer().optimize(g)
assert str(g) == "FunctionGraph(Op1(Op3(Op2(x, y)), Op3(Op2(y, x))))"
def test_merge_outputs(self):
x, y, z = inputs()
e1 = op3(op2(x, y))
e2 = op3(op2(x, y))
g = FunctionGraph([x, y, z], [e1, e2])
g = FunctionGraph([x, y, z], [e1, e2], clone=False)
MergeOptimizer().optimize(g)
assert str(g) == "FunctionGraph(*1 -> Op3(Op2(x, y)), *1)"
def test_multiple_merges(self):
x, y, z = inputs()
e1 = op1(x, y)
e2 = op2(op3(x), y, z)
e = op1(e1, op4(e2, e1), op1(e2))
g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g)
strg = str(g)
# note: graph.as_string can only produce the following two possibilities, but if
# the implementation was to change there are 6 other acceptable answers.
assert (
strg
== "FunctionGraph(Op1(*1 -> Op1(x, y), Op4(*2 -> Op2(Op3(x), y, z), *1), Op1(*2)))"
or strg
== "FunctionGraph(Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1)))"
)
assert g.outputs[0] is g.outputs[1]
def test_identical_constant_args(self):
x = MyVariable("x")
y = Constant(MyType(), 2, name="y")
z = Constant(MyType(), 2, name="z")
with config.change_flags(compute_test_value="off"):
e1 = op1(y, z)
g = FunctionGraph([x, y, z], [e1])
e1 = op1(y, z)
g = FunctionGraph([x, y, z], [e1], clone=False)
MergeOptimizer().optimize(g)
strg = str(g)
assert strg == "FunctionGraph(Op1(y, y))" or strg == "FunctionGraph(Op1(z, z))"
assert g.outputs[0].owner.op == op1
input_1 = g.outputs[0].owner.inputs[0]
assert input_1 is g.outputs[0].owner.inputs[1]
@pytest.mark.skip(reason="This was disabled for some unknown reason")
def test_one_assert_merge(self):
......@@ -336,21 +331,18 @@ class TestMergeOptimizer:
x1 = matrix("x1")
x2 = matrix("x2")
e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2)
g = FunctionGraph([x1, x2], [e])
g = FunctionGraph([x1, x2], [e], clone=False)
MergeOptimizer().optimize(g)
strg = aesara.printing.debugprint(g, file="str")
strref = """Elemwise{add,no_inplace} [id A] '' 4
|dot [id B] '' 3
| |Assert{msg='Aesara Assert failed!'} [id C] '' 2
| | |x1 [id D]
| | |All [id E] '' 1
| | |Elemwise{gt,no_inplace} [id F] '' 0
| | |x1 [id D]
| | |x2 [id G]
| |x2 [id G]
|dot [id B] '' 3
"""
assert strg == strref, (strg, strref)
assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
assert isinstance(add_inputs[0].owner.op, Dot)
# Confirm that the `Assert`s are correct
assert_var = add_inputs[0].owner.inputs[0]
assert_ref = assert_op(x1, (x1 > x2).all())
assert equal_computations([assert_var], [assert_ref])
# Confirm the merge
assert add_inputs[0] is add_inputs[1]
def test_both_assert_merge_identical(self):
# Merge two nodes, both have assert on the same node
......@@ -360,24 +352,20 @@ class TestMergeOptimizer:
e = dot(assert_op(x1, (x1 > x2).all()), x2) + dot(
assert_op(x1, (x1 > x2).all()), x2
)
g = FunctionGraph([x1, x2], [e])
g = FunctionGraph([x1, x2], [e], clone=False)
MergeOptimizer().optimize(g)
strg = aesara.printing.debugprint(g, file="str")
strref = """Elemwise{add,no_inplace} [id A] '' 4
|dot [id B] '' 3
| |Assert{msg='Aesara Assert failed!'} [id C] '' 2
| | |x1 [id D]
| | |All [id E] '' 1
| | |Elemwise{gt,no_inplace} [id F] '' 0
| | |x1 [id D]
| | |x2 [id G]
| |x2 [id G]
|dot [id B] '' 3
"""
# print(strg)
assert strg == strref, (strg, strref)
@pytest.mark.skip(reason="This was disabled for some unknown reason")
assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
assert isinstance(add_inputs[0].owner.op, Dot)
# Confirm that the `Assert`s are correct
assert_var = add_inputs[0].owner.inputs[0]
assert_ref = assert_op(x1, (x1 > x2).all())
assert equal_computations([assert_var], [assert_ref])
# Confirm the merge
assert add_inputs[0] is add_inputs[1]
@pytest.mark.skip(reason="Advanced `Assert` merging is disabled")
def test_both_assert_merge_1(self):
# Merge two nodes, both have assert on the same node
# with different conditions.
......@@ -387,43 +375,20 @@ class TestMergeOptimizer:
e = dot(assert_op(x1, (x1 > x3).all()), x2) + dot(
assert_op(x1, (x1 > x2).all()), x2
)
g = FunctionGraph([x1, x2, x3], [e])
g = FunctionGraph([x1, x2, x3], [e], clone=False)
MergeOptimizer().optimize(g)
strg = aesara.printing.debugprint(g, file="str")
strref1 = """Elemwise{add,no_inplace} [id A] '' 6
|dot [id B] '' 5
| |Assert{msg='Aesara Assert failed!'} [id C] '' 4
| | |x1 [id D]
| | |All [id E] '' 3
| | | |Elemwise{gt,no_inplace} [id F] '' 1
| | | |x1 [id D]
| | | |x3 [id G]
| | |All [id H] '' 2
| | |Elemwise{gt,no_inplace} [id I] '' 0
| | |x1 [id D]
| | |x2 [id J]
| |x2 [id J]
|dot [id B] '' 5
"""
strref2 = """Elemwise{add,no_inplace} [id A] '' 6
|dot [id B] '' 5
| |Assert{msg='Aesara Assert failed!'} [id C] '' 4
| | |x1 [id D]
| | |All [id E] '' 3
| | | |Elemwise{gt,no_inplace} [id F] '' 1
| | | |x1 [id D]
| | | |x2 [id G]
| | |All [id H] '' 2
| | |Elemwise{gt,no_inplace} [id I] '' 0
| | |x1 [id D]
| | |x3 [id J]
| |x2 [id G]
|dot [id B] '' 5
"""
# print(strg)
assert strg == strref1 or strg == strref2, (strg, strref1, strref2)
@pytest.mark.skip(reason="This was disabled for some unknown reason")
assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
assert isinstance(add_inputs[0].owner.op, Dot)
# Confirm that the `Assert`s are correct
assert_var = add_inputs[0].owner.inputs[0]
assert_ref = assert_op(x1, (x1 > x3).all(), (x1 > x2).all())
assert equal_computations([assert_var], [assert_ref])
# Confirm the merge
assert add_inputs[0] is add_inputs[1]
@pytest.mark.skip(reason="Advanced `Assert` merging is disabled")
def test_both_assert_merge_2(self):
# Merge two nodes, both have assert on different node
x1 = matrix("x1")
......@@ -432,29 +397,22 @@ class TestMergeOptimizer:
e = dot(assert_op(x1, (x1 > x3).all()), x2) + dot(
x1, assert_op(x2, (x2 > x3).all())
)
g = FunctionGraph([x1, x2, x3], [e])
g = FunctionGraph([x1, x2, x3], [e], clone=False)
MergeOptimizer().optimize(g)
strg = aesara.printing.debugprint(g, file="str")
strref = """Elemwise{add,no_inplace} [id A] '' 7
|dot [id B] '' 6
| |Assert{msg='Aesara Assert failed!'} [id C] '' 5
| | |x1 [id D]
| | |All [id E] '' 3
| | |Elemwise{gt,no_inplace} [id F] '' 1
| | |x1 [id D]
| | |x3 [id G]
| |Assert{msg='Aesara Assert failed!'} [id H] '' 4
| |x2 [id I]
| |All [id J] '' 2
| |Elemwise{gt,no_inplace} [id K] '' 0
| |x2 [id I]
| |x3 [id G]
|dot [id B] '' 6
"""
# print(strg)
assert strg == strref, (strg, strref)
@pytest.mark.skip(reason="This was disabled for some unknown reason")
assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
assert isinstance(add_inputs[0].owner.op, Dot)
# Confirm that the `Assert`s are correct
assert_var_1, assert_var_2 = add_inputs[0].owner.inputs
assert_ref_1 = assert_op(x1, (x1 > x3).all())
assert equal_computations([assert_var_1], [assert_ref_1])
assert_ref_2 = assert_op(x2, (x2 > x3).all())
assert equal_computations([assert_var_2], [assert_ref_2])
# Confirm the merge
assert add_inputs[0] is add_inputs[1]
@pytest.mark.skip(reason="Advanced `Assert` merging is disabled")
def test_both_assert_merge_2_reverse(self):
# Test case "test_both_assert_merge_2" but in reverse order
x1 = matrix("x1")
......@@ -463,27 +421,20 @@ class TestMergeOptimizer:
e = dot(x1, assert_op(x2, (x2 > x3).all())) + dot(
assert_op(x1, (x1 > x3).all()), x2
)
g = FunctionGraph([x1, x2, x3], [e])
g = FunctionGraph([x1, x2, x3], [e], clone=False)
MergeOptimizer().optimize(g)
strg = aesara.printing.debugprint(g, file="str")
strref = """Elemwise{add,no_inplace} [id A] '' 7
|dot [id B] '' 6
| |Assert{msg='Aesara Assert failed!'} [id C] '' 5
| | |x1 [id D]
| | |All [id E] '' 3
| | |Elemwise{gt,no_inplace} [id F] '' 1
| | |x1 [id D]
| | |x3 [id G]
| |Assert{msg='Aesara Assert failed!'} [id H] '' 4
| |x2 [id I]
| |All [id J] '' 2
| |Elemwise{gt,no_inplace} [id K] '' 0
| |x2 [id I]
| |x3 [id G]
|dot [id B] '' 6
"""
print(strg)
assert strg == strref, (strg, strref)
assert g.outputs[0].owner.op == add
add_inputs = g.outputs[0].owner.inputs
assert isinstance(add_inputs[0].owner.op, Dot)
# Confirm that the `Assert`s are correct
assert_var_1, assert_var_2 = add_inputs[0].owner.inputs
assert_ref_1 = assert_op(x2, (x2 > x3).all())
assert equal_computations([assert_var_1], [assert_ref_1])
assert_ref_2 = assert_op(x1, (x1 > x3).all())
assert equal_computations([assert_var_2], [assert_ref_2])
# Confirm the merge
assert add_inputs[0] is add_inputs[1]
def test_merge_noinput(self):
# Check that identical Apply nodes without inputs will be merged
......@@ -491,10 +442,11 @@ class TestMergeOptimizer:
y = NoInputOp(param=0)()
z = NoInputOp(param=1)()
fg = FunctionGraph([], [x, y, z])
fg = FunctionGraph([], [x, y, z], clone=False)
MergeOptimizer().optimize(fg)
no_input_ops = [n for n in fg.apply_nodes if isinstance(n.op, NoInputOp)]
assert len(no_input_ops) == 2, fg.apply_nodes
assert fg.outputs[0] is fg.outputs[1]
assert fg.outputs[0] is not fg.outputs[2]
class TestEquilibrium:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论