提交 0bfc64e2 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #4216 from f0k/opt-zero-div

Add graph optimization 0/x -> 0
......@@ -5269,6 +5269,17 @@ def local_intdiv_by_one(node):
return [node.inputs[0].astype(node.outputs[0].dtype)]
@register_canonicalize
@gof.local_optimizer([T.int_div, T.true_div])
def local_zero_div(node):
"""0 / x -> 0
"""
if isinstance(node.op, T.Elemwise) and isinstance(
node.op.scalar_op, (theano.scalar.IntDiv, theano.scalar.TrueDiv)):
if local_mul_canonizer.get_constant(node.inputs[0]) == 0:
return [broadcast_like(0, node.outputs[0], node.fgraph)]
@gof.local_optimizer([T.pow])
def local_pow_specialize(node):
# here, we are past the point of canonicalization, so we don't want
......
......@@ -6123,6 +6123,27 @@ class TestIntDivByOne(unittest.TestCase):
assert len(divs) == 0
def test_local_zero_div():
"""Tests 0/x -> 0"""
mode = theano.compile.mode.get_default_mode().including("local_zero_div")
for t in (T.scalar, T.ivector, T.ftensor4):
x = t('x')
for op in (T.int_div, T.true_div):
y = op(0, x)
g = optimize(FunctionGraph([x], [y]))
# the division should be gone
divs = [node for node in g.toposort()
if isinstance(node.op, T.elemwise.Elemwise) and
isinstance(node.op.scalar_op, type(op.scalar_op))]
assert len(divs) == 0
# the output type should match the unoptimized one
output = g.outputs[0]
assert output.ndim == y.ndim
assert output.type == y.type
# and the output should be zero
assert theano.tensor.get_scalar_constant_value(output) == 0
def test_local_expm1():
x = matrix('x')
u = T.scalar('u')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论