提交 25ffd68b authored 作者: sentient07's avatar sentient07

Added fix for new nodes not the part of graph

上级 9d0f2bae
...@@ -1268,20 +1268,20 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1268,20 +1268,20 @@ class LocalOptGroup(LocalOptimizer):
def transform(self, node): def transform(self, node):
if len(self.opts) == 0: if len(self.opts) == 0:
return return
def apply_mult_opts(opt_list, node, single_opts=True): def apply_mult_opts(opt_list, node, multiple_opts=False):
repl = False repl = False
for opt in opt_list: for opt in opt_list:
repl = opt.transform(node) repl = opt.transform(node)
if not repl: if not repl:
continue continue
else: else:
if single_opts or not repl[0].owner: if not multiple_opts or not repl[0].owner:
return repl return repl
assert len(repl) == 1 assert len(repl) == 1
# Ensuring not the input of graph # Ensuring not the input of graph
assert repl[0].owner assert repl[0].owner
new_node = repl[0].owner new_node = repl[0].owner
apply_mult_opts(opt_list, new_node, False) apply_mult_opts(opt_list, new_node, True)
return repl return repl
return apply_mult_opts(self.opts, node, self.apply_all_opts) return apply_mult_opts(self.opts, node, self.apply_all_opts)
......
...@@ -2539,6 +2539,11 @@ def local_useless_subtensor(node): ...@@ -2539,6 +2539,11 @@ def local_useless_subtensor(node):
list/vector or the ARange op. list/vector or the ARange op.
""" """
# If the optimization is tried over a node that is not a part of graph before
if not hasattr(node, 'fgraph'):
return
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(node.fgraph, 'shape_feature'): if not hasattr(node.fgraph, 'shape_feature'):
return return
...@@ -4812,6 +4817,10 @@ class Canonizer(gof.LocalOptimizer): ...@@ -4812,6 +4817,10 @@ class Canonizer(gof.LocalOptimizer):
assert len(node.outputs) == 1 assert len(node.outputs) == 1
out = node.outputs[0] out = node.outputs[0]
# Condition for replacement variable not being a part of the graph
if not hasattr(out, 'clients'):
return False
# check if any of the clients of this node would be part of # check if any of the clients of this node would be part of
# this canonized graph... if so, we do nothing and wait for # this canonized graph... if so, we do nothing and wait for
# them to be transformed. # them to be transformed.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论