提交 a24cd432 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename Scan optimization tags

上级 730f790e
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -1823,11 +1823,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# case we write about.
raise
ne = ValueError(
"An output of the scan has changed shape. "
"This may be caused by a pushout optimization."
" Try adding "
"'optimizer_excluding=scanOp_pushout_output' "
"to your Aesara flags."
"An output of the Scan has changed shape. "
"This may be caused by a push-out optimization."
" Try adding 'optimizer_excluding=scan_pushout'"
" to your Aesara flags."
)
raise ne from e
......
......@@ -786,7 +786,7 @@ def add_nitsot_outputs(
fgraph.replace_all_validate_remove(
list(zip(old_scan_node.outputs, new_node_old_outputs)),
remove=[old_scan_node],
reason="scan_pushout_output",
reason="scan_pushout_add",
)
return new_scan_node, {}
......@@ -1003,7 +1003,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
fgraph.replace_all_validate_remove(
list(zip(node.outputs, new_outs)),
remove=[node],
reason="scanOp_make_inplace",
reason="scan_make_inplace",
)
return new_outs[0].owner
except InconsistencyError:
......@@ -1892,7 +1892,7 @@ class ScanMerge(GlobalOptimizer):
if len(subset) > 1:
proposal = self.merge(subset)
fgraph.replace_all_validate_remove(
proposal, remove=subset, reason="scanOp_merge"
proposal, remove=subset, reason="scan_merge"
)
......@@ -2341,14 +2341,14 @@ optdb.register("scan_eqopt1", scan_eqopt1, 0.05, "fast_run", "scan")
optdb.register("scan_eqopt2", scan_eqopt2, 1.6, "fast_run", "scan")
# ScanSaveMem should execute only once per node.
optdb.register(
"scanOp_save_mem",
"scan_save_mem",
in2out(save_mem_new_scan, ignore_newtrees=True),
1.61,
"fast_run",
"scan",
)
optdb.register(
"scanOp_make_inplace",
"scan_make_inplace",
ScanInplaceOptimizer(typeInfer=None),
75,
"fast_run",
......@@ -2360,7 +2360,7 @@ scan_eqopt1.register("all_pushout_opt", scan_seqopt1, 1, "fast_run", "scan")
scan_seqopt1.register(
"scanOp_remove_constants_and_unused_inputs0",
"scan_remove_constants_and_unused_inputs0",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
1,
"remove_constants_and_unused_inputs_scan",
......@@ -2370,20 +2370,22 @@ scan_seqopt1.register(
scan_seqopt1.register(
"scanOp_pushout_nonseqs_ops",
"scan_pushout_nonseqs_ops",
in2out(push_out_non_seq_scan, ignore_newtrees=True),
2,
"fast_run",
"scan",
"scan_pushout",
)
scan_seqopt1.register(
"scanOp_pushout_seqs_ops",
"scan_pushout_seqs_ops",
in2out(push_out_seq_scan, ignore_newtrees=True),
3,
"fast_run",
"scan",
"scan_pushout",
)
......@@ -2394,17 +2396,19 @@ scan_seqopt1.register(
"fast_run",
"more_mem",
"scan",
"scan_pushout",
)
scan_seqopt1.register(
"scanOp_pushout_output",
"scan_pushout_add",
# TODO: Perhaps this should be an `EquilibriumOptimizer`?
in2out(push_out_add_scan, ignore_newtrees=False),
5,
"fast_run",
"more_mem",
"scan",
"scan_pushout",
)
......@@ -2418,7 +2422,7 @@ scan_eqopt2.register(
scan_eqopt2.register(
"scanOp_remove_constants_and_unused_inputs1",
"scan_remove_constants_and_unused_inputs1",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2,
"remove_constants_and_unused_inputs_scan",
......@@ -2430,11 +2434,11 @@ scan_eqopt2.register(
# after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later.
scan_eqopt2.register("scanOp_merge", ScanMerge(), 4, "fast_run", "scan")
scan_eqopt2.register("scan_merge", ScanMerge(), 4, "fast_run", "scan")
# After Merge optimization
scan_eqopt2.register(
"scanop_remove_constants_and_unused_inputs2",
"scan_remove_constants_and_unused_inputs2",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
5,
"remove_constants_and_unused_inputs_scan",
......@@ -2443,17 +2447,16 @@ scan_eqopt2.register(
)
scan_eqopt2.register(
"scanOp_merge_inouts",
"scan_merge_inouts",
in2out(scan_merge_inouts, ignore_newtrees=True),
6,
"scan_merge_inouts",
"fast_run",
"scan",
)
# After everything else
scan_eqopt2.register(
"scanOp_remove_constants_and_unused_inputs3",
"scan_remove_constants_and_unused_inputs3",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
8,
"remove_constants_and_unused_inputs_scan",
......
......@@ -58,7 +58,7 @@ from aesara.link.utils import raise_with_op
def get_version():
return 0.300
return 0.301
@cython.boundscheck(False)
def perform(
......@@ -545,11 +545,10 @@ def perform(
if i == 0:
raise
raise ValueError(
"An output of the scan has changed shape. "
"This may be caused by a pushout optimization."
" Try adding "
"'optimizer_excluding=scanOp_pushout_output' "
"to your Aesara flags.")
"An output of the Scan has changed shape. "
"This may be caused by a push-out optimization."
" Try adding 'optimizer_excluding=scan_pushout'"
" to your Aesara flags.")
# 5.6 Copy over the values for outputs corresponding to shared
# variables
......
......@@ -21,7 +21,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.300 # must match constant returned in function get_version()
version = 0.301 # must match constant returned in function get_version()
need_reload = False
......
......@@ -159,9 +159,9 @@ Could lower the memory usage, but raise computation time:
<aesara.tensor.nnet.batchnorm.batch_normalization>`. It use less memory
then building a corresponding Aesara graph.
- Disable one or scan more optimizations:
- ``optimizer_excluding=scanOp_pushout_seqs_ops``
- ``optimizer_excluding=scan_pushout_seqs_ops``
- ``optimizer_excluding=scan_pushout_dot1``
- ``optimizer_excluding=scanOp_pushout_output``
- ``optimizer_excluding=scan_pushout_add``
- Disable all optimization tagged as raising memory usage:
``optimizer_excluding=more_mem`` (currently only the 3 scan optimizations above)
- `float16 <https://github.com/Theano/Theano/issues/2908>`_.
......
......@@ -2523,7 +2523,7 @@ class TestScan:
sy, upy = scan(sum, sequences=[y])
f = function(
[x, y], [sx, sy], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops")
[x, y], [sx, sy], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
......@@ -2533,7 +2533,7 @@ class TestScan:
sy, upy = scan(sum, sequences=[y], n_steps=3)
f = function(
[x, y], [sx, sy], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops")
[x, y], [sx, sy], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
......@@ -2543,7 +2543,7 @@ class TestScan:
sy, upy = scan(sum, sequences=[y], n_steps=4)
f = function(
[x, y], [sx, sy], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops")
[x, y], [sx, sy], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
......@@ -2553,7 +2553,7 @@ class TestScan:
sy, upy = scan(sum, sequences=[x])
f = function(
[x], [sx, sy], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops")
[x], [sx, sy], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
......@@ -2563,7 +2563,7 @@ class TestScan:
sy, upy = scan(sum, sequences=[x], mode="FAST_COMPILE")
f = function(
[x], [sx, sy], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops")
[x], [sx, sy], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
......@@ -2573,7 +2573,7 @@ class TestScan:
sy, upy = scan(sum, sequences=[x], truncate_gradient=1)
f = function(
[x], [sx, sy], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops")
[x], [sx, sy], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
......@@ -2596,7 +2596,7 @@ class TestScan:
sz, upz = scan(sum, sequences=[sx], n_steps=4, name="Z")
f = function(
[x, y], [sy, sz], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops")
[x, y], [sy, sz], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
......
......@@ -156,7 +156,7 @@ class TestPushOutScanOutputDot:
opt_mode = mode.including("scan")
f_opt = aesara.function([v, m], jacobian(output, v), mode=opt_mode)
no_opt_mode = mode.excluding("scanOp_pushout_output")
no_opt_mode = mode.excluding("scan_pushout_add")
f_no_opt = aesara.function([v, m], jacobian(output, v), mode=no_opt_mode)
# Ensure that the optimization was performed correctly in f_opt
......@@ -198,7 +198,7 @@ class TestPushOutScanOutputDot:
opt_mode = mode.including("scan")
f_opt = aesara.function([a, b], outputs, mode=opt_mode)
no_opt_mode = mode.excluding("scanOp_pushout_output")
no_opt_mode = mode.excluding("scan_pushout_add")
f_no_opt = aesara.function([a, b], outputs, mode=no_opt_mode)
# Ensure that the optimization was performed correctly in f_opt
......@@ -244,7 +244,7 @@ class TestPushOutScanOutputDot:
opt_mode = mode.including("scan")
f_opt = aesara.function([a, b], outputs, mode=opt_mode)
no_opt_mode = mode.excluding("scanOp_pushout_output")
no_opt_mode = mode.excluding("scan_pushout_add")
f_no_opt = aesara.function([a, b], outputs, mode=no_opt_mode)
# Ensure that the optimization was performed correctly in f_opt
......@@ -346,7 +346,7 @@ class TestPushOutSumOfDot:
grad1 = grad(cost, [U, V, W])
f_opt = aesara.function(inputs=[x, ri, zi], outputs=grad1, mode=opt_mode)
no_opt_mode = mode.excluding("scanOp_pushout_output")
no_opt_mode = mode.excluding("scan_pushout_add")
h, _ = aesara.scan(
rnn_step1,
sequences=[x, ri, zi],
......@@ -405,7 +405,7 @@ class TestPushOutSumOfDot:
output = h[-1]
f_opt = aesara.function([input1, input2, input3], output, mode=opt_mode)
no_opt_mode = mode.excluding("scanOp_pushout_output")
no_opt_mode = mode.excluding("scan_pushout_add")
h, _ = aesara.scan(
inner_fct,
sequences=[input1, input2, input3],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论