提交 20bc7320 authored 作者: Frederic Bastien's avatar Frederic Bastien

Added more merge optimization of subtensor of subtensor. case var[::-1][:int].

上级 560af0e6
......@@ -1115,6 +1115,7 @@ def local_subtensor_merge(node):
"""
1) var[int:][-1] -> var[-1]
2) var[::-1][int] -> var[-int-1]
3) var[::-1][:int] -> var[:-int-1:-1]
"""
if (isinstance(node.op, T.Subtensor) and
......@@ -1164,6 +1165,7 @@ def local_subtensor_merge(node):
# var[::-1][int] -> var[-int-1]
if (len(node.inputs) in [1,2] and
isinstance(node.op.idx_list[0], (int, scalar.basic.Scalar)) and
len(u.owner.op.idx_list)==1 and
isinstance(u.owner.op.idx_list[0], slice) and
u.owner.op.idx_list[0].start is None and
......@@ -1181,6 +1183,30 @@ def local_subtensor_merge(node):
return [u.owner.inputs[0][-idx-1]]
# var[::-1][:int] -> var[:-int-1:-1]
if (len(node.inputs) in [1,2] and
len(u.owner.op.idx_list)==1 and
isinstance(node.op.idx_list[0], slice) and
node.op.idx_list[0].start in [0, None] and
#node.op.idx_list[0].stop is None and
node.op.idx_list[0].step is None and
isinstance(u.owner.op.idx_list[0], slice) and
u.owner.op.idx_list[0].start is None and
u.owner.op.idx_list[0].stop is None and
u.owner.op.idx_list[0].step == -1
):
slice_idx = node.op.idx_list[0]
idx = slice_idx.stop
if len(node.inputs) == 1 and isinstance(idx, int):
pass
elif (len(node.inputs) == 2 and
isinstance (idx, scalar.basic.Scalar)):
idx = T.tensor_from_scalar(node.inputs[1])
else:
return False
return [u.owner.inputs[0][:-idx-1:-1]]
@register_canonicalize
@gof.local_optimizer([None])
......
......@@ -1346,7 +1346,7 @@ class test_local_subtensor_merge(unittest.TestCase):
self.assertRaises(IndexError, f2, x_val)
def test_scalar2(self):
# var[::-1][const] -> var[-1]
# var[::-1][int] -> var[-1]
x = TT.matrix('x')
y = TT.iscalar('y')
f = function([x,y], x[::-1][y], mode=mode_opt)
......@@ -1381,6 +1381,54 @@ class test_local_subtensor_merge(unittest.TestCase):
assert isinstance(topo[2].op, theano.compile.function_module.DeepCopyOp)
f([[0,1],[2,3]]) # let debugmode test something
def test_const3(self):
# var[::-1][:const] -> var[-1]
x = TT.matrix('x')
for idx in range(-5,4):
f = function([x], x[::-1][:idx], mode=mode_opt)
#theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort()
#print [t for t in topo if isinstance(t.op, TT.Subtensor)]
assert len([t for t in topo if isinstance(t.op, TT.Subtensor)]) == 1
#print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
x_val = [[0,1],[2,3]]
f(x_val) # let debugmode test something
def test_scalar3(self):
# var[::-1][:int] -> var[-1]
x = TT.matrix('x')
y = TT.iscalar('y')
f = function([x,y], x[::-1][:y], mode=mode_opt)
#theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort()
#print [t for t in topo if isinstance(t.op, TT.Subtensor)]
assert len([t for t in topo if isinstance(t.op, TT.Subtensor)]) == 1
#print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
x_val = [[0,1],[2,3]]
for idx in range(-5,5):
f(x_val, idx) # let debugmode test something
def test_dont_opt3(self):
# Test that we don't optimize some case
# var[::-1][:const] should be optimized but not
# x[::other int][const]
x = TT.matrix('x')
f = function([x], x[::-2][:0], mode=mode_opt)
#theano.printing.debugprint(f)
topo=f.maker.env.toposort()
assert len(topo)==3
assert isinstance(topo[0].op, TT.Subtensor)
assert isinstance(topo[1].op, TT.Subtensor)
assert isinstance(topo[2].op, theano.compile.function_module.DeepCopyOp)
f([[0,1],[2,3]]) # let debugmode test something
def test_local_fill_useless():
m = theano.config.mode
if m == 'FAST_COMPILE':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论