提交 e51289c9 authored 作者: Frederic's avatar Frederic

make scan opt assert they remove the old scan object.

上级 89aaf210
...@@ -279,8 +279,10 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -279,8 +279,10 @@ class PushOutNonSeqScan(gof.Optimizer):
# Reconstruct node # Reconstruct node
nwScan = scan_op.Scan(op_ins, op_outs, op.info) nwScan = scan_op.Scan(op_ins, op_outs, op.info)
nw_node = nwScan.make_node(* (node.inputs + nw_outer)) nw_node = nwScan.make_node(* (node.inputs + nw_outer))
fgraph.replace_all_validate(zip(node.outputs, nw_node.outputs), fgraph.replace_all_validate_remove(
reason='scan_push_computation_out') zip(node.outputs, nw_node.outputs),
remove=[node],
reason='scan_push_computation_out')
return True return True
elif to_keep == []: elif to_keep == []:
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
...@@ -358,8 +360,9 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -358,8 +360,9 @@ class ScanInplaceOptimizer(Optimizer):
new_outs = new_op.make_node(*inputs).outputs new_outs = new_op.make_node(*inputs).outputs
try: try:
fgraph.replace_all_validate( fgraph.replace_all_validate_remove(
zip(node.outputs, new_outs), zip(node.outputs, new_outs),
remove=[node],
reason=self.__class__.__name__) reason=self.__class__.__name__)
op = new_op op = new_op
node = new_outs[0].owner node = new_outs[0].owner
...@@ -826,7 +829,9 @@ class ScanSaveMem(gof.Optimizer): ...@@ -826,7 +829,9 @@ class ScanSaveMem(gof.Optimizer):
nw_pos = compress_map[idx] nw_pos = compress_map[idx]
old_new += [(o, new_outs[nw_pos])] old_new += [(o, new_outs[nw_pos])]
fgraph.replace_all_validate(old_new, reason='scan_save_mem') fgraph.replace_all_validate_remove(old_new,
remove=[node],
reason='scan_save_mem')
def apply(self, fgraph): def apply(self, fgraph):
...@@ -1021,7 +1026,9 @@ class ScanMerge(gof.Optimizer): ...@@ -1021,7 +1026,9 @@ class ScanMerge(gof.Optimizer):
for subset in all_sets: for subset in all_sets:
if len(subset) > 1: if len(subset) > 1:
proposal = self.merge(subset) proposal = self.merge(subset)
fgraph.replace_all_validate(proposal, reason='scan_merge') fgraph.replace_all_validate_remove(proposal,
remove=subset,
reason='scan_merge')
# after const merge but before stabilize so that we can have identity # after const merge but before stabilize so that we can have identity
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论