提交 6315fdfa authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5350 from nouiz/local_dimshuffle_subtensor

Remove opt warning from local_dimshuffle_subtensor
...@@ -168,11 +168,13 @@ def local_dimshuffle_alloc(node): ...@@ -168,11 +168,13 @@ def local_dimshuffle_alloc(node):
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([DimShuffle]) @gof.local_optimizer([DimShuffle])
def local_dimshuffle_subtensor(node): def local_dimshuffle_subtensor(node):
""" """If a subtensor is inside a dimshuffle which only drop
If a subtensor is inside a dimshuffle which only drop broadcastable dimensions, broadcastable dimensions, scrap the dimshuffle and index the
scrap the dimshuffle and index the subtensor with 0 subtensor with 0
x[i:j, :, k:l].dimshuffle(0, 2) =>
x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
x[i:j, :, k:l].dimshuffle(0, 2) => x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
""" """
if isinstance(node.op, DimShuffle) and node.inputs[0].owner: if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
# the dimshuffle can only drop dimensions (cannot reshape nor add 'x') # the dimshuffle can only drop dimensions (cannot reshape nor add 'x')
...@@ -190,7 +192,8 @@ def local_dimshuffle_subtensor(node): ...@@ -190,7 +192,8 @@ def local_dimshuffle_subtensor(node):
input_ = node.inputs[0] input_ = node.inputs[0]
if isinstance(input_.owner.op, Subtensor): if isinstance(input_.owner.op, Subtensor):
# the arguments missing from the dimshuffles must be dims that are broadcastable # the arguments missing from the dimshuffles must be dims
# that are broadcastable
broadcastable = input_.broadcastable broadcastable = input_.broadcastable
missing_dims = list(range(input_.ndim)) missing_dims = list(range(input_.ndim))
...@@ -202,7 +205,8 @@ def local_dimshuffle_subtensor(node): ...@@ -202,7 +205,8 @@ def local_dimshuffle_subtensor(node):
# create a new idx_list for a new Subtensor object # create a new idx_list for a new Subtensor object
# have to loop on idx_list and inputs # have to loop on idx_list and inputs
# inputs has the length of sum of non None elements of idx_list (check in slice!) # inputs has the length of sum of non None elements of idx_list
# (check in slice!).
# len(missing_dims) can be < len(idx_list), this happens if # len(missing_dims) can be < len(idx_list), this happens if
# tensor was indexed such as x[scalar, :, :], check that as well # tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list = list(input_.owner.op.idx_list) new_idx_list = list(input_.owner.op.idx_list)
...@@ -211,6 +215,7 @@ def local_dimshuffle_subtensor(node): ...@@ -211,6 +215,7 @@ def local_dimshuffle_subtensor(node):
slice_attr_list = ['start', 'stop', 'step'] slice_attr_list = ['start', 'stop', 'step']
j = 0 j = 0
slice_i = -1 slice_i = -1
subtensor_removed_dims = 0
for idx in input_.owner.op.idx_list: for idx in input_.owner.op.idx_list:
if isinstance(idx, slice): if isinstance(idx, slice):
past_j = j past_j = j
...@@ -219,14 +224,24 @@ def local_dimshuffle_subtensor(node): ...@@ -219,14 +224,24 @@ def local_dimshuffle_subtensor(node):
if getattr(idx, slice_attr) is not None: if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1 + j]] new_inputs += [input_.owner.inputs[1 + j]]
j += 1 j += 1
# if past_j == j indicates a slice(None, None, None), that's where # if past_j == j indicates a slice(None, None, None),
# we want to index with 0 if it is also at the same # that's where we want to index with 0 if it is also at
# spot of a missing dim # the same spot of a missing dim
if past_j == j and slice_i in missing_dims: if past_j == j and slice_i in missing_dims:
new_idx_list[j] = zero new_idx_list[j] = zero
new_inputs += [zero] new_inputs += [zero]
else: else:
new_inputs += [input_.owner.inputs[1 + j]] new_inputs += [input_.owner.inputs[1 + j]]
j += 1 j += 1
subtensor_removed_dims += 1
# Verify the trailing dimensions the subtensor didn't look at.
for idx in range(len(input_.owner.op.idx_list),
new_inputs[0].ndim):
if (idx - subtensor_removed_dims) in missing_dims:
while len(new_idx_list) < idx:
new_idx_list.append(slice(None))
new_idx_list.append(zero)
new_inputs.append(zero)
return [Subtensor(new_idx_list)(*new_inputs)] return [Subtensor(new_idx_list)(*new_inputs)]
return False return False
...@@ -160,9 +160,9 @@ def test_local_dimshuffle_alloc(): ...@@ -160,9 +160,9 @@ def test_local_dimshuffle_alloc():
g = FunctionGraph([x], [out]) g = FunctionGraph([x], [out])
reshape_dimshuffle(g) reshape_dimshuffle(g)
l=theano.gof.PerformLinker() l = theano.gof.PerformLinker()
l.accept(g) l.accept(g)
f=l.make_function() f = l.make_function()
assert f([3, 4]).ndim == 4 assert f([3, 4]).ndim == 4
...@@ -174,14 +174,36 @@ def test_local_dimshuffle_subtensor(): ...@@ -174,14 +174,36 @@ def test_local_dimshuffle_subtensor():
dimshuffle_subtensor = out2in(local_dimshuffle_subtensor) dimshuffle_subtensor = out2in(local_dimshuffle_subtensor)
x = tensor.tensor4('x') x = tensor.dtensor4('x')
x = tensor.patternbroadcast(x, (False, True, False, False)) x = tensor.patternbroadcast(x, (False, True, False, False))
i = tensor.iscalar('i') i = tensor.iscalar('i')
out = x[:, :, 10:30, ::i].dimshuffle(0,2,3) out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3)
g = FunctionGraph([x,i], [out]) g = FunctionGraph([x, i], [out])
dimshuffle_subtensor(g) dimshuffle_subtensor(g)
topo = g.toposort() topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo]) assert any([not isinstance(x, DimShuffle) for x in topo])
# Test dimshuffle remove dimensions the subtensor don't "see".
x = tensor.tensor(broadcastable=(False, True, False), dtype='float64')
out = x[i].dimshuffle(1)
g = FunctionGraph([x, i], [out])
dimshuffle_subtensor(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
# Test dimshuffle remove dimensions the subtensor don't "see" but
# have in between dimensions.
x = tensor.tensor(broadcastable=(False, True, False, True),
dtype='float64')
out = x[i].dimshuffle(1)
f = theano.function([x, i], out)
topo = f.maker.fgraph.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
assert f(numpy.random.rand(5, 1, 4, 1), 2).shape == (4,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论