提交 a24cd432 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename Scan optimization tags

上级 730f790e
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -1823,11 +1823,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1823,11 +1823,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# case we write about. # case we write about.
raise raise
ne = ValueError( ne = ValueError(
"An output of the scan has changed shape. " "An output of the Scan has changed shape. "
"This may be caused by a pushout optimization." "This may be caused by a push-out optimization."
" Try adding " " Try adding 'optimizer_excluding=scan_pushout'"
"'optimizer_excluding=scanOp_pushout_output' " " to your Aesara flags."
"to your Aesara flags."
) )
raise ne from e raise ne from e
......
...@@ -786,7 +786,7 @@ def add_nitsot_outputs( ...@@ -786,7 +786,7 @@ def add_nitsot_outputs(
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
list(zip(old_scan_node.outputs, new_node_old_outputs)), list(zip(old_scan_node.outputs, new_node_old_outputs)),
remove=[old_scan_node], remove=[old_scan_node],
reason="scan_pushout_output", reason="scan_pushout_add",
) )
return new_scan_node, {} return new_scan_node, {}
...@@ -1003,7 +1003,7 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1003,7 +1003,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
list(zip(node.outputs, new_outs)), list(zip(node.outputs, new_outs)),
remove=[node], remove=[node],
reason="scanOp_make_inplace", reason="scan_make_inplace",
) )
return new_outs[0].owner return new_outs[0].owner
except InconsistencyError: except InconsistencyError:
...@@ -1892,7 +1892,7 @@ class ScanMerge(GlobalOptimizer): ...@@ -1892,7 +1892,7 @@ class ScanMerge(GlobalOptimizer):
if len(subset) > 1: if len(subset) > 1:
proposal = self.merge(subset) proposal = self.merge(subset)
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
proposal, remove=subset, reason="scanOp_merge" proposal, remove=subset, reason="scan_merge"
) )
...@@ -2341,14 +2341,14 @@ optdb.register("scan_eqopt1", scan_eqopt1, 0.05, "fast_run", "scan") ...@@ -2341,14 +2341,14 @@ optdb.register("scan_eqopt1", scan_eqopt1, 0.05, "fast_run", "scan")
optdb.register("scan_eqopt2", scan_eqopt2, 1.6, "fast_run", "scan") optdb.register("scan_eqopt2", scan_eqopt2, 1.6, "fast_run", "scan")
# ScanSaveMem should execute only once per node. # ScanSaveMem should execute only once per node.
optdb.register( optdb.register(
"scanOp_save_mem", "scan_save_mem",
in2out(save_mem_new_scan, ignore_newtrees=True), in2out(save_mem_new_scan, ignore_newtrees=True),
1.61, 1.61,
"fast_run", "fast_run",
"scan", "scan",
) )
optdb.register( optdb.register(
"scanOp_make_inplace", "scan_make_inplace",
ScanInplaceOptimizer(typeInfer=None), ScanInplaceOptimizer(typeInfer=None),
75, 75,
"fast_run", "fast_run",
...@@ -2360,7 +2360,7 @@ scan_eqopt1.register("all_pushout_opt", scan_seqopt1, 1, "fast_run", "scan") ...@@ -2360,7 +2360,7 @@ scan_eqopt1.register("all_pushout_opt", scan_seqopt1, 1, "fast_run", "scan")
scan_seqopt1.register( scan_seqopt1.register(
"scanOp_remove_constants_and_unused_inputs0", "scan_remove_constants_and_unused_inputs0",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
1, 1,
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
...@@ -2370,20 +2370,22 @@ scan_seqopt1.register( ...@@ -2370,20 +2370,22 @@ scan_seqopt1.register(
scan_seqopt1.register( scan_seqopt1.register(
"scanOp_pushout_nonseqs_ops", "scan_pushout_nonseqs_ops",
in2out(push_out_non_seq_scan, ignore_newtrees=True), in2out(push_out_non_seq_scan, ignore_newtrees=True),
2, 2,
"fast_run", "fast_run",
"scan", "scan",
"scan_pushout",
) )
scan_seqopt1.register( scan_seqopt1.register(
"scanOp_pushout_seqs_ops", "scan_pushout_seqs_ops",
in2out(push_out_seq_scan, ignore_newtrees=True), in2out(push_out_seq_scan, ignore_newtrees=True),
3, 3,
"fast_run", "fast_run",
"scan", "scan",
"scan_pushout",
) )
...@@ -2394,17 +2396,19 @@ scan_seqopt1.register( ...@@ -2394,17 +2396,19 @@ scan_seqopt1.register(
"fast_run", "fast_run",
"more_mem", "more_mem",
"scan", "scan",
"scan_pushout",
) )
scan_seqopt1.register( scan_seqopt1.register(
"scanOp_pushout_output", "scan_pushout_add",
# TODO: Perhaps this should be an `EquilibriumOptimizer`? # TODO: Perhaps this should be an `EquilibriumOptimizer`?
in2out(push_out_add_scan, ignore_newtrees=False), in2out(push_out_add_scan, ignore_newtrees=False),
5, 5,
"fast_run", "fast_run",
"more_mem", "more_mem",
"scan", "scan",
"scan_pushout",
) )
...@@ -2418,7 +2422,7 @@ scan_eqopt2.register( ...@@ -2418,7 +2422,7 @@ scan_eqopt2.register(
scan_eqopt2.register( scan_eqopt2.register(
"scanOp_remove_constants_and_unused_inputs1", "scan_remove_constants_and_unused_inputs1",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2, 2,
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
...@@ -2430,11 +2434,11 @@ scan_eqopt2.register( ...@@ -2430,11 +2434,11 @@ scan_eqopt2.register(
# after const merge but before stabilize so that we can have identity # after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out # for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later. # of the scan later.
scan_eqopt2.register("scanOp_merge", ScanMerge(), 4, "fast_run", "scan") scan_eqopt2.register("scan_merge", ScanMerge(), 4, "fast_run", "scan")
# After Merge optimization # After Merge optimization
scan_eqopt2.register( scan_eqopt2.register(
"scanop_remove_constants_and_unused_inputs2", "scan_remove_constants_and_unused_inputs2",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
5, 5,
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
...@@ -2443,17 +2447,16 @@ scan_eqopt2.register( ...@@ -2443,17 +2447,16 @@ scan_eqopt2.register(
) )
scan_eqopt2.register( scan_eqopt2.register(
"scanOp_merge_inouts", "scan_merge_inouts",
in2out(scan_merge_inouts, ignore_newtrees=True), in2out(scan_merge_inouts, ignore_newtrees=True),
6, 6,
"scan_merge_inouts",
"fast_run", "fast_run",
"scan", "scan",
) )
# After everything else # After everything else
scan_eqopt2.register( scan_eqopt2.register(
"scanOp_remove_constants_and_unused_inputs3", "scan_remove_constants_and_unused_inputs3",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
8, 8,
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
......
...@@ -58,7 +58,7 @@ from aesara.link.utils import raise_with_op ...@@ -58,7 +58,7 @@ from aesara.link.utils import raise_with_op
def get_version(): def get_version():
return 0.300 return 0.301
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -545,11 +545,10 @@ def perform( ...@@ -545,11 +545,10 @@ def perform(
if i == 0: if i == 0:
raise raise
raise ValueError( raise ValueError(
"An output of the scan has changed shape. " "An output of the Scan has changed shape. "
"This may be caused by a pushout optimization." "This may be caused by a push-out optimization."
" Try adding " " Try adding 'optimizer_excluding=scan_pushout'"
"'optimizer_excluding=scanOp_pushout_output' " " to your Aesara flags.")
"to your Aesara flags.")
# 5.6 Copy over the values for outputs corresponding to shared # 5.6 Copy over the values for outputs corresponding to shared
# variables # variables
......
...@@ -21,7 +21,7 @@ if not config.cxx: ...@@ -21,7 +21,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.300 # must match constant returned in function get_version() version = 0.301 # must match constant returned in function get_version()
need_reload = False need_reload = False
......
...@@ -159,9 +159,9 @@ Could lower the memory usage, but raise computation time: ...@@ -159,9 +159,9 @@ Could lower the memory usage, but raise computation time:
<aesara.tensor.nnet.batchnorm.batch_normalization>`. It use less memory <aesara.tensor.nnet.batchnorm.batch_normalization>`. It use less memory
then building a corresponding Aesara graph. then building a corresponding Aesara graph.
- Disable one or scan more optimizations: - Disable one or scan more optimizations:
- ``optimizer_excluding=scanOp_pushout_seqs_ops`` - ``optimizer_excluding=scan_pushout_seqs_ops``
- ``optimizer_excluding=scan_pushout_dot1`` - ``optimizer_excluding=scan_pushout_dot1``
- ``optimizer_excluding=scanOp_pushout_output`` - ``optimizer_excluding=scan_pushout_add``
- Disable all optimization tagged as raising memory usage: - Disable all optimization tagged as raising memory usage:
``optimizer_excluding=more_mem`` (currently only the 3 scan optimizations above) ``optimizer_excluding=more_mem`` (currently only the 3 scan optimizations above)
- `float16 <https://github.com/Theano/Theano/issues/2908>`_. - `float16 <https://github.com/Theano/Theano/issues/2908>`_.
......
...@@ -2523,7 +2523,7 @@ class TestScan: ...@@ -2523,7 +2523,7 @@ class TestScan:
sy, upy = scan(sum, sequences=[y]) sy, upy = scan(sum, sequences=[y])
f = function( f = function(
[x, y], [sx, sy], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops") [x, y], [sx, sy], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
) )
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)] scans = [n for n in topo if isinstance(n.op, Scan)]
...@@ -2533,7 +2533,7 @@ class TestScan: ...@@ -2533,7 +2533,7 @@ class TestScan:
sy, upy = scan(sum, sequences=[y], n_steps=3) sy, upy = scan(sum, sequences=[y], n_steps=3)
f = function( f = function(
[x, y], [sx, sy], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops") [x, y], [sx, sy], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
) )
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)] scans = [n for n in topo if isinstance(n.op, Scan)]
...@@ -2543,7 +2543,7 @@ class TestScan: ...@@ -2543,7 +2543,7 @@ class TestScan:
sy, upy = scan(sum, sequences=[y], n_steps=4) sy, upy = scan(sum, sequences=[y], n_steps=4)
f = function( f = function(
[x, y], [sx, sy], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops") [x, y], [sx, sy], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
) )
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)] scans = [n for n in topo if isinstance(n.op, Scan)]
...@@ -2553,7 +2553,7 @@ class TestScan: ...@@ -2553,7 +2553,7 @@ class TestScan:
sy, upy = scan(sum, sequences=[x]) sy, upy = scan(sum, sequences=[x])
f = function( f = function(
[x], [sx, sy], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops") [x], [sx, sy], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
) )
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)] scans = [n for n in topo if isinstance(n.op, Scan)]
...@@ -2563,7 +2563,7 @@ class TestScan: ...@@ -2563,7 +2563,7 @@ class TestScan:
sy, upy = scan(sum, sequences=[x], mode="FAST_COMPILE") sy, upy = scan(sum, sequences=[x], mode="FAST_COMPILE")
f = function( f = function(
[x], [sx, sy], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops") [x], [sx, sy], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
) )
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)] scans = [n for n in topo if isinstance(n.op, Scan)]
...@@ -2573,7 +2573,7 @@ class TestScan: ...@@ -2573,7 +2573,7 @@ class TestScan:
sy, upy = scan(sum, sequences=[x], truncate_gradient=1) sy, upy = scan(sum, sequences=[x], truncate_gradient=1)
f = function( f = function(
[x], [sx, sy], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops") [x], [sx, sy], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
) )
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)] scans = [n for n in topo if isinstance(n.op, Scan)]
...@@ -2596,7 +2596,7 @@ class TestScan: ...@@ -2596,7 +2596,7 @@ class TestScan:
sz, upz = scan(sum, sequences=[sx], n_steps=4, name="Z") sz, upz = scan(sum, sequences=[sx], n_steps=4, name="Z")
f = function( f = function(
[x, y], [sy, sz], mode=mode_with_opt.excluding("scanOp_pushout_seqs_ops") [x, y], [sy, sz], mode=mode_with_opt.excluding("scan_pushout_seqs_ops")
) )
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)] scans = [n for n in topo if isinstance(n.op, Scan)]
......
...@@ -156,7 +156,7 @@ class TestPushOutScanOutputDot: ...@@ -156,7 +156,7 @@ class TestPushOutScanOutputDot:
opt_mode = mode.including("scan") opt_mode = mode.including("scan")
f_opt = aesara.function([v, m], jacobian(output, v), mode=opt_mode) f_opt = aesara.function([v, m], jacobian(output, v), mode=opt_mode)
no_opt_mode = mode.excluding("scanOp_pushout_output") no_opt_mode = mode.excluding("scan_pushout_add")
f_no_opt = aesara.function([v, m], jacobian(output, v), mode=no_opt_mode) f_no_opt = aesara.function([v, m], jacobian(output, v), mode=no_opt_mode)
# Ensure that the optimization was performed correctly in f_opt # Ensure that the optimization was performed correctly in f_opt
...@@ -198,7 +198,7 @@ class TestPushOutScanOutputDot: ...@@ -198,7 +198,7 @@ class TestPushOutScanOutputDot:
opt_mode = mode.including("scan") opt_mode = mode.including("scan")
f_opt = aesara.function([a, b], outputs, mode=opt_mode) f_opt = aesara.function([a, b], outputs, mode=opt_mode)
no_opt_mode = mode.excluding("scanOp_pushout_output") no_opt_mode = mode.excluding("scan_pushout_add")
f_no_opt = aesara.function([a, b], outputs, mode=no_opt_mode) f_no_opt = aesara.function([a, b], outputs, mode=no_opt_mode)
# Ensure that the optimization was performed correctly in f_opt # Ensure that the optimization was performed correctly in f_opt
...@@ -244,7 +244,7 @@ class TestPushOutScanOutputDot: ...@@ -244,7 +244,7 @@ class TestPushOutScanOutputDot:
opt_mode = mode.including("scan") opt_mode = mode.including("scan")
f_opt = aesara.function([a, b], outputs, mode=opt_mode) f_opt = aesara.function([a, b], outputs, mode=opt_mode)
no_opt_mode = mode.excluding("scanOp_pushout_output") no_opt_mode = mode.excluding("scan_pushout_add")
f_no_opt = aesara.function([a, b], outputs, mode=no_opt_mode) f_no_opt = aesara.function([a, b], outputs, mode=no_opt_mode)
# Ensure that the optimization was performed correctly in f_opt # Ensure that the optimization was performed correctly in f_opt
...@@ -346,7 +346,7 @@ class TestPushOutSumOfDot: ...@@ -346,7 +346,7 @@ class TestPushOutSumOfDot:
grad1 = grad(cost, [U, V, W]) grad1 = grad(cost, [U, V, W])
f_opt = aesara.function(inputs=[x, ri, zi], outputs=grad1, mode=opt_mode) f_opt = aesara.function(inputs=[x, ri, zi], outputs=grad1, mode=opt_mode)
no_opt_mode = mode.excluding("scanOp_pushout_output") no_opt_mode = mode.excluding("scan_pushout_add")
h, _ = aesara.scan( h, _ = aesara.scan(
rnn_step1, rnn_step1,
sequences=[x, ri, zi], sequences=[x, ri, zi],
...@@ -405,7 +405,7 @@ class TestPushOutSumOfDot: ...@@ -405,7 +405,7 @@ class TestPushOutSumOfDot:
output = h[-1] output = h[-1]
f_opt = aesara.function([input1, input2, input3], output, mode=opt_mode) f_opt = aesara.function([input1, input2, input3], output, mode=opt_mode)
no_opt_mode = mode.excluding("scanOp_pushout_output") no_opt_mode = mode.excluding("scan_pushout_add")
h, _ = aesara.scan( h, _ = aesara.scan(
inner_fct, inner_fct,
sequences=[input1, input2, input3], sequences=[input1, input2, input3],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论