提交 944d023d authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5153 from olimastro/ccw4647

Dimshuffle{0,2}(Subtensor[i:j, :, k:l]) => Subtensor[i:j, 0, k:l] #4647
...@@ -39,7 +39,7 @@ import logging ...@@ -39,7 +39,7 @@ import logging
from theano import gof from theano import gof
from theano.tensor.elemwise import CAReduce from theano.tensor.elemwise import CAReduce
from theano.tensor import basic as T from theano.tensor import basic as T
from theano.tensor import DimShuffle from theano.tensor import DimShuffle, Subtensor
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal from theano import scalar as scal
...@@ -138,7 +138,7 @@ def local_reshape_dimshuffle(node): ...@@ -138,7 +138,7 @@ def local_reshape_dimshuffle(node):
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.DimShuffle]) @gof.local_optimizer([DimShuffle])
def local_dimshuffle_alloc(node): def local_dimshuffle_alloc(node):
""" """
If an alloc is inside a dimshuffle which only adds dimension to the left, If an alloc is inside a dimshuffle which only adds dimension to the left,
...@@ -146,7 +146,7 @@ def local_dimshuffle_alloc(node): ...@@ -146,7 +146,7 @@ def local_dimshuffle_alloc(node):
dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2) dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2)
""" """
if isinstance(node.op, T.DimShuffle) and node.inputs[0].owner: if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
input_ = node.inputs[0] input_ = node.inputs[0]
if isinstance(input_.owner.op, T.Alloc): if isinstance(input_.owner.op, T.Alloc):
# check if it only adds dimension to the left # check if it only adds dimension to the left
...@@ -162,3 +162,70 @@ def local_dimshuffle_alloc(node): ...@@ -162,3 +162,70 @@ def local_dimshuffle_alloc(node):
return [T.alloc(input_.owner.inputs[0], *new_shape_input)] return [T.alloc(input_.owner.inputs[0], *new_shape_input)]
return False return False
@register_uncanonicalize
@gof.local_optimizer([DimShuffle])
def local_dimshuffle_subtensor(node):
"""
If a subtensor is inside a dimshuffle which only drop broadcastable dimensions,
scrap the dimshuffle and index the subtensor with 0
x[i:j, :, k:l].dimshuffle(0, 2) => x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
"""
if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
# the dimshuffle can only drop dimensions (cannot reshape nor add 'x')
if 'x' in node.op.new_order:
return False
new_order = node.op.new_order
# new order could be empty
if len(new_order) > 1:
past_dim = new_order[0]
for dim in new_order[1:]:
if not dim > past_dim:
return False
else:
past_dim = dim
input_ = node.inputs[0]
if isinstance(input_.owner.op, Subtensor):
# the arguments missing from the dimshuffles must be dims that are broadcastable
broadcastable = input_.broadcastable
missing_dims = list(range(input_.ndim))
for dim in new_order:
missing_dims.remove(dim)
if not all([broadcastable[i] for i in missing_dims]):
return False
# create a new idx_list for a new Subtensor object
# have to loop on idx_list and inputs
# inputs has the length of sum of non None elements of idx_list (check in slice!)
# len(missing_dims) can be < len(idx_list), this happens if
# tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list = list(input_.owner.op.idx_list)
new_inputs = [input_.owner.inputs[0]]
zero = T.constant(0)
slice_attr_list = ['start', 'stop', 'step']
j = 0
slice_i = -1
for idx in input_.owner.op.idx_list:
if isinstance(idx, slice):
past_j = j
slice_i += 1
for slice_attr in slice_attr_list:
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
# if past_j == j indicates a slice(None, None, None), that's where
# we want to index with 0 if it is also at the same
# spot of a missing dim
if past_j == j and slice_i in missing_dims:
new_idx_list[j] = zero
new_inputs += [zero]
else:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
return [Subtensor(new_idx_list)(*new_inputs)]
return False
...@@ -12,6 +12,7 @@ from theano.tensor.opt_uncanonicalize import ( ...@@ -12,6 +12,7 @@ from theano.tensor.opt_uncanonicalize import (
local_alloc_dimshuffle, local_alloc_dimshuffle,
local_reshape_dimshuffle, local_reshape_dimshuffle,
local_dimshuffle_alloc, local_dimshuffle_alloc,
local_dimshuffle_subtensor,
) )
import theano.tensor as tensor import theano.tensor as tensor
#from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg #from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg
...@@ -148,8 +149,7 @@ def test_local_reshape_dimshuffle(): ...@@ -148,8 +149,7 @@ def test_local_reshape_dimshuffle():
assert any([not isinstance(x, DimShuffle) for x in topo]) assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_dimshuffle_alloc():
def test_local_reshape_dimshuffle():
reshape_dimshuffle = out2in(local_dimshuffle_alloc) reshape_dimshuffle = out2in(local_dimshuffle_alloc)
...@@ -168,3 +168,20 @@ def test_local_reshape_dimshuffle(): ...@@ -168,3 +168,20 @@ def test_local_reshape_dimshuffle():
topo = g.toposort() topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo]) assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_dimshuffle_subtensor():
dimshuffle_subtensor = out2in(local_dimshuffle_subtensor)
x = tensor.tensor4('x')
x = tensor.patternbroadcast(x, (False, True, False, False))
i = tensor.iscalar('i')
out = x[:, :, 10:30, ::i].dimshuffle(0,2,3)
g = FunctionGraph([x,i], [out])
dimshuffle_subtensor(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论