提交 b58fa026 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Add real tests for pre_constant_merge

上级 4afee7a7
......@@ -542,16 +542,55 @@ class TestEquilibrium:
assert str(g) == "FunctionGraph(Op1(x, y))"
def test_pre_constant_merge_slice():
def test_pre_constant_merge():
empty_fgraph = FunctionGraph([], [])
x = MyVariable("x")
y = MyVariable("y")
c1 = Constant(MyType(), 1, "c1")
c2 = Constant(MyType(), 1, "c1")
o1 = op2(c1, x)
o2 = op1(o1, y, c2)
assert c1 is not c2
res = pre_constant_merge(empty_fgraph, [o2])
assert [o2] == res
assert o2.owner.inputs[2] is c1
o2 = op1(o1, y, c2)
fg = FunctionGraph([x, y], [o2], clone=False)
assert o2.owner in fg.apply_nodes
res = pre_constant_merge(fg, [o2])
assert res == [o2]
assert o2.owner.inputs[2] is c2
# What is this supposed to test?
ms = MakeSlice()(1)
pre_constant_merge([ms])
res = pre_constant_merge(empty_fgraph, [ms])
assert res == [ms]
const_slice = SliceConstant(type=slicetype, data=slice(1, None, 2))
assert isinstance(const_slice, Constant)
adv = AdvancedSubtensor()(tt.matrix(), [2, 3], const_slice)
pre_constant_merge(adv)
fgraph = FunctionGraph([], [])
cst = pre_greedy_local_optimizer(fgraph, [constant_folding], ms)
res = pre_constant_merge(empty_fgraph, adv)
assert res == [adv]
def test_pre_greedy_local_optimizer():
empty_fgraph = FunctionGraph([], [])
ms = MakeSlice()(1)
cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], ms)
assert isinstance(cst, SliceConstant)
# Make sure constant of slice signature is hashable.
hash(cst.signature())
assert isinstance(hash(cst.signature()), int)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论