提交 a91644f3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove graph string checking from TestLocalCanonicalizeAlloc.test_useless_alloc_with_shape_one

上级 5e5429e5
......@@ -1438,57 +1438,26 @@ class TestLocalCanonicalizeAlloc:
[isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()]
)
def test_useless_alloc_with_shape_one(self):
"""
TODO FIXME: Remove/replace the string output comparisons.
"""
alloc_lift = out2in(local_canonicalize_alloc)
x = shared(self.rng.standard_normal((2,)))
y = shared(self.rng.standard_normal())
z = shared(self.rng.standard_normal((1, 1)))
w = shared(self.rng.standard_normal((1, 1)))
alloc_x = at.alloc(x, 1, 3, 2)
alloc_y = at.alloc(y, 1, 1)
alloc_z = at.alloc(z, 1, 1, 2)
alloc_w = at.alloc(w, 1, 2)
g = FunctionGraph([x, y, z, w], [alloc_x, alloc_y, alloc_z, alloc_w])
assert str(g) == (
"FunctionGraph(Alloc(<TensorType(float64, vector)>, "
"TensorConstant{1}, "
"TensorConstant{3}, "
"TensorConstant{2}), "
"Alloc(<TensorType(float64, scalar)>, "
"TensorConstant{1}, "
"TensorConstant{1}), "
"Alloc(<TensorType(float64, matrix)>, "
"TensorConstant{1}, "
"TensorConstant{1}, "
"TensorConstant{2}), "
"Alloc(<TensorType(float64, matrix)>, "
"TensorConstant{1}, "
"TensorConstant{2}))"
)
@pytest.mark.parametrize(
"x, has_alloc",
[
(at.alloc(np.ones((2,)), 1, 3, 2), True),
(at.alloc(np.array(1.0), 1, 1), False),
(at.alloc(np.ones((1, 1)), 1, 1, 2), True),
(at.alloc(np.ones((1, 1)), 1, 2), True),
],
)
def test_useless_alloc_with_shape_one(self, x, has_alloc):
g = FunctionGraph(outputs=[x])
assert any(isinstance(node.op, Alloc) for node in g.toposort())
alloc_lift = out2in(local_canonicalize_alloc)
alloc_lift.optimize(g)
assert str(g) == (
"FunctionGraph(InplaceDimShuffle{x,0,1}"
"(Alloc(<TensorType(float64, vector)>, "
"TensorConstant{3}, "
"TensorConstant{2})), "
"InplaceDimShuffle{x,x}"
"(<TensorType(float64, scalar)>), "
"InplaceDimShuffle{x,0,1}"
"(Alloc(<TensorType(float64, matrix)>, "
"TensorConstant{1}, "
"TensorConstant{2})), "
"Alloc(<TensorType(float64, matrix)>, "
"TensorConstant{1}, "
"TensorConstant{2}))"
)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all")
if has_alloc:
assert any(isinstance(node.op, Alloc) for node in g.toposort())
else:
assert not any(isinstance(node.op, Alloc) for node in g.toposort())
class TestLocalUselessIncSubtensorAlloc:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论