提交 aba7b816 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: sentient07

Move in2out and out2in

上级 52812079
......@@ -1964,6 +1964,44 @@ class TopoOptimizer(NavigatorOptimizer):
'<TopoOptimizer instance>')
def out2in(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = LocalOptGroup(*local_opts)
else:
local_opts, = local_opts
if not name:
name = local_opts.__name__
ret = TopoOptimizer(local_opts,
order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
def in2out(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = LocalOptGroup(*local_opts)
else:
local_opts, = local_opts
if not name:
name = local_opts.__name__
ret = TopoOptimizer(local_opts,
order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
class OpKeyOptimizer(NavigatorOptimizer):
"""
WRITEME
......
......@@ -22,7 +22,7 @@ from theano import gof
from theano.compat import izip
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant
from theano.gof.opt import copy_stack_trace
from theano.gof.opt import copy_stack_trace, in2out, out2in
from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType
from theano.configparser import config
......@@ -56,45 +56,6 @@ _logger = logging.getLogger('theano.tensor.opt')
# Utilities
def out2in(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts)
else:
local_opts, = local_opts
if not name:
name = local_opts.__name__
ret = opt.TopoOptimizer(local_opts,
order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
def in2out(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts)
else:
local_opts, = local_opts
if not name:
name = local_opts.__name__
ret = opt.TopoOptimizer(local_opts,
order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
def _fill_chain(new_out, orig_inputs):
for i in orig_inputs:
new_out = T.fill(i, new_out)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论