提交 d26dea02 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor test_local_zero_div

* Use `pytest.mark.parameterize` * Remove reliance on `get_scalar_constant_value`
上级 c832967e
...@@ -3885,12 +3885,11 @@ class TestIntDivByOne: ...@@ -3885,12 +3885,11 @@ class TestIntDivByOne:
assert len(divs) == 0 assert len(divs) == 0
def test_local_zero_div(): @pytest.mark.parametrize("t", [scalar, ivector, ftensor4])
# Tests 0/x -> 0 @pytest.mark.parametrize("op", [int_div, true_div])
def test_local_zero_div(t, op):
for t in (scalar, ivector, ftensor4): """Test the canonicalization ``0/x -> 0``."""
x = t("x") x = t("x")
for op in (int_div, true_div):
y = op(0, x) y = op(0, x)
g = optimize(FunctionGraph([x], [y])) g = optimize(FunctionGraph([x], [y]))
# the division should be gone # the division should be gone
...@@ -3906,7 +3905,12 @@ def test_local_zero_div(): ...@@ -3906,7 +3905,12 @@ def test_local_zero_div():
assert output.ndim == y.ndim assert output.ndim == y.ndim
assert output.type == y.type assert output.type == y.type
# and the output should be zero # and the output should be zero
assert aet.get_scalar_constant_value(output) == 0 if output.owner and isinstance(output.owner.op, Alloc):
out_var = output.owner.inputs[0]
else:
out_var = output
assert out_var.data == 0
def test_local_sumsqr2dot(): def test_local_sumsqr2dot():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论