提交 033c51c1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Add a real test for pre_greedy_local_optimizer

These tests actually confirm that `pre_greedy_local_optimizer` traverses a graph and applies optimizations, and that it avoids terms in a given `FunctionGraph`.
上级 b58fa026
...@@ -587,9 +587,44 @@ def test_pre_constant_merge(): ...@@ -587,9 +587,44 @@ def test_pre_constant_merge():
def test_pre_greedy_local_optimizer(): def test_pre_greedy_local_optimizer():
empty_fgraph = FunctionGraph([], []) empty_fgraph = FunctionGraph([], [])
x = MyVariable("x")
y = MyVariable("y")
c1 = Constant(MyType(), 1, "c1")
c2 = Constant(MyType(), 2, "c2")
o1 = op2(c1, c2)
o3 = op1(c1, y)
o2 = op1(o1, c2, x, o3, o1)
assert o2.owner.inputs[0].owner is not None
assert o2.owner.inputs[4].owner is not None
# This should fold `o1`, because it has only `Constant` arguments, and
# replace it with the `Constant` result
cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], o2)
assert cst.owner.inputs[0].owner is None
assert cst.owner.inputs[1] is c2
assert cst.owner.inputs[2] is x
assert cst.owner.inputs[3] is o3
assert cst.owner.inputs[4] is cst.owner.inputs[0]
# We're going to do it again, except this time `o1` is
# in the `fgraph`, so it shouldn't be folded
fg = FunctionGraph([], [o1], clone=False)
o2 = op1(o1, c2, x, o3, o1)
cst = pre_greedy_local_optimizer(fg, [constant_folding], o2)
assert cst.owner.inputs[0] is o1
assert cst.owner.inputs[4] is cst.owner.inputs[0]
# What exactly is this supposed to test?
ms = MakeSlice()(1) ms = MakeSlice()(1)
cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], ms) cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], ms)
assert isinstance(cst, SliceConstant) assert isinstance(cst, SliceConstant)
# Make sure constant of slice signature is hashable. # Make sure constant of slice signature is hashable.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论