提交 e8bf9c1f authored 作者: sentient07's avatar sentient07

Few change to param of transform

上级 7e2ad39b
...@@ -1237,7 +1237,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1237,7 +1237,7 @@ class LocalOptGroup(LocalOptimizer):
""" """
def __init__(self, apply_all_opts=False, *optimizers): def __init__(self, *optimizers, **kwargs):
if len(optimizers) == 1 and isinstance(optimizers[0], list): if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB. # This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0]) optimizers = tuple(optimizers[0])
...@@ -1246,7 +1246,10 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1246,7 +1246,10 @@ class LocalOptGroup(LocalOptimizer):
for opt in optimizers) for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False) self.retains_inputs = all(getattr(opt, 'retains_inputs', False)
for opt in optimizers) for opt in optimizers)
self.apply_all_opts = apply_all_opts try:
self.apply_all_opts = kwargs['apply_all_opts']
except KeyError:
self.apply_all_opts = False
def __str__(self): def __str__(self):
return getattr(self, '__name__', return getattr(self, '__name__',
...@@ -1262,16 +1265,19 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1262,16 +1265,19 @@ class LocalOptGroup(LocalOptimizer):
return t return t
def transform(self, node): def transform(self, node):
repl = None repl = False
counter = 0
for opt in self.opts: for opt in self.opts:
repl = opt.transform(node) repl = opt.transform(node)
if repl: if repl:
if self.apply_all_opts is True: counter += 1
if self.apply_all_opts:
assert len(repl) == 1 assert len(repl) == 1
node = repl.owner node = repl[0].owner
continue continue
return repl return repl
if counter >=2:
print("No of times the node is optimized : " + str(counter))
return repl return repl
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
......
...@@ -22,7 +22,7 @@ from theano import gof ...@@ -22,7 +22,7 @@ from theano import gof
from theano.compat import izip from theano.compat import izip
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant from theano.gof import Variable, Constant
from theano.gof.opt import copy_stack_trace, in2out, out2in from theano.gof.opt import copy_stack_trace, in2out, out2in, LocalOptGroup
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.configparser import config from theano.configparser import config
...@@ -377,6 +377,7 @@ def register_useless(lopt, *tags, **kwargs): ...@@ -377,6 +377,7 @@ def register_useless(lopt, *tags, **kwargs):
return register return register
else: else:
name = kwargs.pop('name', None) or lopt.__name__ name = kwargs.pop('name', None) or lopt.__name__
compile.mode.local_useless.register(name, lopt, 'last', 'fast_run', compile.mode.local_useless.register(name, lopt, 'last', 'fast_run',
*tags, **kwargs) *tags, **kwargs)
return lopt return lopt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论