提交 760405a8 authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Fixed local_reshape_dimshuffle

上级 f4e0933c
......@@ -137,11 +137,12 @@ def local_reshape_dimshuffle(node):
if input_.owner and isinstance(input_.owner.op, DimShuffle):
new_order = input_.owner.op.new_order
offset = 0
for i, dim in enumerate(new_order):
for dim in new_order:
if dim == 'x':
offset += 1
continue
elif i != dim + offset:
elif dim != offset:
return False
return [T.reshape(input_.owner.inputs[0], tuple(node.inputs[1].owner.inputs))]
else:
offset += 1
return [T.reshape(input_.owner.inputs[0], node.inputs[1])]
return False
......@@ -7,7 +7,7 @@ import theano
from theano import function, config
from theano import scalar
from theano.gof import FunctionGraph
from theano.tensor.gof.opt import out2in
from theano.gof.opt import out2in
from theano.tensor.opt_uncanonicalize import (
local_alloc_dimshuffle,
local_reshape_dimshuffle
......@@ -118,11 +118,11 @@ def test_local_alloc_dimshuffle():
alloc_dimshuffle = out2in(local_alloc_dimshuffle)
x = T.vector('x')
m = T.iscalar('m')
x = tensor.vector('x')
m = tensor.iscalar('m')
y = x.dimshuffle('x', 0)
out = T.alloc(y, m, 1, x.shape[0])
out = tensor.alloc(y, m, 1, x.shape[0])
g = FunctionGraph([x, m], [out])
alloc_dimshuffle(g)
......@@ -135,10 +135,10 @@ def test_local_reshape_dimshuffle():
reshape_dimshuffle = out2in(local_reshape_dimshuffle)
x = T.matrix('x')
x = tensor.matrix('x')
y = x.dimshuffle('x', 0, 'x', 1)
out = T.reshape(y, (1, x.shape[0] * x.shape[1], 1))
out = tensor.reshape(y, (1, x.shape[0] * x.shape[1], 1))
g = FunctionGraph([x], [out])
reshape_dimshuffle(g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论