提交 4948903d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Harmonize Scan rewrite and tag names

上级 20b6a20c
...@@ -184,7 +184,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -184,7 +184,7 @@ def numba_funcify_Scan(op, node, **kwargs):
# rotation for initially truncated storage. # rotation for initially truncated storage.
output_storage_post_proc_stmts: list[str] = [] output_storage_post_proc_stmts: list[str] = []
# In truncated storage situations (e.g. created by `save_mem_new_scan`), # In truncated storage situations (e.g. created by `scan_save_mem`),
# the taps and output storage overlap, instead of the standard situation in # the taps and output storage overlap, instead of the standard situation in
# which the output storage is large enough to contain both the initial taps # which the output storage is large enough to contain both the initial taps
# values and the output storage. In this truncated case, we use the # values and the output storage. In this truncated case, we use the
......
...@@ -209,7 +209,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -209,7 +209,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
@node_rewriter([Scan]) @node_rewriter([Scan])
def push_out_non_seq_scan(fgraph, node): def scan_push_out_non_seq(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on non-sequences. r"""Push out the variables inside the `Scan` that depend only on non-sequences.
This optimizations pushes, out of `Scan`'s inner function and into the outer This optimizations pushes, out of `Scan`'s inner function and into the outer
...@@ -417,10 +417,10 @@ def push_out_non_seq_scan(fgraph, node): ...@@ -417,10 +417,10 @@ def push_out_non_seq_scan(fgraph, node):
@node_rewriter([Scan]) @node_rewriter([Scan])
def push_out_seq_scan(fgraph, node): def scan_push_out_seq(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on constants and sequences. r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
This optimization resembles `push_out_non_seq_scan` but it tries to push--out of This optimization resembles `scan_push_out_non_seq` but it tries to push--out of
the inner function--the computation that only relies on sequence and the inner function--the computation that only relies on sequence and
non-sequence inputs. The idea behind this optimization is that, when it is non-sequence inputs. The idea behind this optimization is that, when it is
possible to do so, it is generally more computationally efficient to perform possible to do so, it is generally more computationally efficient to perform
...@@ -822,10 +822,10 @@ def add_nitsot_outputs( ...@@ -822,10 +822,10 @@ def add_nitsot_outputs(
@node_rewriter([Scan]) @node_rewriter([Scan])
def push_out_add_scan(fgraph, node): def scan_push_out_add(fgraph, node):
r"""Push `Add` operations performed at the end of the inner graph to the outside. r"""Push `Add` operations performed at the end of the inner graph to the outside.
Like `push_out_seq_scan`, this optimization aims to replace many operations Like `scan_push_out_seq`, this optimization aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to on small tensors by few operations on large tensors. It can also lead to
increased memory usage. increased memory usage.
""" """
...@@ -1185,7 +1185,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node): ...@@ -1185,7 +1185,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
@node_rewriter([Scan]) @node_rewriter([Scan])
def save_mem_new_scan(fgraph, node): def scan_save_mem(fgraph, node):
r"""Graph optimizer that reduces scan memory consumption. r"""Graph optimizer that reduces scan memory consumption.
This optimizations attempts to determine if a `Scan` node, during its execution, This optimizations attempts to determine if a `Scan` node, during its execution,
...@@ -2282,7 +2282,7 @@ def scan_merge_inouts(fgraph, node): ...@@ -2282,7 +2282,7 @@ def scan_merge_inouts(fgraph, node):
@node_rewriter([Scan]) @node_rewriter([Scan])
def push_out_dot1_scan(fgraph, node): def scan_push_out_dot1(fgraph, node):
r""" r"""
This is another optimization that attempts to detect certain patterns of This is another optimization that attempts to detect certain patterns of
computation in a `Scan` `Op`'s inner function and move this computation to the computation in a `Scan` `Op`'s inner function and move this computation to the
...@@ -2483,7 +2483,7 @@ optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6) ...@@ -2483,7 +2483,7 @@ optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
# ScanSaveMem should execute only once per node. # ScanSaveMem should execute only once per node.
optdb.register( optdb.register(
"scan_save_mem", "scan_save_mem",
in2out(save_mem_new_scan, ignore_newtrees=True), in2out(scan_save_mem, ignore_newtrees=True),
"fast_run", "fast_run",
"scan", "scan",
position=1.61, position=1.61,
...@@ -2511,8 +2511,9 @@ scan_seqopt1.register( ...@@ -2511,8 +2511,9 @@ scan_seqopt1.register(
scan_seqopt1.register( scan_seqopt1.register(
"scan_pushout_nonseqs_ops", "scan_push_out_non_seq",
in2out(push_out_non_seq_scan, ignore_newtrees=True), in2out(scan_push_out_non_seq, ignore_newtrees=True),
"scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name
"fast_run", "fast_run",
"scan", "scan",
"scan_pushout", "scan_pushout",
...@@ -2521,8 +2522,9 @@ scan_seqopt1.register( ...@@ -2521,8 +2522,9 @@ scan_seqopt1.register(
scan_seqopt1.register( scan_seqopt1.register(
"scan_pushout_seqs_ops", "scan_push_out_seq",
in2out(push_out_seq_scan, ignore_newtrees=True), in2out(scan_push_out_seq, ignore_newtrees=True),
"scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name
"fast_run", "fast_run",
"scan", "scan",
"scan_pushout", "scan_pushout",
...@@ -2531,8 +2533,9 @@ scan_seqopt1.register( ...@@ -2531,8 +2533,9 @@ scan_seqopt1.register(
scan_seqopt1.register( scan_seqopt1.register(
"scan_pushout_dot1", "scan_push_out_dot1",
in2out(push_out_dot1_scan, ignore_newtrees=True), in2out(scan_push_out_dot1, ignore_newtrees=True),
"scan_pushout_dot1", # For backcompat: so it can be tagged with old name
"fast_run", "fast_run",
"more_mem", "more_mem",
"scan", "scan",
...@@ -2542,9 +2545,10 @@ scan_seqopt1.register( ...@@ -2542,9 +2545,10 @@ scan_seqopt1.register(
scan_seqopt1.register( scan_seqopt1.register(
"scan_pushout_add", "scan_push_out_add",
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`? # TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
in2out(push_out_add_scan, ignore_newtrees=False), in2out(scan_push_out_add, ignore_newtrees=False),
"scan_pushout_add", # For backcompat: so it can be tagged with old name
"fast_run", "fast_run",
"more_mem", "more_mem",
"scan", "scan",
......
...@@ -304,7 +304,7 @@ class TestPushOutDot: ...@@ -304,7 +304,7 @@ class TestPushOutDot:
class TestPushOutNonSeqScan: class TestPushOutNonSeqScan:
""" """
Tests for the `push_out_non_seq_scan` optimization in the case where the inner Tests for the `scan_push_out_non_seq` optimization in the case where the inner
function of a `Scan` `Op` has an output which is the result of a `Dot` product function of a `Scan` `Op` has an output which is the result of a `Dot` product
on a non-sequence matrix input to `Scan` and a vector that is the result of on a non-sequence matrix input to `Scan` and a vector that is the result of
computation in the inner function. computation in the inner function.
...@@ -595,7 +595,7 @@ class TestPushOutNonSeqScan: ...@@ -595,7 +595,7 @@ class TestPushOutNonSeqScan:
class TestPushOutAddScan: class TestPushOutAddScan:
""" """
Test case for the `push_out_add_scan` optimization in the case where the `Scan` Test case for the `scan_push_out_add` optimization in the case where the `Scan`
is used to compute the sum over the dot products between the corresponding is used to compute the sum over the dot products between the corresponding
elements of two list of matrices. elements of two list of matrices.
...@@ -1208,7 +1208,7 @@ class TestScanInplaceOptimizer: ...@@ -1208,7 +1208,7 @@ class TestScanInplaceOptimizer:
class TestSaveMem: class TestSaveMem:
mode = get_default_mode().including("scan_save_mem", "save_mem_new_scan") mode = get_default_mode().including("scan_save_mem", "scan_save_mem")
def test_save_mem(self): def test_save_mem(self):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论