提交 2fb76e8d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix the fix of local_useless_subtensor opt. Add two new test cases.

上级 62b59828
...@@ -472,7 +472,7 @@ class MakeVector(T.Op): ...@@ -472,7 +472,7 @@ class MakeVector(T.Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if None in eval_points: if None in eval_points:
return [None] return [None]
return self.make_node(*eval_points).outputs return self.make_node(*eval_points).outputs
make_vector = MakeVector() make_vector = MakeVector()
...@@ -1258,29 +1258,37 @@ def local_useless_subtensor(node): ...@@ -1258,29 +1258,37 @@ def local_useless_subtensor(node):
# is not a useless subtensor # is not a useless subtensor
return False return False
length_pos_data = sys.maxint length_pos_data = sys.maxint
length_pos_shape_i = None length_pos_shape_i = None
length_pos = shape_of[node.inputs[0]][pos]
try: try:
length_pos = shape_of[node.inputs[0]][pos] length_pos_data = get_constant_value(length_pos)
try: except TypeError:
length_pos_data = get_constant_value(length_pos) pass
except TypeError:
pass
if isinstance(idx.stop, theano.scalar.Scalar): if isinstance(idx.stop, theano.scalar.Scalar):
if isinstance(node.inputs[node_input_idx].owner.op, length_pos_shape_i = node.inputs[node_input_idx]
T.ScalarFromTensor): # length_pos is a tensor variable, but length_pos_shape_i
length_pos_shape_i = node.inputs[node_input_idx].owner.inputs[0] # is a scalar variable. We try to see if they represent
else: # the same underlying variable.
length_pos_shape_i = node.inputs[node_input_idx] if (length_pos_shape_i.owner and
assert length_pos_shape_i.type == idx.stop isinstance(length_pos_shape_i.owner.op,
# We already know that start and step are not variables T.ScalarFromTensor)):
# and so they don't appear in the input of the node length_pos_shape_i = length_pos_shape_i.owner.inputs[0]
node_input_idx += 1 elif (length_pos.owner and
# Catch exception from shape_of isinstance(length_pos.owner.op,
except Exception, e: T.TensorFromScalar)):
length_pos = None length_pos = length_pos.owner.inputs[0]
else:
# We did not find underlying variables of the same type
return False
assert length_pos_shape_i.type == length_pos.type
assert length_pos_shape_i.type.dtype == idx.stop.dtype
# We already know that start and step are not variables
# and so they don't appear in the input of the node
node_input_idx += 1
if isinstance(idx.stop, int): if isinstance(idx.stop, int):
if idx.stop < length_pos_data: if idx.stop < length_pos_data:
......
...@@ -1210,6 +1210,7 @@ def test_local_useless_subtensor(): ...@@ -1210,6 +1210,7 @@ def test_local_useless_subtensor():
((slice(0,x.shape[1]),slice(0,x.shape[1]),), False), ((slice(0,x.shape[1]),slice(0,x.shape[1]),), False),
((slice(0,x.shape[1]),2), False), ((slice(0,x.shape[1]),2), False),
((slice(0,x.shape[1]),slice(x.shape[0]-x.shape[0],x.shape[1]),), False), ((slice(0,x.shape[1]),slice(x.shape[0]-x.shape[0],x.shape[1]),), False),
((slice(0,T.scalar_from_tensor(x.shape[0])),), True),
]): ]):
f = function([x], tensor.exp(x).__getitem__(dims), mode=mode_opt) f = function([x], tensor.exp(x).__getitem__(dims), mode=mode_opt)
#theano.printing.debugprint(f) #theano.printing.debugprint(f)
...@@ -1236,6 +1237,22 @@ def test_local_useless_subtensor(): ...@@ -1236,6 +1237,22 @@ def test_local_useless_subtensor():
assert any([isinstance(node.op, Subtensor) for node in prog]) assert any([isinstance(node.op, Subtensor) for node in prog])
f([[0,1,2],[3,4,5]]) # let debugmode test something f([[0,1,2],[3,4,5]]) # let debugmode test something
# Test scalar variable
s = scal.int32('s')
for idx, (dims, res) in enumerate([
((slice(0,s),), False),
]):
f = function([x, s], tensor.exp(x).__getitem__(dims), mode=mode_opt)
#theano.printing.debugprint(f)
prog=f.maker.env.toposort()
if res:
assert prog[0].op == tensor.exp, dims
assert len(prog)==1, dims
else:
assert any([isinstance(node.op, Subtensor) for node in prog])
f([[1,2,3],[4,5,6]], 1)
f([[1,2,3],[4,5,6]], 3)
class test_local_subtensor_lift(unittest.TestCase): class test_local_subtensor_lift(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论