提交 cbae058c authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Opts had errors in their return value and implementation of tests

上级 c571066f
......@@ -114,10 +114,10 @@ def local_alloc_dimshuffle(node):
# check if it only adds dimension to the left
new_order = input_.owner.op.new_order
expected_new_order = ('x',) * (input_.ndim - input_.owner.inputs[0].ndim) + \
tuple(range(input_.owner.inputs.ndim))
tuple(range(input_.owner.inputs[0].ndim))
if new_order != expected_new_order:
return False
return input_.owner.inputs
return [T.alloc(input_.owner.inputs[0], *node.inputs[1:])]
return False
......@@ -128,10 +128,10 @@ def local_reshape_dimshuffle(node):
If a dimshuffle is inside a reshape and does not change the order
of dimensions, remove it.
"""
if node.op == T.reshape:
if isinstance(node.op, T.Reshape):
input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle):
new_order = input_owner.op.new_order
new_order = input_.owner.op.new_order
offset = 0
for i, dim in enumerate(new_order):
if dim == 'x':
......@@ -139,5 +139,5 @@ def local_reshape_dimshuffle(node):
continue
elif i != dim + offset:
return False
return input_.owner.inputs
return [T.reshape(input_.owner.inputs[0], tuple(node.inputs[1].owner.inputs))]
return False
......@@ -6,9 +6,15 @@ import numpy
import theano
from theano import function, config
from theano import scalar
from theano.gof import FunctionGraph
from theano.tensor.opt import out2in
from theano.tensor.opt_uncanonicalize import (
local_alloc_dimshuffle,
local_reshape_dimshuffle
)
import theano.tensor as tensor
#from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg
from theano.tensor.elemwise import CAReduce, Elemwise
from theano.tensor.elemwise import CAReduce, Elemwise, DimShuffle
from theano.tests import unittest_tools as utt
......@@ -106,3 +112,36 @@ class T_min_max(unittest.TestCase):
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce) # max
f(data)
def test_local_alloc_dimshuffle():
alloc_dimshuffle = out2in(local_alloc_dimshuffle)
x = T.vector('x')
m = T.iscalar('m')
out = x.dimshuffle('x', 0)
out = T.alloc(x, m, 1, x.shape[0])
g = FunctionGraph(out)
alloc_dimshuffle(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_reshape_dimshuffle():
reshape_dimshuffle = out2in(local_reshape_dimshuffle)
x = T.matrix('x')
y = x.dimshuffle('x', 0, 'x', 1)
out = T.reshape(y, (1, x.shape[0] * x.shape[1], 1))
g = FunctionGraph(out)
reshape_dimshuffle(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论