提交 f4e0933c authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Fix out2in import on test and added examples on docstring

上级 9164239a
......@@ -107,6 +107,8 @@ def local_alloc_dimshuffle(node):
"""
If a dimshuffle is inside an alloc and only adds dimension to the
left, remove it.
Alloc(DimShuffle(x), ...) - > Alloc(x, ...)
"""
if node.op == T.alloc:
input_ = node.inputs[0]
......@@ -114,7 +116,7 @@ def local_alloc_dimshuffle(node):
# check if it only adds dimension to the left
new_order = input_.owner.op.new_order
expected_new_order = ('x',) * (input_.ndim - input_.owner.inputs[0].ndim) + \
tuple(range(input_.owner.inputs[0].ndim))
tuple(range(input_.owner.inputs[0].ndim))
if new_order != expected_new_order:
return False
return [T.alloc(input_.owner.inputs[0], *node.inputs[1:])]
......@@ -127,6 +129,8 @@ def local_reshape_dimshuffle(node):
"""
If a dimshuffle is inside a reshape and does not change the order
of dimensions, remove it.
Reshape(Dimshuffle(x), shp) -> Reshape(x, shp)
"""
if isinstance(node.op, T.Reshape):
input_ = node.inputs[0]
......
......@@ -7,7 +7,7 @@ import theano
from theano import function, config
from theano import scalar
from theano.gof import FunctionGraph
from theano.tensor.opt import out2in
from theano.tensor.gof.opt import out2in
from theano.tensor.opt_uncanonicalize import (
local_alloc_dimshuffle,
local_reshape_dimshuffle
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论