提交 aba7b816 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: sentient07

Move in2out and out2in

上级 52812079
...@@ -1964,6 +1964,44 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -1964,6 +1964,44 @@ class TopoOptimizer(NavigatorOptimizer):
'<TopoOptimizer instance>') '<TopoOptimizer instance>')
def out2in(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = LocalOptGroup(*local_opts)
else:
local_opts, = local_opts
if not name:
name = local_opts.__name__
ret = TopoOptimizer(local_opts,
order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
def in2out(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = LocalOptGroup(*local_opts)
else:
local_opts, = local_opts
if not name:
name = local_opts.__name__
ret = TopoOptimizer(local_opts,
order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
class OpKeyOptimizer(NavigatorOptimizer): class OpKeyOptimizer(NavigatorOptimizer):
""" """
WRITEME WRITEME
......
...@@ -22,7 +22,7 @@ from theano import gof ...@@ -22,7 +22,7 @@ from theano import gof
from theano.compat import izip from theano.compat import izip
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant from theano.gof import Variable, Constant
from theano.gof.opt import copy_stack_trace from theano.gof.opt import copy_stack_trace, in2out, out2in
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.configparser import config from theano.configparser import config
...@@ -56,45 +56,6 @@ _logger = logging.getLogger('theano.tensor.opt') ...@@ -56,45 +56,6 @@ _logger = logging.getLogger('theano.tensor.opt')
# Utilities # Utilities
def out2in(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts)
else:
local_opts, = local_opts
if not name:
name = local_opts.__name__
ret = opt.TopoOptimizer(local_opts,
order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
def in2out(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts)
else:
local_opts, = local_opts
if not name:
name = local_opts.__name__
ret = opt.TopoOptimizer(local_opts,
order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
def _fill_chain(new_out, orig_inputs): def _fill_chain(new_out, orig_inputs):
for i in orig_inputs: for i in orig_inputs:
new_out = T.fill(i, new_out) new_out = T.fill(i, new_out)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论