提交 0cf0187d authored 作者: Frederic Bastien's avatar Frederic Bastien

Mark more scan optimization as removing the scan.

上级 8c58dfb8
...@@ -1782,7 +1782,10 @@ class NavigatorOptimizer(Optimizer): ...@@ -1782,7 +1782,10 @@ class NavigatorOptimizer(Optimizer):
if replacements is False or replacements is None: if replacements is False or replacements is None:
return False return False
old_vars = node.outputs old_vars = node.outputs
remove = []
if isinstance(replacements, dict): if isinstance(replacements, dict):
if "remove" in replacements:
remove = replacements.pop("remove")
old_vars = list(replacements.keys()) old_vars = list(replacements.keys())
replacements = list(replacements.values()) replacements = list(replacements.values())
elif not isinstance(replacements, (tuple, list)): elif not isinstance(replacements, (tuple, list)):
...@@ -1805,7 +1808,9 @@ class NavigatorOptimizer(Optimizer): ...@@ -1805,7 +1808,9 @@ class NavigatorOptimizer(Optimizer):
if len(repl_pairs) == 0: if len(repl_pairs) == 0:
return False return False
try: try:
fgraph.replace_all_validate(repl_pairs, reason=lopt) fgraph.replace_all_validate_remove(repl_pairs,
reason=lopt,
remove=remove)
return True return True
except Exception as e: except Exception as e:
# This means the replacements were rejected by the fgraph. # This means the replacements were rejected by the fgraph.
......
...@@ -202,7 +202,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -202,7 +202,7 @@ def remove_constants_and_unused_inputs_scan(node):
# DEBUG CHECK # DEBUG CHECK
nwScan = scan_op.Scan(nw_inner, op_outs, nw_info) nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
nw_outs = nwScan(*nw_outer, **dict(return_list=True)) nw_outs = nwScan(*nw_outer, **dict(return_list=True))
return nw_outs return dict([("remove", [node])] + zip(node.outputs, nw_outs))
else: else:
return False return False
...@@ -1964,8 +1964,10 @@ def scan_merge_inouts(node): ...@@ -1964,8 +1964,10 @@ def scan_merge_inouts(node):
outputs = [outputs] outputs = [outputs]
na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info) na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info)
remove = [node]
else: else:
na = a na = a
remove = []
# Now that the identical external inputs have been merged, we do a new # Now that the identical external inputs have been merged, we do a new
# loop in order to merge external outputs that compute the same things # loop in order to merge external outputs that compute the same things
...@@ -2070,6 +2072,7 @@ def scan_merge_inouts(node): ...@@ -2070,6 +2072,7 @@ def scan_merge_inouts(node):
new_outer_out_mit_mot.append(outer_omm) new_outer_out_mit_mot.append(outer_omm)
na.outer_out_mit_mot = new_outer_out_mit_mot na.outer_out_mit_mot = new_outer_out_mit_mot
return dict([("remove", remove)] + zip(node.outputs, na.outer_outputs))
return na.outer_outputs return na.outer_outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论