提交 6c0263e4 authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Import changes and Dimshuffle{}(Subtensor) => Subtensor implementation

上级 d9fe1b74
......@@ -39,7 +39,7 @@ import logging
from theano import gof
from theano.tensor.elemwise import CAReduce
from theano.tensor import basic as T
from theano.tensor import DimShuffle
from theano.tensor import DimShuffle, Subtensor
from theano.tensor.basic import (get_scalar_constant_value,
NotScalarConstantError)
......@@ -149,7 +149,7 @@ def local_reshape_dimshuffle(node):
@register_uncanonicalize
@gof.local_optimizer([T.DimShuffle])
@gof.local_optimizer([DimShuffle])
def local_dimshuffle_alloc(node):
"""
If an alloc is inside a dimshuffle which only adds dimension to the left,
......@@ -157,7 +157,7 @@ def local_dimshuffle_alloc(node):
dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2)
"""
if isinstance(node.op, T.DimShuffle) and node.inputs[0].owner:
if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
input_ = node.inputs[0]
if isinstance(input_.owner.op, T.Alloc):
# check if it only adds dimension to the left
......@@ -173,3 +173,50 @@ def local_dimshuffle_alloc(node):
return [T.alloc(input_.owner.inputs[0], *new_shape_input)]
return False
@register_uncanonicalize
@gof.local_optimizer([DimShuffle])
def local_dimshuffle_subtensor(node):
"""
If a subtensor is inside a dimshuffle which only drop broadcastable dimensions,
scrap the dimshuffle and index the subtensor with 0
x[i:j, :, k:l].dimshuffle(0, 2) => x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
"""
if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
# the dimshuffle can only drop dimensions (cannot reshape nor add 'x')
if 'x' in node.op.new_order :
return False
new_order = node.op.new_order
past_dim = new_order[0]
for dim in new_order[1:]:
if not dim > past_dim:
return False
else:
past_dim = dim
input_ = node.inputs[0]
if isinstance(input_.owner.op, Subtensor):
# the arguments missing from the dimshuffles must be dims that are broadcastable
broadcastable = input_.broadcastable
missing_dims = range(input_.ndim)
for dim in new_order:
missing_dims.remove(dim)
if not all([broadcastable[i] for i in missing_dims]):
return False
# create a new idx_list for a new Subtensor object
# idx_list could be longer than the len(missing_dims), that would happen with
# x[0, :, :].dimshuffle(1,0)
new_idx_list = list(input_.owner.op.idx_list)
offset = len(new_idx_list) - len(missing_dims)
zero = T.constant(0)
for dim in missing_dims:
new_idx_list[dim + offset] = slice(zero, None, None)
input_.owner.op.idx_list = tuple(new_idx_list)
return [input_]
return False
......@@ -12,6 +12,7 @@ from theano.tensor.opt_uncanonicalize import (
local_alloc_dimshuffle,
local_reshape_dimshuffle,
local_dimshuffle_alloc,
local_dimshuffle_subtensor,
)
import theano.tensor as tensor
#from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg
......@@ -148,8 +149,7 @@ def test_local_reshape_dimshuffle():
assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_reshape_dimshuffle():
def test_local_dimshuffle_alloc():
reshape_dimshuffle = out2in(local_dimshuffle_alloc)
......@@ -168,3 +168,21 @@ def test_local_reshape_dimshuffle():
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_dimshuffle_subtensor():
dimshuffle_subtensor = out2in(local_dimshuffle_subtensor)
x = tensor.tensor4('x')
x = tensor.patternbroadcast(x, (False, True, False, False))
i = tensor.iscalar('i')
out = x[i, :, 10:30, :-1].dimshuffle(1,2)
g = FunctionGraph([x,i], [out])
import ipdb; ipdb.set_trace()
dimshuffle_subtensor(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论