提交 941ca01d authored 作者: Kelvin Xu's avatar Kelvin Xu 提交者: Kelvin Xu

added test, correct opt

上级 9a13989b
...@@ -1921,25 +1921,24 @@ def local_set_to_inc_subtensor(node): ...@@ -1921,25 +1921,24 @@ def local_set_to_inc_subtensor(node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([Subtensor, AdvancedSubtensor1]) @gof.local_optimizer([Subtensor])
def local_useless_slice(node): def local_useless_slice(node):
""" """
Remove Subtensor/AdvancedSubtensor1 of the from X[0, :] -> X[0] Remove Subtensor of the form X[0, :] -> X[0]
""" """
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
slices = get_idx_list(node.inputs, node.op.idx_list) slices = get_idx_list(node.inputs, node.op.idx_list)
last_slice = len(slices) last_slice = len(slices)
for s in slices[::-1]: for s in slices[::-1]:
# check if slice and then check slice indices # check if slice and then check slice indices
if isinstance(s, slice): if (isinstance(s, slice) and s.start is None and s.stop is None
if s.start is None and s.stop is None and\ and (s.step is None or T.extract_constant(s.step) == 1)):
(s.step is None or T.extract_constant(s.step) == 1):
last_slice -= 1 last_slice -= 1
else: else:
break break
# check if we removed something # check if we removed something
if last_slice < len(slices): if last_slice < len(slices):
subtens = Subtensor(make_constant(slices[:last_slice])) subtens = Subtensor(slices[:last_slice])
sl_ins = Subtensor.collapse(slices[:last_slice], sl_ins = Subtensor.collapse(slices[:last_slice],
lambda x: isinstance(x, T.Variable)) lambda x: isinstance(x, T.Variable))
out = subtens(node.inputs[0], *sl_ins) out = subtens(node.inputs[0], *sl_ins)
......
...@@ -1576,6 +1576,40 @@ def test_log_add(): ...@@ -1576,6 +1576,40 @@ def test_log_add():
#TODO: (write and) test that the optimization works with Sum in addition to working with Add. #TODO: (write and) test that the optimization works with Sum in addition to working with Add.
def test_local_useless_slice():
# test a simple matrix
x = tensor.matrix('x')
mode_unopt = compile.get_default_mode().excluding("local_useless_slice")
mode_opt = compile.get_default_mode().including("local_useless_slice")
# test with and without the useless slice
o = 2 * x[0, :]
f_unopt = theano.function([x], o, mode=mode_unopt)
f_opt = theano.function([x], o, mode=mode_opt)
test_inp = numpy.random.randint(-10, 10, (4, 4)).astype('float32')
assert all(f_opt(test_inp) == f_unopt(test_inp)),\
"The optimization caused a mismatch in the result"
# test to see if the slice is truely gone
apply_node = f_opt.maker.fgraph.toposort()[0]
subtens = apply_node.op
assert not any(isinstance(idx, slice) for idx in subtens.idx_list), "Slice should be gone"
# test a 4d tensor
z = tensor.tensor4('z')
o2 = z[1, :, :, 1]
o3 = z[0, :, :, :]
f_opt_check = theano.function([z], o2, mode=mode_opt)
f_opt_check_apply = theano.function([z], o3, mode=mode_opt)
# The optimization shouldn't apply here
apply_node = f_opt_check.maker.fgraph.toposort()[0]
subtens = apply_node.op
assert [isinstance(idx, slice) for idx in subtens.idx_list].count(True) == 2
# But it should here
apply_node = f_opt_check_apply.maker.fgraph.toposort()[0]
subtens = apply_node.op
assert not any(isinstance(idx, slice) for idx in subtens.idx_list)
def test_local_useless_inc_subtensor(): def test_local_useless_inc_subtensor():
x = tensor.matrix('x') x = tensor.matrix('x')
y = tensor.matrix('y') y = tensor.matrix('y')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论