提交 2d4b514d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix subtensor merge when the first subtensor is empty. Test it.

上级 4357224c
...@@ -1127,14 +1127,41 @@ def local_subtensor_merge(node): ...@@ -1127,14 +1127,41 @@ def local_subtensor_merge(node):
return False return False
if (isinstance(u.owner.op, T.Subtensor) and if (isinstance(u.owner.op, T.Subtensor) and
len(u.owner.inputs) in [1,2] and
len(u.owner.op.idx_list)==1 and len(u.owner.op.idx_list)==1 and
isinstance(u.owner.op.idx_list[0], slice) and isinstance(u.owner.op.idx_list[0], slice) and
isinstance(u.owner.op.idx_list[0].start, (int, scalar.basic.Scalar)) and
u.owner.op.idx_list[0].stop is None and u.owner.op.idx_list[0].stop is None and
u.owner.op.idx_list[0].step is None u.owner.op.idx_list[0].step is None
): ):
return [u.owner.inputs[0][-1]] u_start = u.owner.op.idx_list[0].start
if len(u.owner.inputs) == 1 and isinstance(u_start, int):
print 'int'
start0 = u_start
elif (len(u.owner.inputs) == 2 and
isinstance (u_start, scalar.basic.Scalar)):
print 'scalar'
start0 = T.tensor_from_scalar(u.owner.inputs[1])
else:
return False
len0 = u.owner.inputs[0].shape[0]
# The following is equivalent to:
# if start0 <= -u.shape[0]:
# actual_start0 = 0
# elif start0 < 0:
# actual_start0 = start0 + u.shape[0]
# else:
# actual_start0 = start0
actual_start0 = (start0 > -len0) * (start0 + ((start0 < 0) * len0))
# if actual_start < u.shape[0]:
# new_index = -1
# else: # Will give an IndexError
# new_index = actual_start
new_index = -1 + (actual_start0 >= len0) * (actual_start0 + 1)
new_index = T.scalar_from_tensor(new_index)
return [u.owner.inputs[0][new_index]]
@register_canonicalize @register_canonicalize
......
...@@ -1249,29 +1249,43 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1249,29 +1249,43 @@ class test_local_subtensor_merge(unittest.TestCase):
def test_const(self): def test_const(self):
# var[const::][-1] -> var[-1] # var[const::][-1] -> var[-1]
x = TT.matrix('x') x = TT.matrix('x')
for idx in [-2,-1,1]: for idx in [-2,-1,1,2]:
f = function([x], x[idx::][-1], mode=mode_opt) f = function([x], x[idx::][-1], mode=mode_opt)
#theano.printing.debugprint(f) #theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort() topo=f.maker.env.toposort()
assert len(topo)==2 print [t for t in topo if isinstance(t.op, TT.Subtensor)]
assert isinstance(topo[0].op, TT.Subtensor) assert len([t for t in topo if isinstance(t.op, TT.Subtensor)]) == 1
assert isinstance(topo[1].op, theano.compile.function_module.DeepCopyOp) print topo[-1].op
f([[0,1],[2,3]]) # let debugmode test something assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
x_val = [[0,1],[2,3]]
if idx<2:
# The first subtensor is non-empty, so it makes sense
f(x_val) # let debugmode test something
else:
# A non-empty subtensor of an empty one should be an IndexError
self.assertRaises(IndexError, f, x_val)
def test_scalar(self): def test_scalar(self):
# var[int::][-1] -> var[-1] # var[int::][-1] -> var[-1]
x = TT.matrix('x') x = TT.matrix('x')
y = TT.iscalar('y') y = TT.iscalar('y')
f = function([x,y], x[y::][-1], mode=mode_opt) f = function([x,y], x[y::][-1], mode=mode_opt)
#theano.printing.debugprint(f) #theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort() topo=f.maker.env.toposort()
assert len(topo)==2 print [t for t in topo if isinstance(t.op, TT.Subtensor)]
assert isinstance(topo[0].op, TT.Subtensor) assert len([t for t in topo if isinstance(t.op, TT.Subtensor)]) == 1
assert isinstance(topo[1].op, theano.compile.function_module.DeepCopyOp) print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
x_val = [[0,1],[2,3]]
for idx in range(-10,2): for idx in range(-10,2):
f([[0,1],[2,3]], idx) # let debugmode test something f(x_val, idx) # let debugmode test something
for idx in range(2,4):
self.assertRaises(IndexError, f, x_val, idx)
def test_dont_opt(self): def test_dont_opt(self):
# Test that we don't optimize some case # Test that we don't optimize some case
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论