提交 9bc2a2f5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove unnecessary checks and unused variable in Scan rewrites

上级 2ce0ce13
...@@ -220,9 +220,6 @@ def scan_push_out_non_seq(fgraph, node): ...@@ -220,9 +220,6 @@ def scan_push_out_non_seq(fgraph, node):
it to the outer function to be executed only once, before the `Scan` `Op`, it to the outer function to be executed only once, before the `Scan` `Op`,
reduces the amount of computation that needs to be performed. reduces the amount of computation that needs to be performed.
""" """
if not isinstance(node.op, Scan):
return False
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs) local_fgraph_topo = io_toposort(node_inputs, node_outputs)
...@@ -430,9 +427,6 @@ def scan_push_out_seq(fgraph, node): ...@@ -430,9 +427,6 @@ def scan_push_out_seq(fgraph, node):
many times on many smaller tensors. In many cases, this optimization can many times on many smaller tensors. In many cases, this optimization can
increase memory usage but, in some specific cases, it can also decrease it. increase memory usage but, in some specific cases, it can also decrease it.
""" """
if not isinstance(node.op, Scan):
return False
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs) local_fgraph_topo = io_toposort(node_inputs, node_outputs)
...@@ -696,7 +690,6 @@ def push_out_inner_vars( ...@@ -696,7 +690,6 @@ def push_out_inner_vars(
old_scan_args: ScanArgs, old_scan_args: ScanArgs,
) -> tuple[list[Variable], ScanArgs, dict[Variable, Variable]]: ) -> tuple[list[Variable], ScanArgs, dict[Variable, Variable]]:
tmp_outer_vars: list[Variable | None] = [] tmp_outer_vars: list[Variable | None] = []
new_scan_node = old_scan_node
new_scan_args = old_scan_args new_scan_args = old_scan_args
replacements: dict[Variable, Variable] = {} replacements: dict[Variable, Variable] = {}
...@@ -843,10 +836,11 @@ def scan_push_out_add(fgraph, node): ...@@ -843,10 +836,11 @@ def scan_push_out_add(fgraph, node):
# Don't perform the optimization on `as_while` `Scan`s. Because these # Don't perform the optimization on `as_while` `Scan`s. Because these
# `Scan`s don't run for a predetermined number of steps, handling them is # `Scan`s don't run for a predetermined number of steps, handling them is
# more complicated and this optimization doesn't support it at the moment. # more complicated and this optimization doesn't support it at the moment.
if not (isinstance(node.op, Scan) and not node.op.info.as_while): op = node.op
if op.info.as_while:
return False return False
op = node.op # apply_ancestors(args.inner_outputs)
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of # Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use # use
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论