提交 d26dea02 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor test_local_zero_div

* Use `pytest.mark.parameterize` * Remove reliance on `get_scalar_constant_value`
上级 c832967e
...@@ -3885,28 +3885,32 @@ class TestIntDivByOne: ...@@ -3885,28 +3885,32 @@ class TestIntDivByOne:
assert len(divs) == 0 assert len(divs) == 0
def test_local_zero_div(): @pytest.mark.parametrize("t", [scalar, ivector, ftensor4])
# Tests 0/x -> 0 @pytest.mark.parametrize("op", [int_div, true_div])
def test_local_zero_div(t, op):
for t in (scalar, ivector, ftensor4): """Test the canonicalization ``0/x -> 0``."""
x = t("x") x = t("x")
for op in (int_div, true_div): y = op(0, x)
y = op(0, x) g = optimize(FunctionGraph([x], [y]))
g = optimize(FunctionGraph([x], [y])) # the division should be gone
# the division should be gone divs = [
divs = [ node
node for node in g.toposort()
for node in g.toposort() if isinstance(node.op, Elemwise)
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, type(op.scalar_op))
and isinstance(node.op.scalar_op, type(op.scalar_op)) ]
] assert len(divs) == 0
assert len(divs) == 0 # the output type should match the unoptimized one
# the output type should match the unoptimized one output = g.outputs[0]
output = g.outputs[0] assert output.ndim == y.ndim
assert output.ndim == y.ndim assert output.type == y.type
assert output.type == y.type # and the output should be zero
# and the output should be zero if output.owner and isinstance(output.owner.op, Alloc):
assert aet.get_scalar_constant_value(output) == 0 out_var = output.owner.inputs[0]
else:
out_var = output
assert out_var.data == 0
def test_local_sumsqr2dot(): def test_local_sumsqr2dot():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论