提交 46c58294 authored 作者: Frederic Bastien's avatar Frederic Bastien

flake8

上级 42a99bcc
...@@ -168,11 +168,13 @@ def local_dimshuffle_alloc(node): ...@@ -168,11 +168,13 @@ def local_dimshuffle_alloc(node):
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([DimShuffle]) @gof.local_optimizer([DimShuffle])
def local_dimshuffle_subtensor(node): def local_dimshuffle_subtensor(node):
""" """If a subtensor is inside a dimshuffle which only drop
If a subtensor is inside a dimshuffle which only drop broadcastable dimensions, broadcastable dimensions, scrap the dimshuffle and index the
scrap the dimshuffle and index the subtensor with 0 subtensor with 0
x[i:j, :, k:l].dimshuffle(0, 2) =>
x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
x[i:j, :, k:l].dimshuffle(0, 2) => x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
""" """
if isinstance(node.op, DimShuffle) and node.inputs[0].owner: if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
# the dimshuffle can only drop dimensions (cannot reshape nor add 'x') # the dimshuffle can only drop dimensions (cannot reshape nor add 'x')
...@@ -190,7 +192,8 @@ def local_dimshuffle_subtensor(node): ...@@ -190,7 +192,8 @@ def local_dimshuffle_subtensor(node):
input_ = node.inputs[0] input_ = node.inputs[0]
if isinstance(input_.owner.op, Subtensor): if isinstance(input_.owner.op, Subtensor):
# the arguments missing from the dimshuffles must be dims that are broadcastable # the arguments missing from the dimshuffles must be dims
# that are broadcastable
broadcastable = input_.broadcastable broadcastable = input_.broadcastable
missing_dims = list(range(input_.ndim)) missing_dims = list(range(input_.ndim))
...@@ -202,7 +205,8 @@ def local_dimshuffle_subtensor(node): ...@@ -202,7 +205,8 @@ def local_dimshuffle_subtensor(node):
# create a new idx_list for a new Subtensor object # create a new idx_list for a new Subtensor object
# have to loop on idx_list and inputs # have to loop on idx_list and inputs
# inputs has the length of sum of non None elements of idx_list (check in slice!) # inputs has the length of sum of non None elements of idx_list
# (check in slice!).
# len(missing_dims) can be < len(idx_list), this happens if # len(missing_dims) can be < len(idx_list), this happens if
# tensor was indexed such as x[scalar, :, :], check that as well # tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list = list(input_.owner.op.idx_list) new_idx_list = list(input_.owner.op.idx_list)
...@@ -219,9 +223,9 @@ def local_dimshuffle_subtensor(node): ...@@ -219,9 +223,9 @@ def local_dimshuffle_subtensor(node):
if getattr(idx, slice_attr) is not None: if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1 + j]] new_inputs += [input_.owner.inputs[1 + j]]
j += 1 j += 1
# if past_j == j indicates a slice(None, None, None), that's where # if past_j == j indicates a slice(None, None, None),
# we want to index with 0 if it is also at the same # that's where we want to index with 0 if it is also at
# spot of a missing dim # the same spot of a missing dim
if past_j == j and slice_i in missing_dims: if past_j == j and slice_i in missing_dims:
new_idx_list[j] = zero new_idx_list[j] = zero
new_inputs += [zero] new_inputs += [zero]
......
...@@ -160,9 +160,9 @@ def test_local_dimshuffle_alloc(): ...@@ -160,9 +160,9 @@ def test_local_dimshuffle_alloc():
g = FunctionGraph([x], [out]) g = FunctionGraph([x], [out])
reshape_dimshuffle(g) reshape_dimshuffle(g)
l=theano.gof.PerformLinker() l = theano.gof.PerformLinker()
l.accept(g) l.accept(g)
f=l.make_function() f = l.make_function()
assert f([3, 4]).ndim == 4 assert f([3, 4]).ndim == 4
...@@ -178,9 +178,9 @@ def test_local_dimshuffle_subtensor(): ...@@ -178,9 +178,9 @@ def test_local_dimshuffle_subtensor():
x = tensor.patternbroadcast(x, (False, True, False, False)) x = tensor.patternbroadcast(x, (False, True, False, False))
i = tensor.iscalar('i') i = tensor.iscalar('i')
out = x[:, :, 10:30, ::i].dimshuffle(0,2,3) out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3)
g = FunctionGraph([x,i], [out]) g = FunctionGraph([x, i], [out])
dimshuffle_subtensor(g) dimshuffle_subtensor(g)
topo = g.toposort() topo = g.toposort()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论