提交 760405a8 authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Fixed local_reshape_dimshuffle

上级 f4e0933c
...@@ -137,11 +137,12 @@ def local_reshape_dimshuffle(node): ...@@ -137,11 +137,12 @@ def local_reshape_dimshuffle(node):
if input_.owner and isinstance(input_.owner.op, DimShuffle): if input_.owner and isinstance(input_.owner.op, DimShuffle):
new_order = input_.owner.op.new_order new_order = input_.owner.op.new_order
offset = 0 offset = 0
for i, dim in enumerate(new_order): for dim in new_order:
if dim == 'x': if dim == 'x':
offset += 1
continue continue
elif i != dim + offset: elif dim != offset:
return False return False
return [T.reshape(input_.owner.inputs[0], tuple(node.inputs[1].owner.inputs))] else:
offset += 1
return [T.reshape(input_.owner.inputs[0], node.inputs[1])]
return False return False
...@@ -7,7 +7,7 @@ import theano ...@@ -7,7 +7,7 @@ import theano
from theano import function, config from theano import function, config
from theano import scalar from theano import scalar
from theano.gof import FunctionGraph from theano.gof import FunctionGraph
from theano.tensor.gof.opt import out2in from theano.gof.opt import out2in
from theano.tensor.opt_uncanonicalize import ( from theano.tensor.opt_uncanonicalize import (
local_alloc_dimshuffle, local_alloc_dimshuffle,
local_reshape_dimshuffle local_reshape_dimshuffle
...@@ -118,11 +118,11 @@ def test_local_alloc_dimshuffle(): ...@@ -118,11 +118,11 @@ def test_local_alloc_dimshuffle():
alloc_dimshuffle = out2in(local_alloc_dimshuffle) alloc_dimshuffle = out2in(local_alloc_dimshuffle)
x = T.vector('x') x = tensor.vector('x')
m = T.iscalar('m') m = tensor.iscalar('m')
y = x.dimshuffle('x', 0) y = x.dimshuffle('x', 0)
out = T.alloc(y, m, 1, x.shape[0]) out = tensor.alloc(y, m, 1, x.shape[0])
g = FunctionGraph([x, m], [out]) g = FunctionGraph([x, m], [out])
alloc_dimshuffle(g) alloc_dimshuffle(g)
...@@ -135,10 +135,10 @@ def test_local_reshape_dimshuffle(): ...@@ -135,10 +135,10 @@ def test_local_reshape_dimshuffle():
reshape_dimshuffle = out2in(local_reshape_dimshuffle) reshape_dimshuffle = out2in(local_reshape_dimshuffle)
x = T.matrix('x') x = tensor.matrix('x')
y = x.dimshuffle('x', 0, 'x', 1) y = x.dimshuffle('x', 0, 'x', 1)
out = T.reshape(y, (1, x.shape[0] * x.shape[1], 1)) out = tensor.reshape(y, (1, x.shape[0] * x.shape[1], 1))
g = FunctionGraph([x], [out]) g = FunctionGraph([x], [out])
reshape_dimshuffle(g) reshape_dimshuffle(g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论