提交 e07dbb63 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4935 from olimastro/ccw4647

[WIP] Opt Alloc(DimShuffle(...)...) -> Alloc(...) #4647
...@@ -39,6 +39,7 @@ import logging ...@@ -39,6 +39,7 @@ import logging
from theano import gof from theano import gof
from theano.tensor.elemwise import CAReduce from theano.tensor.elemwise import CAReduce
from theano.tensor import basic as T from theano.tensor import basic as T
from theano.tensor import DimShuffle
from theano.tensor.basic import (get_scalar_constant_value, from theano.tensor.basic import (get_scalar_constant_value,
NotScalarConstantError) NotScalarConstantError)
...@@ -98,3 +99,50 @@ def local_max_to_min(node): ...@@ -98,3 +99,50 @@ def local_max_to_min(node):
max.owner.op.axis)(neg.owner.inputs[0])] max.owner.op.axis)(neg.owner.inputs[0])]
return False return False
@register_uncanonicalize
@gof.local_optimizer([T.Alloc])
def local_alloc_dimshuffle(node):
"""
If a dimshuffle is inside an alloc and only adds dimension to the
left, remove it.
Alloc(DimShuffle(x), ...) - > Alloc(x, ...)
"""
if node.op == T.alloc:
input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle):
# check if it only adds dimension to the left
new_order = input_.owner.op.new_order
expected_new_order = ('x',) * (input_.ndim - input_.owner.inputs[0].ndim) + \
tuple(range(input_.owner.inputs[0].ndim))
if new_order != expected_new_order:
return False
return [T.alloc(input_.owner.inputs[0], *node.inputs[1:])]
return False
@register_uncanonicalize
@gof.local_optimizer([T.Reshape])
def local_reshape_dimshuffle(node):
"""
If a dimshuffle is inside a reshape and does not change the order
of dimensions, remove it.
Reshape(Dimshuffle(x), shp) -> Reshape(x, shp)
"""
if isinstance(node.op, T.Reshape):
input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle):
new_order = input_.owner.op.new_order
offset = 0
for dim in new_order:
if dim == 'x':
continue
elif dim != offset:
return False
else:
offset += 1
return [T.reshape(input_.owner.inputs[0], node.inputs[1])]
return False
...@@ -6,9 +6,15 @@ import numpy ...@@ -6,9 +6,15 @@ import numpy
import theano import theano
from theano import function, config from theano import function, config
from theano import scalar from theano import scalar
from theano.gof import FunctionGraph
from theano.gof.opt import out2in
from theano.tensor.opt_uncanonicalize import (
local_alloc_dimshuffle,
local_reshape_dimshuffle
)
import theano.tensor as tensor import theano.tensor as tensor
#from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg #from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg
from theano.tensor.elemwise import CAReduce, Elemwise from theano.tensor.elemwise import CAReduce, Elemwise, DimShuffle
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -106,3 +112,36 @@ class T_min_max(unittest.TestCase): ...@@ -106,3 +112,36 @@ class T_min_max(unittest.TestCase):
assert len(topo) == 1 assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce) # max assert isinstance(topo[0].op, CAReduce) # max
f(data) f(data)
def test_local_alloc_dimshuffle():
alloc_dimshuffle = out2in(local_alloc_dimshuffle)
x = tensor.vector('x')
m = tensor.iscalar('m')
y = x.dimshuffle('x', 0)
out = tensor.alloc(y, m, 1, x.shape[0])
g = FunctionGraph([x, m], [out])
alloc_dimshuffle(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_reshape_dimshuffle():
reshape_dimshuffle = out2in(local_reshape_dimshuffle)
x = tensor.matrix('x')
y = x.dimshuffle('x', 0, 'x', 1)
out = tensor.reshape(y, (1, x.shape[0] * x.shape[1], 1))
g = FunctionGraph([x], [out])
reshape_dimshuffle(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论