提交 204bab09 authored 作者: sentient07's avatar sentient07

Temp fix for nodes without fgraph feature

上级 05573157
...@@ -1242,14 +1242,15 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1242,14 +1242,15 @@ class LocalOptGroup(LocalOptimizer):
# This happen when created by LocalGroupDB. # This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0]) optimizers = tuple(optimizers[0])
self.opts = optimizers self.opts = optimizers
assert isinstance(self.opts, tuple)
self.reentrant = any(getattr(opt, 'reentrant', True) self.reentrant = any(getattr(opt, 'reentrant', True)
for opt in optimizers) for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False) self.retains_inputs = all(getattr(opt, 'retains_inputs', False)
for opt in optimizers) for opt in optimizers)
try:
self.apply_all_opts = kwargs['apply_all_opts'] self.apply_all_opts = kwargs.get('apply_all_opts', False)
except KeyError:
self.apply_all_opts = False
def __str__(self): def __str__(self):
return getattr(self, '__name__', return getattr(self, '__name__',
...@@ -1265,13 +1266,10 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1265,13 +1266,10 @@ class LocalOptGroup(LocalOptimizer):
return t return t
def transform(self, node): def transform(self, node):
repl = False if len(self.opts) == 0:
counter = 0 return
def apply_mult_opts(opt_list, node, single_opts=True): def apply_mult_opts(opt_list, node, single_opts=True):
repl = False repl = False
assert isinstance(opt_list, tuple)
if len(opt_list) == 0:
return
for opt in opt_list: for opt in opt_list:
repl = opt.transform(node) repl = opt.transform(node)
if not repl: if not repl:
...@@ -1283,8 +1281,9 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1283,8 +1281,9 @@ class LocalOptGroup(LocalOptimizer):
# Ensuring not the input of graph # Ensuring not the input of graph
assert repl[0].owner assert repl[0].owner
new_node = repl[0].owner new_node = repl[0].owner
apply_mult_opts(opt_list, repl[0].owner, False) if not getattr(new_node, 'fgraph', None):
continue
apply_mult_opts(opt_list, new_node, False)
return repl return repl
return apply_mult_opts(self.opts, node, self.apply_all_opts) return apply_mult_opts(self.opts, node, self.apply_all_opts)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论