提交 4948903d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Harmonize Scan rewrite and tag names

上级 20b6a20c
......@@ -184,7 +184,7 @@ def numba_funcify_Scan(op, node, **kwargs):
# rotation for initially truncated storage.
output_storage_post_proc_stmts: list[str] = []
# In truncated storage situations (e.g. created by `save_mem_new_scan`),
# In truncated storage situations (e.g. created by `scan_save_mem`),
# the taps and output storage overlap, instead of the standard situation in
# which the output storage is large enough to contain both the initial taps
# values and the output storage. In this truncated case, we use the
......
......@@ -209,7 +209,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
@node_rewriter([Scan])
def push_out_non_seq_scan(fgraph, node):
def scan_push_out_non_seq(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on non-sequences.
This optimizations pushes, out of `Scan`'s inner function and into the outer
......@@ -417,10 +417,10 @@ def push_out_non_seq_scan(fgraph, node):
@node_rewriter([Scan])
def push_out_seq_scan(fgraph, node):
def scan_push_out_seq(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
This optimization resembles `push_out_non_seq_scan` but it tries to push--out of
This optimization resembles `scan_push_out_non_seq` but it tries to push--out of
the inner function--the computation that only relies on sequence and
non-sequence inputs. The idea behind this optimization is that, when it is
possible to do so, it is generally more computationally efficient to perform
......@@ -822,10 +822,10 @@ def add_nitsot_outputs(
@node_rewriter([Scan])
def push_out_add_scan(fgraph, node):
def scan_push_out_add(fgraph, node):
r"""Push `Add` operations performed at the end of the inner graph to the outside.
Like `push_out_seq_scan`, this optimization aims to replace many operations
Like `scan_push_out_seq`, this optimization aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to
increased memory usage.
"""
......@@ -1185,7 +1185,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
@node_rewriter([Scan])
def save_mem_new_scan(fgraph, node):
def scan_save_mem(fgraph, node):
r"""Graph optimizer that reduces scan memory consumption.
This optimizations attempts to determine if a `Scan` node, during its execution,
......@@ -2282,7 +2282,7 @@ def scan_merge_inouts(fgraph, node):
@node_rewriter([Scan])
def push_out_dot1_scan(fgraph, node):
def scan_push_out_dot1(fgraph, node):
r"""
This is another optimization that attempts to detect certain patterns of
computation in a `Scan` `Op`'s inner function and move this computation to the
......@@ -2483,7 +2483,7 @@ optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
# ScanSaveMem should execute only once per node.
optdb.register(
"scan_save_mem",
in2out(save_mem_new_scan, ignore_newtrees=True),
in2out(scan_save_mem, ignore_newtrees=True),
"fast_run",
"scan",
position=1.61,
......@@ -2511,8 +2511,9 @@ scan_seqopt1.register(
scan_seqopt1.register(
"scan_pushout_nonseqs_ops",
in2out(push_out_non_seq_scan, ignore_newtrees=True),
"scan_push_out_non_seq",
in2out(scan_push_out_non_seq, ignore_newtrees=True),
"scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name
"fast_run",
"scan",
"scan_pushout",
......@@ -2521,8 +2522,9 @@ scan_seqopt1.register(
scan_seqopt1.register(
"scan_pushout_seqs_ops",
in2out(push_out_seq_scan, ignore_newtrees=True),
"scan_push_out_seq",
in2out(scan_push_out_seq, ignore_newtrees=True),
"scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name
"fast_run",
"scan",
"scan_pushout",
......@@ -2531,8 +2533,9 @@ scan_seqopt1.register(
scan_seqopt1.register(
"scan_pushout_dot1",
in2out(push_out_dot1_scan, ignore_newtrees=True),
"scan_push_out_dot1",
in2out(scan_push_out_dot1, ignore_newtrees=True),
"scan_pushout_dot1", # For backcompat: so it can be tagged with old name
"fast_run",
"more_mem",
"scan",
......@@ -2542,9 +2545,10 @@ scan_seqopt1.register(
scan_seqopt1.register(
"scan_pushout_add",
"scan_push_out_add",
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
in2out(push_out_add_scan, ignore_newtrees=False),
in2out(scan_push_out_add, ignore_newtrees=False),
"scan_pushout_add", # For backcompat: so it can be tagged with old name
"fast_run",
"more_mem",
"scan",
......
......@@ -304,7 +304,7 @@ class TestPushOutDot:
class TestPushOutNonSeqScan:
"""
Tests for the `push_out_non_seq_scan` optimization in the case where the inner
Tests for the `scan_push_out_non_seq` optimization in the case where the inner
function of a `Scan` `Op` has an output which is the result of a `Dot` product
on a non-sequence matrix input to `Scan` and a vector that is the result of
computation in the inner function.
......@@ -595,7 +595,7 @@ class TestPushOutNonSeqScan:
class TestPushOutAddScan:
"""
Test case for the `push_out_add_scan` optimization in the case where the `Scan`
Test case for the `scan_push_out_add` optimization in the case where the `Scan`
is used to compute the sum over the dot products between the corresponding
elements of two list of matrices.
......@@ -1208,7 +1208,7 @@ class TestScanInplaceOptimizer:
class TestSaveMem:
mode = get_default_mode().including("scan_save_mem", "save_mem_new_scan")
mode = get_default_mode().including("scan_save_mem", "scan_save_mem")
def test_save_mem(self):
rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论