提交 4071d3f8 authored 作者: Frederic Bastien's avatar Frederic Bastien

'Fix the optimization local_useless_subtensor and add test to it and generalize…

'Fix the optimization local_useless_subtensor and add test to it and generalize it to support for Theano Variable in the slice.stop.'
上级 7e587b08
...@@ -1096,18 +1096,37 @@ def local_useless_subtensor(node): ...@@ -1096,18 +1096,37 @@ def local_useless_subtensor(node):
Remove Subtensor if it take the full input Remove Subtensor if it take the full input
""" """
if isinstance(node.op, T.Subtensor): if isinstance(node.op, T.Subtensor):
shape_of = node.env.shape_feature.shape_of
node_input_idx = 1
for pos, idx in enumerate(node.op.idx_list): for pos, idx in enumerate(node.op.idx_list):
length_pos_data = sys.maxint
length_pos_shape_i = None
try: try:
length_pos = shape_i(node.inputs[0])[pos] length_pos = shape_of[node.inputs[0]][pos]
except: if isinstance(length_pos, theano.tensor.basic.TensorConstant):
length_pos_data = length_pos.data
else:
length_pos_shape_i = node.inputs[node_input_idx].owner.inputs[0]
except Exception, e:
length_pos = None length_pos = None
if ( isinstance(idx,slice) and if ( isinstance(idx,slice) and
idx.start in [0,None] and idx.start in [0,None] and
idx.step in [1,None] and idx.step in [1,None] and
idx.stop in [sys.maxint, None, length_pos]): (idx.stop in [sys.maxint, None, length_pos_data] or
(isinstance(idx.stop, int) and idx.stop>=length_pos_data) or
(isinstance(idx.stop, theano.scalar.Scalar) and
length_pos==length_pos_shape_i)
)):
pass pass
else: else:
return False return False
if isinstance(idx, slice):
node_input_idx += sum([isinstance(idx.start, theano.scalar.Scalar),
isinstance(idx.stop, theano.scalar.Scalar),
isinstance(idx.step, theano.scalar.Scalar)])
if isinstance(idx, theano.scalar.Scalar):
node_input_idx += 1
return [node.inputs[0]] return [node.inputs[0]]
......
...@@ -1147,12 +1147,76 @@ def test_log_add(): ...@@ -1147,12 +1147,76 @@ def test_log_add():
def test_local_useless_subtensor(): def test_local_useless_subtensor():
x = TT.matrix('x') x = TT.matrix('x')
f = function([x], TT.exp(x)[0:], mode=mode_opt)
# Test default
for dims in [(slice(0,None),),
(slice(0,None),slice(0,None)),
]:
f = function([x], TT.exp(x).__getitem__(dims), mode=mode_opt)
#theano.printing.debugprint(f)
prog=f.maker.env.toposort() prog=f.maker.env.toposort()
assert prog[0].op == TT.exp assert prog[0].op == TT.exp
assert len(prog)==1 assert len(prog)==1
f([[0,1],[2,3]]) # let debugmode test something f([[0,1,2],[3,4,5]]) # let debugmode test something
x_c = specify_shape(x, (2,3))
# Test constant
for dims, res in [((slice(0,2),), True),
((slice(0,2),slice(0,None)), True),
((slice(0,2),slice(0,3)), True),
((slice(0,None),slice(0,3)), True),
((slice(0,3),slice(0,13)), True),
((slice(0,3),slice(0,2)), False),
((slice(0,1),slice(0,None)), False),
((slice(0,1),1), False),
]:
f = function([x], TT.exp(x_c).__getitem__(dims), mode=mode_opt)
#theano.printing.debugprint(f)
prog=f.maker.env.toposort()
if res:
assert isinstance(prog[0].op, theano.tensor.basic.SpecifyShape), dims
assert prog[1].op == TT.exp, dims
assert len(prog)==2, dims
else:
assert any([isinstance(node.op, Subtensor) for node in prog])
f([[0,1,2],[3,4,5]]) # let debugmode test something
# Test Variable
for idx, (dims, res) in enumerate([
((slice(0,x.shape[0]),), True),
((slice(0,x.shape[1]),), False),
((slice(0,x.shape[0]),slice(0,x.shape[1]),), True),
((slice(0,x.shape[0]),slice(0,x.shape[0]),), False),
((slice(0,x.shape[1]),slice(0,x.shape[0]),), False),
((slice(0,x.shape[1]),slice(0,x.shape[1]),), False),
((slice(0,x.shape[1]),2), False),
((slice(0,x.shape[1]),slice(x.shape[0]-x.shape[0],x.shape[1]),), False),
]):
f = function([x], TT.exp(x).__getitem__(dims), mode=mode_opt)
#theano.printing.debugprint(f)
prog=f.maker.env.toposort()
if res:
assert prog[0].op == TT.exp, dims
assert len(prog)==1, dims
else:
assert any([isinstance(node.op, Subtensor) for node in prog])
f([[0,1,2],[3,4,5]]) # let debugmode test something
# Test mix Variable and Constant
# Currently not supported
for idx, (dims, res) in enumerate([
((slice(0,x.shape[0]),slice(0,3)), False),
((slice(0,3),slice(0,x.shape[1])), False),
]):
f = function([x], TT.exp(x_c).__getitem__(dims), mode=mode_opt)
#theano.printing.debugprint(f)
prog=f.maker.env.toposort()
if res:
assert prog[0].op == TT.exp, dims
assert len(prog)==1, dims
else:
assert any([isinstance(node.op, Subtensor) for node in prog])
f([[0,1,2],[3,4,5]]) # let debugmode test something
class test_local_subtensor_lift(unittest.TestCase): class test_local_subtensor_lift(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论