提交 2fb76e8d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix the fix of local_useless_subtensor opt. Add two new test cases.

上级 62b59828
......@@ -1258,10 +1258,9 @@ def local_useless_subtensor(node):
# is not a useless subtensor
return False
length_pos_data = sys.maxint
length_pos_shape_i = None
try:
length_pos = shape_of[node.inputs[0]][pos]
try:
length_pos_data = get_constant_value(length_pos)
......@@ -1269,18 +1268,27 @@ def local_useless_subtensor(node):
pass
if isinstance(idx.stop, theano.scalar.Scalar):
if isinstance(node.inputs[node_input_idx].owner.op,
T.ScalarFromTensor):
length_pos_shape_i = node.inputs[node_input_idx].owner.inputs[0]
else:
length_pos_shape_i = node.inputs[node_input_idx]
assert length_pos_shape_i.type == idx.stop
# length_pos is a tensor variable, but length_pos_shape_i
# is a scalar variable. We try to see if they represent
# the same underlying variable.
if (length_pos_shape_i.owner and
isinstance(length_pos_shape_i.owner.op,
T.ScalarFromTensor)):
length_pos_shape_i = length_pos_shape_i.owner.inputs[0]
elif (length_pos.owner and
isinstance(length_pos.owner.op,
T.TensorFromScalar)):
length_pos = length_pos.owner.inputs[0]
else:
# We did not find underlying variables of the same type
return False
assert length_pos_shape_i.type == length_pos.type
assert length_pos_shape_i.type.dtype == idx.stop.dtype
# We already know that start and step are not variables
# and so they don't appear in the input of the node
node_input_idx += 1
# Catch exception from shape_of
except Exception, e:
length_pos = None
if isinstance(idx.stop, int):
if idx.stop < length_pos_data:
......
......@@ -1210,6 +1210,7 @@ def test_local_useless_subtensor():
((slice(0,x.shape[1]),slice(0,x.shape[1]),), False),
((slice(0,x.shape[1]),2), False),
((slice(0,x.shape[1]),slice(x.shape[0]-x.shape[0],x.shape[1]),), False),
((slice(0,T.scalar_from_tensor(x.shape[0])),), True),
]):
f = function([x], tensor.exp(x).__getitem__(dims), mode=mode_opt)
#theano.printing.debugprint(f)
......@@ -1236,6 +1237,22 @@ def test_local_useless_subtensor():
assert any([isinstance(node.op, Subtensor) for node in prog])
f([[0,1,2],[3,4,5]]) # let debugmode test something
# Test scalar variable
s = scal.int32('s')
for idx, (dims, res) in enumerate([
((slice(0,s),), False),
]):
f = function([x, s], tensor.exp(x).__getitem__(dims), mode=mode_opt)
#theano.printing.debugprint(f)
prog=f.maker.env.toposort()
if res:
assert prog[0].op == tensor.exp, dims
assert len(prog)==1, dims
else:
assert any([isinstance(node.op, Subtensor) for node in prog])
f([[1,2,3],[4,5,6]], 1)
f([[1,2,3],[4,5,6]], 3)
class test_local_subtensor_lift(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论