提交 e846ef90 authored 作者: Frederic Bastien's avatar Frederic Bastien

'optimize var[int1::][:int2] to only 1 subtensor.'

上级 addbc07c
......@@ -1113,9 +1113,13 @@ def local_subtensor_lift(node):
@gof.local_optimizer([])
def local_subtensor_merge(node):
"""
1) var[int:][-1] -> var[-1]
1) var[int:][-1] -> var[-1] # a little different for when the first subtensor is empty.
2) var[::-1][int] -> var[-int-1]
3) var[::-1][:int] -> var[:-int-1:-1]
4) var[int1::][:int2] -> var[int1:switch(idx1>=0,
idx1,
maximum(u.owner.inputs[0].shape[0]+idx1, 0)
) + idx2]
"""
if (isinstance(node.op, T.Subtensor) and
......@@ -1188,7 +1192,7 @@ def local_subtensor_merge(node):
len(u.owner.op.idx_list)==1 and
isinstance(node.op.idx_list[0], slice) and
node.op.idx_list[0].start in [0, None] and
#node.op.idx_list[0].stop is None and
isinstance(node.op.idx_list[0].stop, (int, scalar.basic.Scalar)) and
node.op.idx_list[0].step is None and
isinstance(u.owner.op.idx_list[0], slice) and
u.owner.op.idx_list[0].start is None and
......@@ -1207,6 +1211,30 @@ def local_subtensor_merge(node):
return [u.owner.inputs[0][:-idx-1:-1]]
# var[int1::][:int2]
if (len(node.inputs) in [1, 2] and
isinstance(node.op.idx_list[0], slice) and
node.op.idx_list[0].start in [0, None] and
isinstance(node.op.idx_list[0].stop,(int, scalar.basic.Scalar)) and
node.op.idx_list[0].step is None and
len(u.owner.op.idx_list)==1 and
isinstance(u.owner.op.idx_list[0], slice) and
isinstance(u.owner.op.idx_list[0].start,(int, scalar.basic.Scalar)) and
u.owner.op.idx_list[0].stop in [sys.maxint, None] and
u.owner.op.idx_list[0].step is None
):
idx1 = u.owner.op.idx_list[0].start
idx2 = node.op.idx_list[0].stop
if isinstance(idx1, scalar.basic.Scalar):
idx1 = T.tensor_from_scalar(u.owner.inputs[1])
if isinstance(idx2, scalar.basic.Scalar):
idx2 = T.tensor_from_scalar(node.inputs[1])
# The maximum is needed to don't have shape[0] - idx1 < 0
idx2_neg = T.maximum(u.owner.inputs[0].shape[0]+idx1, 0)
new_idx2 = T.switch(idx1>=0, idx1, idx2_neg)+idx2
return [u.owner.inputs[0][idx1:new_idx2]]
@register_canonicalize
@gof.local_optimizer([None])
......
......@@ -1429,6 +1429,57 @@ class test_local_subtensor_merge(unittest.TestCase):
assert isinstance(topo[2].op, theano.compile.function_module.DeepCopyOp)
f([[0,1],[2,3]]) # let debugmode test something
def test_const4(self):
# var[const1::][:const2]
x = TT.matrix('x')
x_val = [[0,1],[2,3]]
for idx1 in range(-3,3):
for idx2 in range(-3,3):
f = function([x], x[idx1:][:idx2], mode=mode_opt)
#theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort()
#print [t for t in topo if isinstance(t.op, TT.Subtensor)]
assert len([t for t in topo if isinstance(t.op, TT.Subtensor)]) == 1
#print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
f(x_val) # let debugmode test something
def test_scalar4(self):
# var[int1:][:int2]
x = TT.matrix('x')
y = TT.iscalar('y')
z = TT.iscalar('y')
f = function([x,y,z], x[y:][:z], mode=mode_opt)
#theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort()
#print [t for t in topo if isinstance(t.op, TT.Subtensor)]
assert len([t for t in topo if isinstance(t.op, TT.Subtensor)]) == 1
#print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
x_val = [[0,1],[2,3]]
for idx1 in range(-5,5):
for idx2 in range(-5,5):
f(x_val, idx1, idx2) # let debugmode test something
def test_dont_opt4(self):
# Test that we don't optimize some case
# var[int1:][:int2] should be optimized but not
# x[::other int][const]
x = TT.matrix('x')
f = function([x], x[-2:0][:0], mode=mode_opt)
theano.printing.debugprint(f)
topo=f.maker.env.toposort()
assert len(topo)==3
assert isinstance(topo[0].op, TT.Subtensor)
assert isinstance(topo[1].op, TT.Subtensor)
assert isinstance(topo[2].op, theano.compile.function_module.DeepCopyOp)
f([[0,1],[2,3]]) # let debugmode test something
def test_local_fill_useless():
m = theano.config.mode
if m == 'FAST_COMPILE':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论