提交 e07dbb63 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4935 from olimastro/ccw4647

[WIP] Opt Alloc(DimShuffle(...)...) -> Alloc(...) #4647
......@@ -39,6 +39,7 @@ import logging
from theano import gof
from theano.tensor.elemwise import CAReduce
from theano.tensor import basic as T
from theano.tensor import DimShuffle
from theano.tensor.basic import (get_scalar_constant_value,
NotScalarConstantError)
......@@ -98,3 +99,50 @@ def local_max_to_min(node):
max.owner.op.axis)(neg.owner.inputs[0])]
return False
@register_uncanonicalize
@gof.local_optimizer([T.Alloc])
def local_alloc_dimshuffle(node):
"""
If a dimshuffle is inside an alloc and only adds dimension to the
left, remove it.
Alloc(DimShuffle(x), ...) - > Alloc(x, ...)
"""
if node.op == T.alloc:
input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle):
# check if it only adds dimension to the left
new_order = input_.owner.op.new_order
expected_new_order = ('x',) * (input_.ndim - input_.owner.inputs[0].ndim) + \
tuple(range(input_.owner.inputs[0].ndim))
if new_order != expected_new_order:
return False
return [T.alloc(input_.owner.inputs[0], *node.inputs[1:])]
return False
@register_uncanonicalize
@gof.local_optimizer([T.Reshape])
def local_reshape_dimshuffle(node):
"""
If a dimshuffle is inside a reshape and does not change the order
of dimensions, remove it.
Reshape(Dimshuffle(x), shp) -> Reshape(x, shp)
"""
if isinstance(node.op, T.Reshape):
input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle):
new_order = input_.owner.op.new_order
offset = 0
for dim in new_order:
if dim == 'x':
continue
elif dim != offset:
return False
else:
offset += 1
return [T.reshape(input_.owner.inputs[0], node.inputs[1])]
return False
......@@ -6,9 +6,15 @@ import numpy
import theano
from theano import function, config
from theano import scalar
from theano.gof import FunctionGraph
from theano.gof.opt import out2in
from theano.tensor.opt_uncanonicalize import (
local_alloc_dimshuffle,
local_reshape_dimshuffle
)
import theano.tensor as tensor
#from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg
from theano.tensor.elemwise import CAReduce, Elemwise
from theano.tensor.elemwise import CAReduce, Elemwise, DimShuffle
from theano.tests import unittest_tools as utt
......@@ -106,3 +112,36 @@ class T_min_max(unittest.TestCase):
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce) # max
f(data)
def test_local_alloc_dimshuffle():
alloc_dimshuffle = out2in(local_alloc_dimshuffle)
x = tensor.vector('x')
m = tensor.iscalar('m')
y = x.dimshuffle('x', 0)
out = tensor.alloc(y, m, 1, x.shape[0])
g = FunctionGraph([x, m], [out])
alloc_dimshuffle(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_reshape_dimshuffle():
reshape_dimshuffle = out2in(local_reshape_dimshuffle)
x = tensor.matrix('x')
y = x.dimshuffle('x', 0, 'x', 1)
out = tensor.reshape(y, (1, x.shape[0] * x.shape[1], 1))
g = FunctionGraph([x], [out])
reshape_dimshuffle(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论