提交 7700bfbc authored 作者: Frederic Bastien's avatar Frederic Bastien

Make opt local_useless_subtensor convert to subsample when we reverse.

上级 3e574bba
...@@ -1715,10 +1715,12 @@ def local_useless_inc_subtensor(node): ...@@ -1715,10 +1715,12 @@ def local_useless_inc_subtensor(node):
# If is this IncSubtensor useful? # If is this IncSubtensor useful?
# Check that we keep all the original data. # Check that we keep all the original data.
# Put the constant inputs in the slice.
idx_cst = theano.tensor.subtensor.get_idx_list(node.inputs[1:],
node.op.idx_list)
if all(isinstance(e, slice) and e.start is None and if all(isinstance(e, slice) and e.start is None and
e.stop is None and e.step is None e.stop is None and (e.step is None or T.extract_constant(e.step) == -1)
for e in node.op.idx_list): for e in idx_cst):
assert len(node.inputs) == 2
# IncSubtensor broadcast node.inputs[1] on node.inputs[0] # IncSubtensor broadcast node.inputs[1] on node.inputs[0]
# based on run time shapes, so we must check they are the same. # based on run time shapes, so we must check they are the same.
if not hasattr(node.fgraph, 'shape_feature'): if not hasattr(node.fgraph, 'shape_feature'):
...@@ -1726,8 +1728,12 @@ def local_useless_inc_subtensor(node): ...@@ -1726,8 +1728,12 @@ def local_useless_inc_subtensor(node):
if not node.fgraph.shape_feature.same_shape(node.inputs[0], if not node.fgraph.shape_feature.same_shape(node.inputs[0],
node.inputs[1]): node.inputs[1]):
return return
# They are the same shape, so we can remore this IncSubtensor # There is no reverse, so we don't need a replacement.
return node.inputs[1] if all(e.step is None
for e in node.op.idx_list):
# They are the same shape, so we can remore this IncSubtensor
return [node.inputs[1]]
return [Subtensor(node.op.idx_list)(*node.inputs[1:])]
@register_canonicalize @register_canonicalize
......
...@@ -1574,36 +1574,48 @@ def test_log_add(): ...@@ -1574,36 +1574,48 @@ def test_log_add():
def test_local_useless_inc_subtensor(): def test_local_useless_inc_subtensor():
x = tensor.matrix('x') x = tensor.matrix('x')
y = tensor.matrix('y') y = tensor.matrix('y')
o = tensor.set_subtensor(x[::, ::], y) for sub in [slice(None), slice(None, None, -1)]:
o_shape = tensor.set_subtensor(x[::, ::], o = tensor.set_subtensor(x[::, sub], y)
tensor.specify_shape(y, x.shape)) f = theano.function([x, y], o)
f_shape = theano.function([x, y], o_shape) o_shape = tensor.set_subtensor(x[::, sub],
f = theano.function([x, y], o) tensor.specify_shape(y, x.shape))
f_shape = theano.function([x, y], o_shape)
# Test with shape info
topo = f_shape.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.IncSubtensor) for n in topo)
out = f_shape([[2, 3]], [[3, 4]])
assert (out == numpy.asarray([[3, 4]])[::, sub]).all()
# Test that without shape info, we don't apply the opt.
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, tensor.IncSubtensor)
out = f([[2, 3]], [[3, 4]])
assert (out == numpy.asarray([[3, 4]])[::, sub]).all()
# Test with shape info # Test that we don't remove shape error
topo = f_shape.maker.fgraph.toposort() try:
assert len(topo) == 5, topo f([[2, 3]], [[3, 4], [4, 5]])
assert not isinstance(topo[-1].op, tensor.IncSubtensor) assert False
out = f_shape([[2, 3]], [[3, 4]]) except (ValueError, AssertionError):
assert (out == [[3, 4]]).all() pass
# Test that without shape info, we don't apply the opt. # Test that we don't remove broadcastability
topo = f.maker.fgraph.toposort() out = f([[2, 3], [3, 4]], [[5, 6]])
assert len(topo) == 1 assert (out == numpy.asarray([[5, 6], [5, 6]])[::, sub]).all()
assert isinstance(topo[0].op, tensor.IncSubtensor)
out = f([[2, 3]], [[3, 4]])
assert (out == [[3, 4]]).all()
# Test that we don't remove shape error # Test that we do not optimize others strides even when sub and y
try: # have same shapes
f([[2, 3]], [[3, 4], [4, 5]]) sub = x[::, ::2]
assert False o_shape = tensor.set_subtensor(sub,
except (ValueError, AssertionError): tensor.specify_shape(y, sub.shape))
pass f_shape = theano.function([x, y], o_shape)
topo = f_shape.maker.fgraph.toposort()
# Test that we don't remove broadcastability theano.printing.debugprint(f_shape)
out = f([[2, 3], [3, 4]], [[5, 6]]) assert any(isinstance(n.op, tensor.IncSubtensor) for n in topo)
assert (out == [[5, 6], [5, 6]]).all() out = f_shape([[2, 3, 6, 7]], [[8, 9]])
assert (out == numpy.asarray([[8, 3, 9, 7]])).all()
def test_local_useless_subtensor(): def test_local_useless_subtensor():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论