提交 6315fdfa authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5350 from nouiz/local_dimshuffle_subtensor

Remove opt warning from local_dimshuffle_subtensor
......@@ -168,11 +168,13 @@ def local_dimshuffle_alloc(node):
@register_uncanonicalize
@gof.local_optimizer([DimShuffle])
def local_dimshuffle_subtensor(node):
"""
If a subtensor is inside a dimshuffle which only drop broadcastable dimensions,
scrap the dimshuffle and index the subtensor with 0
"""If a subtensor is inside a dimshuffle which only drop
broadcastable dimensions, scrap the dimshuffle and index the
subtensor with 0
x[i:j, :, k:l].dimshuffle(0, 2) =>
x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
x[i:j, :, k:l].dimshuffle(0, 2) => x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
"""
if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
# the dimshuffle can only drop dimensions (cannot reshape nor add 'x')
......@@ -190,7 +192,8 @@ def local_dimshuffle_subtensor(node):
input_ = node.inputs[0]
if isinstance(input_.owner.op, Subtensor):
# the arguments missing from the dimshuffles must be dims that are broadcastable
# the arguments missing from the dimshuffles must be dims
# that are broadcastable
broadcastable = input_.broadcastable
missing_dims = list(range(input_.ndim))
......@@ -202,7 +205,8 @@ def local_dimshuffle_subtensor(node):
# create a new idx_list for a new Subtensor object
# have to loop on idx_list and inputs
# inputs has the length of sum of non None elements of idx_list (check in slice!)
# inputs has the length of sum of non None elements of idx_list
# (check in slice!).
# len(missing_dims) can be < len(idx_list), this happens if
# tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list = list(input_.owner.op.idx_list)
......@@ -211,6 +215,7 @@ def local_dimshuffle_subtensor(node):
slice_attr_list = ['start', 'stop', 'step']
j = 0
slice_i = -1
subtensor_removed_dims = 0
for idx in input_.owner.op.idx_list:
if isinstance(idx, slice):
past_j = j
......@@ -219,14 +224,24 @@ def local_dimshuffle_subtensor(node):
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
# if past_j == j indicates a slice(None, None, None), that's where
# we want to index with 0 if it is also at the same
# spot of a missing dim
# if past_j == j indicates a slice(None, None, None),
# that's where we want to index with 0 if it is also at
# the same spot of a missing dim
if past_j == j and slice_i in missing_dims:
new_idx_list[j] = zero
new_inputs += [zero]
else:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
subtensor_removed_dims += 1
# Verify the trailing dimensions the subtensor didn't look at.
for idx in range(len(input_.owner.op.idx_list),
new_inputs[0].ndim):
if (idx - subtensor_removed_dims) in missing_dims:
while len(new_idx_list) < idx:
new_idx_list.append(slice(None))
new_idx_list.append(zero)
new_inputs.append(zero)
return [Subtensor(new_idx_list)(*new_inputs)]
return False
......@@ -160,9 +160,9 @@ def test_local_dimshuffle_alloc():
g = FunctionGraph([x], [out])
reshape_dimshuffle(g)
l=theano.gof.PerformLinker()
l = theano.gof.PerformLinker()
l.accept(g)
f=l.make_function()
f = l.make_function()
assert f([3, 4]).ndim == 4
......@@ -174,14 +174,36 @@ def test_local_dimshuffle_subtensor():
dimshuffle_subtensor = out2in(local_dimshuffle_subtensor)
x = tensor.tensor4('x')
x = tensor.dtensor4('x')
x = tensor.patternbroadcast(x, (False, True, False, False))
i = tensor.iscalar('i')
out = x[:, :, 10:30, ::i].dimshuffle(0,2,3)
out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3)
g = FunctionGraph([x,i], [out])
g = FunctionGraph([x, i], [out])
dimshuffle_subtensor(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
# Test dimshuffle remove dimensions the subtensor don't "see".
x = tensor.tensor(broadcastable=(False, True, False), dtype='float64')
out = x[i].dimshuffle(1)
g = FunctionGraph([x, i], [out])
dimshuffle_subtensor(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
# Test dimshuffle remove dimensions the subtensor don't "see" but
# have in between dimensions.
x = tensor.tensor(broadcastable=(False, True, False, True),
dtype='float64')
out = x[i].dimshuffle(1)
f = theano.function([x, i], out)
topo = f.maker.fgraph.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
assert f(numpy.random.rand(5, 1, 4, 1), 2).shape == (4,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论