提交 e8bf9c1f authored 作者: sentient07's avatar sentient07

Few change to param of transform

上级 7e2ad39b
......@@ -1237,7 +1237,7 @@ class LocalOptGroup(LocalOptimizer):
"""
def __init__(self, apply_all_opts=False, *optimizers):
def __init__(self, *optimizers, **kwargs):
if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0])
......@@ -1246,7 +1246,10 @@ class LocalOptGroup(LocalOptimizer):
for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False)
for opt in optimizers)
self.apply_all_opts = apply_all_opts
try:
self.apply_all_opts = kwargs['apply_all_opts']
except KeyError:
self.apply_all_opts = False
def __str__(self):
return getattr(self, '__name__',
......@@ -1262,16 +1265,19 @@ class LocalOptGroup(LocalOptimizer):
return t
def transform(self, node):
repl = None
repl = False
counter = 0
for opt in self.opts:
repl = opt.transform(node)
if repl:
if self.apply_all_opts is True:
counter += 1
if self.apply_all_opts:
assert len(repl) == 1
node = repl.owner
node = repl[0].owner
continue
return repl
if counter >=2:
print("No of times the node is optimized : " + str(counter))
return repl
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
......
......@@ -22,7 +22,7 @@ from theano import gof
from theano.compat import izip
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant
from theano.gof.opt import copy_stack_trace, in2out, out2in
from theano.gof.opt import copy_stack_trace, in2out, out2in, LocalOptGroup
from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType
from theano.configparser import config
......@@ -377,6 +377,7 @@ def register_useless(lopt, *tags, **kwargs):
return register
else:
name = kwargs.pop('name', None) or lopt.__name__
compile.mode.local_useless.register(name, lopt, 'last', 'fast_run',
*tags, **kwargs)
return lopt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论