提交 d26dea02 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor test_local_zero_div

* Use `pytest.mark.parameterize` * Remove reliance on `get_scalar_constant_value`
上级 c832967e
......@@ -3885,28 +3885,32 @@ class TestIntDivByOne:
assert len(divs) == 0
def test_local_zero_div():
# Tests 0/x -> 0
for t in (scalar, ivector, ftensor4):
x = t("x")
for op in (int_div, true_div):
y = op(0, x)
g = optimize(FunctionGraph([x], [y]))
# the division should be gone
divs = [
node
for node in g.toposort()
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, type(op.scalar_op))
]
assert len(divs) == 0
# the output type should match the unoptimized one
output = g.outputs[0]
assert output.ndim == y.ndim
assert output.type == y.type
# and the output should be zero
assert aet.get_scalar_constant_value(output) == 0
@pytest.mark.parametrize("t", [scalar, ivector, ftensor4])
@pytest.mark.parametrize("op", [int_div, true_div])
def test_local_zero_div(t, op):
"""Test the canonicalization ``0/x -> 0``."""
x = t("x")
y = op(0, x)
g = optimize(FunctionGraph([x], [y]))
# the division should be gone
divs = [
node
for node in g.toposort()
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, type(op.scalar_op))
]
assert len(divs) == 0
# the output type should match the unoptimized one
output = g.outputs[0]
assert output.ndim == y.ndim
assert output.type == y.type
# and the output should be zero
if output.owner and isinstance(output.owner.op, Alloc):
out_var = output.owner.inputs[0]
else:
out_var = output
assert out_var.data == 0
def test_local_sumsqr2dot():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论