提交 65c61c15 authored 作者: abergeron's avatar abergeron

Merge pull request #2647 from kelvinxu/remove_useless_slice

CCW remove useless slice
......@@ -1919,6 +1919,33 @@ def local_set_to_inc_subtensor(node):
return [advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])]
@register_canonicalize
@register_specialize
@gof.local_optimizer([Subtensor])
def local_useless_slice(node):
"""
Remove Subtensor of the form X[0, :] -> X[0]
"""
if isinstance(node.op, Subtensor):
slices = get_idx_list(node.inputs, node.op.idx_list)
last_slice = len(slices)
for s in slices[::-1]:
# check if slice and then check slice indices
if (isinstance(s, slice) and s.start is None and s.stop is None
and (s.step is None or T.extract_constant(s.step) == 1)):
last_slice -= 1
else:
break
# check if we removed something
if last_slice < len(slices):
subtens = Subtensor(slices[:last_slice])
sl_ins = Subtensor.collapse(slices[:last_slice],
lambda x: isinstance(x, T.Variable))
out = subtens(node.inputs[0], *sl_ins)
return [out]
@register_canonicalize
@register_specialize
@gof.local_optimizer([Subtensor, AdvancedSubtensor1])
......
......@@ -1576,6 +1576,40 @@ def test_log_add():
#TODO: (write and) test that the optimization works with Sum in addition to working with Add.
def test_local_useless_slice():
# test a simple matrix
x = tensor.matrix('x')
mode_unopt = compile.get_default_mode().excluding("local_useless_slice")
mode_opt = compile.get_default_mode().including("local_useless_slice")
# test with and without the useless slice
o = 2 * x[0, :]
f_unopt = theano.function([x], o, mode=mode_unopt)
f_opt = theano.function([x], o, mode=mode_opt)
test_inp = numpy.random.randint(-10, 10, (4, 4)).astype('float32')
assert all(f_opt(test_inp) == f_unopt(test_inp)),\
"The optimization caused a mismatch in the result"
# test to see if the slice is truely gone
apply_node = f_opt.maker.fgraph.toposort()[0]
subtens = apply_node.op
assert not any(isinstance(idx, slice) for idx in subtens.idx_list), "Slice should be gone"
# test a 4d tensor
z = tensor.tensor4('z')
o2 = z[1, :, :, 1]
o3 = z[0, :, :, :]
f_opt_check = theano.function([z], o2, mode=mode_opt)
f_opt_check_apply = theano.function([z], o3, mode=mode_opt)
# The optimization shouldn't apply here
apply_node = f_opt_check.maker.fgraph.toposort()[0]
subtens = apply_node.op
assert [isinstance(idx, slice) for idx in subtens.idx_list].count(True) == 2
# But it should here
apply_node = f_opt_check_apply.maker.fgraph.toposort()[0]
subtens = apply_node.op
assert not any(isinstance(idx, slice) for idx in subtens.idx_list)
def test_local_useless_inc_subtensor():
x = tensor.matrix('x')
y = tensor.matrix('y')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论