提交 3e574bba authored 作者: Frederic Bastien's avatar Frederic Bastien

Add local_useless_inc_subtensor.

Remove the IncSubtensor when we can infer all the shapes are the same.
上级 09a4d5ed
...@@ -1688,6 +1688,48 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -1688,6 +1688,48 @@ def local_upcast_elemwise_constant_inputs(node):
################## ##################
@register_canonicalize
@register_specialize
@gof.local_optimizer([IncSubtensor])
def local_useless_inc_subtensor(node):
"""Remove IncSubtensor, when we overwrite the full inputs with the
new value.
"""
if not isinstance(node.op, IncSubtensor):
return
if node.op.set_instead_of_inc is False:
# This is an IncSubtensor, so the init value must be zeros
try:
c = get_scalar_constant_value(node.inputs[0])
if c != 0:
return
except NotScalarConstantError:
return
if (node.inputs[0].ndim != node.inputs[1].ndim or
node.inputs[0].broadcastable != node.inputs[1].broadcastable):
# FB: I didn't check if this case can happen, but this opt
# don't support it.
return
# We have a SetSubtensor or an IncSubtensor on zeros
# If is this IncSubtensor useful?
# Check that we keep all the original data.
if all(isinstance(e, slice) and e.start is None and
e.stop is None and e.step is None
for e in node.op.idx_list):
assert len(node.inputs) == 2
# IncSubtensor broadcast node.inputs[1] on node.inputs[0]
# based on run time shapes, so we must check they are the same.
if not hasattr(node.fgraph, 'shape_feature'):
return
if not node.fgraph.shape_feature.same_shape(node.inputs[0],
node.inputs[1]):
return
# They are the same shape, so we can remore this IncSubtensor
return node.inputs[1]
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor])
......
...@@ -1571,6 +1571,41 @@ def test_log_add(): ...@@ -1571,6 +1571,41 @@ def test_log_add():
#TODO: (write and) test that the optimization works with Sum in addition to working with Add. #TODO: (write and) test that the optimization works with Sum in addition to working with Add.
def test_local_useless_inc_subtensor():
x = tensor.matrix('x')
y = tensor.matrix('y')
o = tensor.set_subtensor(x[::, ::], y)
o_shape = tensor.set_subtensor(x[::, ::],
tensor.specify_shape(y, x.shape))
f_shape = theano.function([x, y], o_shape)
f = theano.function([x, y], o)
# Test with shape info
topo = f_shape.maker.fgraph.toposort()
assert len(topo) == 5, topo
assert not isinstance(topo[-1].op, tensor.IncSubtensor)
out = f_shape([[2, 3]], [[3, 4]])
assert (out == [[3, 4]]).all()
# Test that without shape info, we don't apply the opt.
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, tensor.IncSubtensor)
out = f([[2, 3]], [[3, 4]])
assert (out == [[3, 4]]).all()
# Test that we don't remove shape error
try:
f([[2, 3]], [[3, 4], [4, 5]])
assert False
except (ValueError, AssertionError):
pass
# Test that we don't remove broadcastability
out = f([[2, 3], [3, 4]], [[5, 6]])
assert (out == [[5, 6], [5, 6]]).all()
def test_local_useless_subtensor(): def test_local_useless_subtensor():
x = tensor.matrix('x') x = tensor.matrix('x')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论