提交 cbae058c authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Opts had errors in their return value and implementation of tests

上级 c571066f
...@@ -114,10 +114,10 @@ def local_alloc_dimshuffle(node): ...@@ -114,10 +114,10 @@ def local_alloc_dimshuffle(node):
# check if it only adds dimension to the left # check if it only adds dimension to the left
new_order = input_.owner.op.new_order new_order = input_.owner.op.new_order
expected_new_order = ('x',) * (input_.ndim - input_.owner.inputs[0].ndim) + \ expected_new_order = ('x',) * (input_.ndim - input_.owner.inputs[0].ndim) + \
tuple(range(input_.owner.inputs.ndim)) tuple(range(input_.owner.inputs[0].ndim))
if new_order != expected_new_order: if new_order != expected_new_order:
return False return False
return input_.owner.inputs return [T.alloc(input_.owner.inputs[0], *node.inputs[1:])]
return False return False
...@@ -128,10 +128,10 @@ def local_reshape_dimshuffle(node): ...@@ -128,10 +128,10 @@ def local_reshape_dimshuffle(node):
If a dimshuffle is inside a reshape and does not change the order If a dimshuffle is inside a reshape and does not change the order
of dimensions, remove it. of dimensions, remove it.
""" """
if node.op == T.reshape: if isinstance(node.op, T.Reshape):
input_ = node.inputs[0] input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle): if input_.owner and isinstance(input_.owner.op, DimShuffle):
new_order = input_owner.op.new_order new_order = input_.owner.op.new_order
offset = 0 offset = 0
for i, dim in enumerate(new_order): for i, dim in enumerate(new_order):
if dim == 'x': if dim == 'x':
...@@ -139,5 +139,5 @@ def local_reshape_dimshuffle(node): ...@@ -139,5 +139,5 @@ def local_reshape_dimshuffle(node):
continue continue
elif i != dim + offset: elif i != dim + offset:
return False return False
return input_.owner.inputs return [T.reshape(input_.owner.inputs[0], tuple(node.inputs[1].owner.inputs))]
return False return False
...@@ -6,9 +6,15 @@ import numpy ...@@ -6,9 +6,15 @@ import numpy
import theano import theano
from theano import function, config from theano import function, config
from theano import scalar from theano import scalar
from theano.gof import FunctionGraph
from theano.tensor.opt import out2in
from theano.tensor.opt_uncanonicalize import (
local_alloc_dimshuffle,
local_reshape_dimshuffle
)
import theano.tensor as tensor import theano.tensor as tensor
#from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg #from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg
from theano.tensor.elemwise import CAReduce, Elemwise from theano.tensor.elemwise import CAReduce, Elemwise, DimShuffle
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -106,3 +112,36 @@ class T_min_max(unittest.TestCase): ...@@ -106,3 +112,36 @@ class T_min_max(unittest.TestCase):
assert len(topo) == 1 assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce) # max assert isinstance(topo[0].op, CAReduce) # max
f(data) f(data)
def test_local_alloc_dimshuffle():
alloc_dimshuffle = out2in(local_alloc_dimshuffle)
x = T.vector('x')
m = T.iscalar('m')
out = x.dimshuffle('x', 0)
out = T.alloc(x, m, 1, x.shape[0])
g = FunctionGraph(out)
alloc_dimshuffle(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_reshape_dimshuffle():
reshape_dimshuffle = out2in(local_reshape_dimshuffle)
x = T.matrix('x')
y = x.dimshuffle('x', 0, 'x', 1)
out = T.reshape(y, (1, x.shape[0] * x.shape[1], 1))
g = FunctionGraph(out)
reshape_dimshuffle(g)
topo = g.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论