提交 bfde042c authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5098 from olimastro/ccw4647

New optimization: DimShuffle(Alloc(...)) -> Alloc(..., use dims of 1 for new dimensions) #5024
...@@ -110,7 +110,7 @@ def local_alloc_dimshuffle(node): ...@@ -110,7 +110,7 @@ def local_alloc_dimshuffle(node):
Alloc(DimShuffle(x), ...) - > Alloc(x, ...) Alloc(DimShuffle(x), ...) - > Alloc(x, ...)
""" """
if node.op == T.alloc: if isinstance(node.op, T.Alloc):
input_ = node.inputs[0] input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle): if input_.owner and isinstance(input_.owner.op, DimShuffle):
# check if it only adds dimension to the left # check if it only adds dimension to the left
...@@ -146,3 +146,30 @@ def local_reshape_dimshuffle(node): ...@@ -146,3 +146,30 @@ def local_reshape_dimshuffle(node):
offset += 1 offset += 1
return [T.reshape(input_.owner.inputs[0], node.inputs[1])] return [T.reshape(input_.owner.inputs[0], node.inputs[1])]
return False return False
@register_uncanonicalize
@gof.local_optimizer([T.DimShuffle])
def local_dimshuffle_alloc(node):
"""
If an alloc is inside a dimshuffle which only adds dimension to the left,
scrap the dimshuffle and adds 1 into the alloc
dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2)
"""
if isinstance(node.op, T.DimShuffle) and node.inputs[0].owner:
input_ = node.inputs[0]
if isinstance(input_.owner.op, T.Alloc):
# check if it only adds dimension to the left
new_order = node.op.new_order
expected_new_order = ('x',) * (len(new_order) - input_.ndim) + \
tuple(range(input_.ndim))
if new_order != expected_new_order:
return False
# count numbers of 'x'
nb_new_dims = len(new_order) - input_.ndim
new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:])
return [T.alloc(input_.owner.inputs[0], *new_shape_input)]
return False
...@@ -10,7 +10,8 @@ from theano.gof import FunctionGraph ...@@ -10,7 +10,8 @@ from theano.gof import FunctionGraph
from theano.gof.opt import out2in from theano.gof.opt import out2in
from theano.tensor.opt_uncanonicalize import ( from theano.tensor.opt_uncanonicalize import (
local_alloc_dimshuffle, local_alloc_dimshuffle,
local_reshape_dimshuffle local_reshape_dimshuffle,
local_dimshuffle_alloc,
) )
import theano.tensor as tensor import theano.tensor as tensor
#from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg #from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg
...@@ -145,3 +146,25 @@ def test_local_reshape_dimshuffle(): ...@@ -145,3 +146,25 @@ def test_local_reshape_dimshuffle():
topo = g.toposort() topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo]) assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_reshape_dimshuffle():
reshape_dimshuffle = out2in(local_dimshuffle_alloc)
x = tensor.vector('x')
out = tensor.alloc(x, 3, 2).dimshuffle('x', 'x', 0, 1)
g = FunctionGraph([x], [out])
reshape_dimshuffle(g)
l=theano.gof.PerformLinker()
l.accept(g)
f=l.make_function()
assert f([3, 4]).ndim == 4
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论