提交 767b820d authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix cases where the removed dimensions and the subtensor touched dimensions have a gap in between.

上级 5a1846a0
......@@ -238,6 +238,9 @@ def local_dimshuffle_subtensor(node):
for idx in range(len(input_.owner.op.idx_list),
new_inputs[0].ndim):
if (idx - subtensor_removed_dims) in missing_dims:
while len(new_idx_list) < idx:
new_idx_list.append(slice(None))
new_idx_list.append(zero)
new_inputs.append(zero)
return [Subtensor(new_idx_list)(*new_inputs)]
......
......@@ -174,7 +174,7 @@ def test_local_dimshuffle_subtensor():
dimshuffle_subtensor = out2in(local_dimshuffle_subtensor)
x = tensor.tensor4('x')
x = tensor.dtensor4('x')
x = tensor.patternbroadcast(x, (False, True, False, False))
i = tensor.iscalar('i')
......@@ -186,7 +186,8 @@ def test_local_dimshuffle_subtensor():
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
x = tensor.tensor(broadcastable=(False, True, False), dtype='floatX')
# Test dimshuffle remove dimensions the subtensor don't "see".
x = tensor.tensor(broadcastable=(False, True, False), dtype='float64')
out = x[i].dimshuffle(1)
g = FunctionGraph([x, i], [out])
......@@ -194,3 +195,15 @@ def test_local_dimshuffle_subtensor():
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
# Test dimshuffle remove dimensions the subtensor don't "see" but
# have in between dimensions.
x = tensor.tensor(broadcastable=(False, True, False, True),
dtype='float64')
out = x[i].dimshuffle(1)
f = theano.function([x, i], out)
topo = f.maker.fgraph.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
assert f(numpy.random.rand(5, 1, 4, 1), 2).shape == (4,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论