提交 560af0e6 authored 作者: Frederic Bastien's avatar Frederic Bastien

'Merge more subtensor of subtensor(the case var[::-1][int])'

上级 49e7b915
...@@ -1113,25 +1113,26 @@ def local_subtensor_lift(node): ...@@ -1113,25 +1113,26 @@ def local_subtensor_lift(node):
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_subtensor_merge(node): def local_subtensor_merge(node):
""" """
var[int:][-1] -> var[-1] 1) var[int:][-1] -> var[-1]
2) var[::-1][int] -> var[-int-1]
The optimization is valid for any int be it constant or not.
""" """
if (isinstance(node.op, T.Subtensor) and if (isinstance(node.op, T.Subtensor) and
len(node.inputs)==1 and len(node.op.idx_list)==1):
len(node.op.idx_list)==1 and
node.op.idx_list[0]==-1):
u = node.inputs[0] u = node.inputs[0]
if not u.owner or len(u.clients) > 1: if (not u.owner or len(u.clients) > 1 or
not isinstance(u.owner.op, T.Subtensor)):
return False return False
if (isinstance(u.owner.op, T.Subtensor) and # var[int:][-1] -> var[-1]
len(u.owner.op.idx_list)==1 and if (len(node.inputs)==1 and
isinstance(u.owner.op.idx_list[0], slice) and node.op.idx_list[0]==-1 and
u.owner.op.idx_list[0].stop is None and len(u.owner.op.idx_list)==1 and
u.owner.op.idx_list[0].step is None isinstance(u.owner.op.idx_list[0], slice) and
): u.owner.op.idx_list[0].stop is None and
u.owner.op.idx_list[0].step is None
):
u_start = u.owner.op.idx_list[0].start u_start = u.owner.op.idx_list[0].start
if len(u.owner.inputs) == 1 and isinstance(u_start, int): if len(u.owner.inputs) == 1 and isinstance(u_start, int):
...@@ -1161,6 +1162,25 @@ def local_subtensor_merge(node): ...@@ -1161,6 +1162,25 @@ def local_subtensor_merge(node):
new_index = T.scalar_from_tensor(new_index) new_index = T.scalar_from_tensor(new_index)
return [u.owner.inputs[0][new_index]] return [u.owner.inputs[0][new_index]]
# var[::-1][int] -> var[-int-1]
if (len(node.inputs) in [1,2] and
len(u.owner.op.idx_list)==1 and
isinstance(u.owner.op.idx_list[0], slice) and
u.owner.op.idx_list[0].start is None and
u.owner.op.idx_list[0].stop is None and
u.owner.op.idx_list[0].step == -1
):
idx = node.op.idx_list[0]
if len(node.inputs) == 1 and isinstance(idx, int):
pass
elif (len(node.inputs) == 2 and
isinstance (idx, scalar.basic.Scalar)):
idx = T.tensor_from_scalar(node.inputs[1])
else:
return False
return [u.owner.inputs[0][-idx-1]]
@register_canonicalize @register_canonicalize
@gof.local_optimizer([None]) @gof.local_optimizer([None])
......
...@@ -1322,6 +1322,65 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1322,6 +1322,65 @@ class test_local_subtensor_merge(unittest.TestCase):
assert isinstance(topo[2].op, theano.compile.function_module.DeepCopyOp) assert isinstance(topo[2].op, theano.compile.function_module.DeepCopyOp)
f([[0,1],[2,3]]) # let debugmode test something f([[0,1],[2,3]]) # let debugmode test something
def test_const2(self):
# var[::-1][const] -> var[-1]
x = TT.matrix('x')
for idx in range(-5,4):
f = function([x], x[::-1][idx], mode=mode_opt)
#theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort()
#print [t for t in topo if isinstance(t.op, TT.Subtensor)]
assert len([t for t in topo if isinstance(t.op, TT.Subtensor)]) == 1
#print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
x_val = [[0,1],[2,3]]
if idx<2 and idx>=-2:
# The first subtensor is non-empty, so it makes sense
f(x_val) # let debugmode test something
else:
# A non-empty subtensor of an empty one should be an IndexError
self.assertRaises(IndexError, f, x_val)
f2 = function([x], x[::-1][idx], mode=mode_opt.excluding('local_subtensor_merge'))
self.assertRaises(IndexError, f2, x_val)
def test_scalar2(self):
# var[::-1][const] -> var[-1]
x = TT.matrix('x')
y = TT.iscalar('y')
f = function([x,y], x[::-1][y], mode=mode_opt)
#theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort()
#print [t for t in topo if isinstance(t.op, TT.Subtensor)]
assert len([t for t in topo if isinstance(t.op, TT.Subtensor)]) == 1
#print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
x_val = [[0,1],[2,3]]
for idx in range(-2,2):
f(x_val, idx) # let debugmode test something
for idx in range(2,5)+range(-5,-2):
self.assertRaises(IndexError, f, x_val, idx)
f = function([x,y], x[::-1][y], mode=mode_opt.excluding('local_subtensor_merge'))
self.assertRaises(IndexError, f, x_val, idx)
def test_dont_opt2(self):
# Test that we don't optimize some case
# var[::-1][const] should be optimized but not
# x[::other int][const]
x = TT.matrix('x')
f = function([x], x[::-2][0], mode=mode_opt)
#theano.printing.debugprint(f)
topo=f.maker.env.toposort()
assert len(topo)==3
assert isinstance(topo[0].op, TT.Subtensor)
assert isinstance(topo[1].op, TT.Subtensor)
assert isinstance(topo[2].op, theano.compile.function_module.DeepCopyOp)
f([[0,1],[2,3]]) # let debugmode test something
def test_local_fill_useless(): def test_local_fill_useless():
m = theano.config.mode m = theano.config.mode
if m == 'FAST_COMPILE': if m == 'FAST_COMPILE':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论