提交 e30d7b90 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Change SequenceDB's position argument to a keyword

上级 aab9e8c8
......@@ -854,7 +854,7 @@ def inline_ofg_expansion(fgraph, node):
optdb.register(
"inline_ofg_expansion",
in2out(inline_ofg_expansion),
-0.01,
"fast_compile",
"fast_run",
position=-0.01,
)
......@@ -190,7 +190,9 @@ class PrintCurrentFunctionGraph(GlobalOptimizer):
optdb = SequenceDB()
optdb.register("merge1", MergeOptimizer(), 0, "fast_run", "fast_compile", "merge")
optdb.register(
"merge1", MergeOptimizer(), "fast_run", "fast_compile", "merge", position=0
)
# After scan1 opt at 0.5 and before ShapeOpt at 1
# This should only remove nodes.
......@@ -201,21 +203,23 @@ local_useless = LocalGroupDB(apply_all_opts=True, profile=True)
optdb.register(
"useless",
TopoDB(local_useless, failure_callback=NavigatorOptimizer.warn_inplace),
0.6,
"fast_run",
"fast_compile",
position=0.6,
)
optdb.register("merge1.1", MergeOptimizer(), 0.65, "fast_run", "fast_compile", "merge")
optdb.register(
"merge1.1", MergeOptimizer(), "fast_run", "fast_compile", "merge", position=0.65
)
# rearranges elemwise expressions
optdb.register(
"canonicalize",
EquilibriumDB(ignore_newtrees=False),
1,
"fast_run",
"fast_compile",
"canonicalize_db",
position=1,
)
# Register in the canonizer Equilibrium as a clean up opt the merge opt.
# Without this, as the equilibrium have ignore_newtrees=False, we
......@@ -228,41 +232,47 @@ optdb["canonicalize"].register(
"merge", MergeOptimizer(), "fast_run", "fast_compile", cleanup=True
)
optdb.register("merge1.2", MergeOptimizer(), 1.2, "fast_run", "fast_compile", "merge")
optdb.register(
"merge1.2", MergeOptimizer(), "fast_run", "fast_compile", "merge", position=1.2
)
optdb.register(
"Print1.21",
PrintCurrentFunctionGraph("Post-canonicalize"),
1.21,
position=1.21,
) # 'fast_run', 'fast_compile')
# replace unstable subgraphs
optdb.register("stabilize", EquilibriumDB(), 1.5, "fast_run")
optdb.register("stabilize", EquilibriumDB(), "fast_run", position=1.5)
optdb.register(
"Print1.51",
PrintCurrentFunctionGraph("Post-stabilize"),
1.51,
position=1.51,
) # 'fast_run', 'fast_compile')
# misc special cases for speed
optdb.register("specialize", EquilibriumDB(), 2, "fast_run", "fast_compile_gpu")
optdb.register(
"specialize", EquilibriumDB(), "fast_run", "fast_compile_gpu", position=2
)
# misc special cases for speed that break canonicalization
optdb.register("uncanonicalize", EquilibriumDB(), 3, "fast_run")
optdb.register("uncanonicalize", EquilibriumDB(), "fast_run", position=3)
# misc special cases for speed that are dependent on the device.
optdb.register(
"specialize_device", EquilibriumDB(), 48.6, "fast_compile", "fast_run"
"specialize_device", EquilibriumDB(), "fast_compile", "fast_run", position=48.6
) # must be after gpu stuff at 48.5
# especially constant merge
optdb.register("merge2", MergeOptimizer(), 49, "fast_run", "merge")
optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49)
optdb.register("add_destroy_handler", AddDestroyHandler(), 49.5, "fast_run", "inplace")
optdb.register(
"add_destroy_handler", AddDestroyHandler(), "fast_run", "inplace", position=49.5
)
# final pass just to make sure
optdb.register("merge3", MergeOptimizer(), 100, "fast_run", "merge")
optdb.register("merge3", MergeOptimizer(), "fast_run", "merge", position=100)
_tags: Union[Tuple[str, str], Tuple]
......@@ -272,7 +282,7 @@ if config.check_stack_trace in ("raise", "warn", "log"):
if config.check_stack_trace == "off":
_tags = ()
optdb.register("CheckStackTrace", CheckStackTraceOptimization(), -1, *_tags)
optdb.register("CheckStackTrace", CheckStackTraceOptimization(), *_tags, position=-1)
del _tags
......
......@@ -2090,5 +2090,10 @@ gpuablas_opt_inplace = in2out(
)
optdb.register(
"InplaceGpuaBlasOpt", gpuablas_opt_inplace, 70.0, "fast_run", "inplace", "gpuarray"
"InplaceGpuaBlasOpt",
gpuablas_opt_inplace,
"fast_run",
"inplace",
"gpuarray",
position=70.0,
)
......@@ -432,11 +432,11 @@ optdb.register(
local_dnn_convgi_inplace,
name="local_dnna_conv_inplace",
),
70.0,
"fast_run",
"inplace",
"gpuarray",
"cudnn",
position=70.0,
)
......@@ -837,7 +837,7 @@ class NoCuDNNRaise(GlobalOptimizer):
)
gpu_seqopt.register("NoCuDNNRaise", NoCuDNNRaise(), 0, "cudnn")
gpu_seqopt.register("NoCuDNNRaise", NoCuDNNRaise(), "cudnn", position=0)
@register_inplace()
......
......@@ -204,23 +204,28 @@ _logger = logging.getLogger("aesara.gpuarray.opt")
gpu_seqopt.register(
"gpuarray_graph_optimization",
GraphToGPUDB(),
-0.5,
"fast_compile",
"fast_run",
"gpuarray",
position=-0.5,
)
gpu_seqopt.register(
"gpuarray_local_optimizations",
gpu_optimizer,
1,
"fast_compile",
"fast_run",
"gpuarray",
"gpuarray_local_optimiziations",
position=1,
)
gpu_seqopt.register(
"gpuarray_cut_transfers", gpu_cut_copies, 2, "fast_compile", "fast_run", "gpuarray"
"gpuarray_cut_transfers",
gpu_cut_copies,
"fast_compile",
"fast_run",
"gpuarray",
position=2,
)
register_opt("fast_compile")(aesara.tensor.basic_opt.local_track_shape_i)
......@@ -280,10 +285,10 @@ class InputToGpuOptimizer(GlobalOptimizer):
gpu_seqopt.register(
"InputToGpuArrayOptimizer",
InputToGpuOptimizer(),
0,
"fast_run",
"fast_compile",
"merge",
position=0,
)
......@@ -702,8 +707,8 @@ optdb.register(
"local_gpua_alloc_empty_to_zeros",
aesara.graph.opt.in2out(local_gpua_alloc_empty_to_zeros),
# After move to gpu and merge2, before inplace.
49.3,
"alloc_empty_to_zeros",
position=49.3,
)
......@@ -866,27 +871,27 @@ gpu_local_elemwise_fusion = aesara.tensor.basic_opt.local_elemwise_fusion_op(
)
optdb.register(
"gpua_elemwise_fusion",
# 48.5 move to gpu
# 48.6 specialize
# 49 cpu fusion
# 49.5 add destroy handler
aesara.tensor.basic_opt.FusionOptimizer(gpu_local_elemwise_fusion),
49,
"fast_run",
"fusion",
"local_elemwise_fusion",
"gpuarray",
# 48.5 move to gpu
# 48.6 specialize
# 49 cpu fusion
# 49.5 add destroy handler
position=49,
)
inplace_gpu_elemwise_opt = aesara.tensor.basic_opt.InplaceElemwiseOptimizer(GpuElemwise)
optdb.register(
"gpua_inplace_opt",
inplace_gpu_elemwise_opt,
75,
"inplace_elemwise_optimizer",
"fast_run",
"inplace",
"gpuarray",
position=75,
)
register_opt(aesara.tensor.basic_opt.local_useless_elemwise)
......@@ -2608,7 +2613,9 @@ assert_no_cpu_op = aesara.graph.opt.in2out(
local_assert_no_cpu_op, name="assert_no_cpu_op"
)
# 49.2 is after device specialization & fusion optimizations for last transfers
optdb.register("gpua_assert_no_cpu_op", assert_no_cpu_op, 49.2, "assert_no_cpu_op")
optdb.register(
"gpua_assert_no_cpu_op", assert_no_cpu_op, "assert_no_cpu_op", position=49.2
)
def tensor_to_gpu(x, context_name):
......@@ -2961,97 +2968,97 @@ def local_gpu_ctc(fgraph, op, context_name, inputs, outputs):
optdb.register(
"gpua_scanOp_make_inplace",
ScanInplaceOptimizer(typeInfer=_scan_type_infer, gpua_flag=True),
75,
"gpuarray",
"inplace",
"scan",
position=75,
)
abstractconv_groupopt.register(
"local_abstractconv_dnn",
local_abstractconv_cudnn,
20,
"conv_dnn",
"gpuarray",
"fast_compile",
"fast_run",
"cudnn",
position=20,
)
abstractconv_groupopt.register(
"local_abstractconv_gw_dnn",
local_abstractconv_gw_cudnn,
20,
"conv_dnn",
"gpuarray",
"fast_compile",
"fast_run",
"cudnn",
position=20,
)
abstractconv_groupopt.register(
"local_abstractconv_gi_dnn",
local_abstractconv_gi_cudnn,
20,
"conv_dnn",
"gpuarray",
"fast_compile",
"fast_run",
"cudnn",
position=20,
)
# The GEMM-based convolution comes last to catch all remaining cases.
# It can be disabled by excluding 'conv_gemm'.
abstractconv_groupopt.register(
"local_abstractconv_gemm",
local_abstractconv_gemm,
30,
"conv_gemm",
"gpuarray",
"fast_compile",
"fast_run",
position=30,
)
abstractconv_groupopt.register(
"local_abstractconv3d_gemm",
local_abstractconv3d_gemm,
30,
"conv_gemm",
"gpuarray",
"fast_compile",
"fast_run",
position=30,
)
abstractconv_groupopt.register(
"local_abstractconv_gradweights_gemm",
local_abstractconv_gradweights_gemm,
30,
"conv_gemm",
"gpuarray",
"fast_compile",
"fast_run",
position=30,
)
abstractconv_groupopt.register(
"local_abstractconv3d_gradweights_gemm",
local_abstractconv3d_gradweights_gemm,
30,
"conv_gemm",
"gpuarray",
"fast_compile",
"fast_run",
position=30,
)
abstractconv_groupopt.register(
"local_abstractconv_gradinputs",
local_abstractconv_gradinputs_gemm,
30,
"conv_gemm",
"gpuarray",
"fast_compile",
"fast_run",
position=30,
)
abstractconv_groupopt.register(
"local_abstractconv3d_gradinputs",
local_abstractconv3d_gradinputs_gemm,
30,
"conv_gemm",
"gpuarray",
"fast_compile",
"fast_run",
position=30,
)
conv_metaopt = ConvMetaOptimizer()
......
......@@ -60,8 +60,8 @@ gpu_seqopt = SequenceDB()
optdb.register(
"gpuarray_opt",
gpu_seqopt,
optdb.__position__.get("add_destroy_handler", 49.5) - 1,
"gpuarray",
position=optdb.__position__.get("add_destroy_handler", 49.5) - 1,
)
......@@ -123,11 +123,11 @@ def register_inplace(*tags, **kwargs):
optdb.register(
name,
TopoOptimizer(local_opt, failure_callback=TopoOptimizer.warn_inplace),
60,
"fast_run",
"inplace",
"gpuarray",
*tags,
position=60,
)
return local_opt
......
......@@ -34,6 +34,7 @@ class OptimizationDatabase:
optimizer: Union["OptimizationDatabase", OptimizersType],
*tags: str,
use_db_name_as_tag=True,
**kwargs,
):
"""Register a new optimizer to the database.
......@@ -339,10 +340,10 @@ class EquilibriumDB(OptimizationDatabase):
self.__final__ = {}
self.__cleanup__ = {}
def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwtags):
def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwargs):
if final_opt and cleanup:
raise ValueError("`final_opt` and `cleanup` cannot both be true.")
super().register(name, obj, *tags, **kwtags)
super().register(name, obj, *tags, **kwargs)
self.__final__[name] = final_opt
self.__cleanup__[name] = cleanup
......@@ -387,7 +388,9 @@ class SequenceDB(OptimizationDatabase):
self.__position__ = {}
self.failure_callback = failure_callback
def register(self, name, obj, position: Union[str, int, float], *tags, **kwargs):
def register(
self, name, obj, *tags, position: Union[str, int, float] = "last", **kwargs
):
super().register(name, obj, *tags, **kwargs)
if position == "last":
if len(self.__position__) == 0:
......@@ -493,7 +496,7 @@ class LocalGroupDB(SequenceDB):
self.__name__: str = ""
def register(self, name, obj, *tags, position="last", **kwargs):
super().register(name, obj, position, *tags, **kwargs)
super().register(name, obj, *tags, position=position, **kwargs)
def query(self, *tags, **kwtags):
opts = list(super().query(*tags, **kwtags))
......
......@@ -420,9 +420,9 @@ def cond_make_inplace(fgraph, node):
optdb.register(
"cond_make_inplace",
in2out(cond_make_inplace, ignore_newtrees=True),
95,
"fast_run",
"inplace",
position=95,
)
# XXX: Optimizations commented pending further debugging (certain optimizations
......@@ -456,8 +456,8 @@ where, each of the optimization do the following things:
`ifelse_lift` (def cond_lift_single_if):
"""
# optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, .5, 'fast_run',
# 'ifelse')
# optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, 'fast_run',
# 'ifelse', position=.5)
acceptable_ops = (
......@@ -768,26 +768,26 @@ def cond_merge_random_op(fgraph, main_node):
#
# ifelse_seqopt.register('ifelse_condPushOut_equilibrium',
# pushout_equilibrium,
# 1, 'fast_run', 'ifelse')
# 'fast_run', 'ifelse', position=1)
#
# ifelse_seqopt.register('merge_nodes_1',
# graph.opt.MergeOptimizer(skip_const_merge=False),
# 2, 'fast_run', 'ifelse')
# 'fast_run', 'ifelse', position=2)
#
#
# ifelse_seqopt.register('ifelse_sameCondTrue',
# in2out(cond_merge_ifs_true,
# ignore_newtrees=True),
# 3, 'fast_run', 'ifelse')
# 'fast_run', 'ifelse', position=3)
#
#
# ifelse_seqopt.register('ifelse_sameCondFalse',
# in2out(cond_merge_ifs_false,
# ignore_newtrees=True),
# 4, 'fast_run', 'ifelse')
# 'fast_run', 'ifelse', position=4)
#
#
# ifelse_seqopt.register('ifelse_removeIdenetical',
# in2out(cond_remove_identical,
# ignore_newtrees=True),
# 7, 'fast_run', 'ifelse')
# 'fast_run', 'ifelse', position=7)
......@@ -1358,7 +1358,7 @@ def mrg_random_make_inplace(fgraph, node):
optdb.register(
"random_make_inplace_mrg",
in2out(mrg_random_make_inplace, ignore_newtrees=True),
99,
"fast_run",
"inplace",
position=99,
)
......@@ -2336,68 +2336,68 @@ scan_eqopt2 = EquilibriumDB()
# scan_eqopt1 before ShapeOpt at 0.1
# This is needed to don't have ShapeFeature trac old Scan that we
# don't want to reintroduce.
optdb.register("scan_eqopt1", scan_eqopt1, 0.05, "fast_run", "scan")
optdb.register("scan_eqopt1", scan_eqopt1, "fast_run", "scan", position=0.05)
# We run before blas opt at 1.7 and specialize 2.0
# but after stabilize at 1.5. Should we put it before stabilize?
optdb.register("scan_eqopt2", scan_eqopt2, 1.6, "fast_run", "scan")
optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
# ScanSaveMem should execute only once per node.
optdb.register(
"scan_save_mem",
in2out(save_mem_new_scan, ignore_newtrees=True),
1.61,
"fast_run",
"scan",
position=1.61,
)
optdb.register(
"scan_make_inplace",
ScanInplaceOptimizer(typeInfer=None),
75,
"fast_run",
"inplace",
"scan",
position=75,
)
scan_eqopt1.register("all_pushout_opt", scan_seqopt1, 1, "fast_run", "scan")
scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan", position=1)
scan_seqopt1.register(
"scan_remove_constants_and_unused_inputs0",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
1,
"remove_constants_and_unused_inputs_scan",
"fast_run",
"scan",
position=1,
)
scan_seqopt1.register(
"scan_pushout_nonseqs_ops",
in2out(push_out_non_seq_scan, ignore_newtrees=True),
2,
"fast_run",
"scan",
"scan_pushout",
position=2,
)
scan_seqopt1.register(
"scan_pushout_seqs_ops",
in2out(push_out_seq_scan, ignore_newtrees=True),
3,
"fast_run",
"scan",
"scan_pushout",
position=3,
)
scan_seqopt1.register(
"scan_pushout_dot1",
in2out(push_out_dot1_scan, ignore_newtrees=True),
4,
"fast_run",
"more_mem",
"scan",
"scan_pushout",
position=4,
)
......@@ -2405,62 +2405,62 @@ scan_seqopt1.register(
"scan_pushout_add",
# TODO: Perhaps this should be an `EquilibriumOptimizer`?
in2out(push_out_add_scan, ignore_newtrees=False),
5,
"fast_run",
"more_mem",
"scan",
"scan_pushout",
position=5,
)
scan_eqopt2.register(
"constant_folding_for_scan2",
in2out(basic_opt.constant_folding, ignore_newtrees=True),
1,
"fast_run",
"scan",
position=1,
)
scan_eqopt2.register(
"scan_remove_constants_and_unused_inputs1",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2,
"remove_constants_and_unused_inputs_scan",
"fast_run",
"scan",
position=2,
)
# after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later.
scan_eqopt2.register("scan_merge", ScanMerge(), 4, "fast_run", "scan")
scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan", position=4)
# After Merge optimization
scan_eqopt2.register(
"scan_remove_constants_and_unused_inputs2",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
5,
"remove_constants_and_unused_inputs_scan",
"fast_run",
"scan",
position=5,
)
scan_eqopt2.register(
"scan_merge_inouts",
in2out(scan_merge_inouts, ignore_newtrees=True),
6,
"fast_run",
"scan",
position=6,
)
# After everything else
scan_eqopt2.register(
"scan_remove_constants_and_unused_inputs3",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
8,
"remove_constants_and_unused_inputs_scan",
"fast_run",
"scan",
position=8,
)
......@@ -75,9 +75,9 @@ def local_inplace_remove0(fgraph, node):
aesara.compile.optdb.register(
"local_inplace_remove0",
TopoOptimizer(local_inplace_remove0, failure_callback=TopoOptimizer.warn_inplace),
60,
"fast_run",
"inplace",
position=60,
)
......@@ -216,9 +216,9 @@ aesara.compile.optdb.register(
TopoOptimizer(
local_inplace_addsd_ccode, failure_callback=TopoOptimizer.warn_inplace
),
60,
"fast_run",
"inplace",
position=60,
)
......@@ -248,8 +248,8 @@ aesara.compile.optdb.register(
"local_addsd_ccode",
TopoOptimizer(local_addsd_ccode),
# Must be after local_inplace_addsd_ccode at 60
61,
"fast_run",
position=61,
)
......
......@@ -474,11 +474,11 @@ inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise)
compile.optdb.register(
"inplace_elemwise_opt",
inplace_elemwise_optimizer,
75,
"inplace_opt", # for historic reason
"inplace_elemwise_optimizer",
"fast_run",
"inplace",
position=75,
)
......@@ -493,7 +493,7 @@ def register_useless(lopt, *tags, **kwargs):
name = kwargs.pop("name", None) or lopt.__name__
compile.mode.local_useless.register(
name, lopt, "last", "fast_run", *tags, **kwargs
name, lopt, "fast_run", *tags, position="last", **kwargs
)
return lopt
......@@ -1475,12 +1475,12 @@ class UnShapeOptimizer(GlobalOptimizer):
# Register it after merge1 optimization at 0. We don't want to track
# the shape of merged node.
aesara.compile.mode.optdb.register(
"ShapeOpt", ShapeOptimizer(), 0.1, "fast_run", "fast_compile"
"ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1
)
# Not enabled by default for now. Some crossentropy opt use the
# shape_feature. They are at step 2.01. uncanonicalize is at step
# 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable.
aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), 10)
aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
@register_specialize("local_alloc_elemwise")
......@@ -1741,11 +1741,11 @@ def local_fill_to_alloc(fgraph, node):
# Register this after stabilize at 1.5 to make sure stabilize don't
# get affected by less canonicalized graph due to alloc.
compile.optdb.register(
"local_fill_to_alloc", in2out(local_fill_to_alloc), 1.51, "fast_run"
"local_fill_to_alloc", in2out(local_fill_to_alloc), "fast_run", position=1.51
)
# Needed to clean some extra alloc added by local_fill_to_alloc
compile.optdb.register(
"local_elemwise_alloc", in2out(local_elemwise_alloc), 1.52, "fast_run"
"local_elemwise_alloc", in2out(local_elemwise_alloc), "fast_run", position=1.52
)
......@@ -1856,8 +1856,8 @@ compile.optdb.register(
"local_alloc_empty_to_zeros",
in2out(local_alloc_empty_to_zeros),
# After move to gpu and merge2, before inplace.
49.3,
"alloc_empty_to_zeros",
position=49.3,
)
......@@ -3369,28 +3369,28 @@ if config.tensor__local_elemwise_fusion:
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
1,
"fast_run",
"fusion",
position=1,
)
compile.optdb.register(
"elemwise_fusion",
fuse_seqopt,
49,
"fast_run",
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)
else:
_logger.debug("not enabling optimization fusion elemwise in fast_run")
compile.optdb.register(
"elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
49,
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)
......
......@@ -1798,15 +1798,19 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
blas_optdb = SequenceDB()
# run after numerical stability optimizations (1.5)
optdb.register("BlasOpt", blas_optdb, 1.7, "fast_run", "fast_compile")
optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7)
# run before specialize (2.0) because specialize is basically a
# free-for-all that makes the graph crazy.
# fast_compile is needed to have GpuDot22 created.
blas_optdb.register(
"local_dot_to_dot22", in2out(local_dot_to_dot22), 0, "fast_run", "fast_compile"
"local_dot_to_dot22",
in2out(local_dot_to_dot22),
"fast_run",
"fast_compile",
position=0,
)
blas_optdb.register("gemm_optimizer", GemmOptimizer(), 10, "fast_run")
blas_optdb.register("gemm_optimizer", GemmOptimizer(), "fast_run", position=10)
blas_optdb.register(
"local_gemm_to_gemv",
EquilibriumOptimizer(
......@@ -1819,8 +1823,8 @@ blas_optdb.register(
max_use_ratio=5,
ignore_newtrees=False,
),
15,
"fast_run",
position=15,
)
......@@ -1830,7 +1834,12 @@ blas_opt_inplace = in2out(
local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
)
optdb.register(
"InplaceBlasOpt", blas_opt_inplace, 70.0, "fast_run", "inplace", "blas_opt_inplace"
"InplaceBlasOpt",
blas_opt_inplace,
"fast_run",
"inplace",
"blas_opt_inplace",
position=70.0,
)
......@@ -2048,7 +2057,10 @@ def local_dot22_to_dot22scalar(fgraph, node):
# must happen after gemm as the gemm optimizer don't understant
# dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register(
"local_dot22_to_dot22scalar", in2out(local_dot22_to_dot22scalar), 11, "fast_run"
"local_dot22_to_dot22scalar",
in2out(local_dot22_to_dot22scalar),
"fast_run",
position=11,
)
......
......@@ -730,15 +730,15 @@ def make_c_gemv_destructive(fgraph, node):
# ##### ####### #######
blas_optdb.register(
"use_c_blas", in2out(use_c_ger, use_c_gemv), 20, "fast_run", "c_blas"
"use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20
)
# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"c_blas_destructive",
in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"),
70.0,
"fast_run",
"inplace",
"c_blas",
position=70.0,
)
......@@ -79,13 +79,13 @@ if have_fblas:
# C implementations should be scheduled earlier than this, so that they take
# precedence. Once the original Ger is replaced, then these optimizations
# have no effect.
blas_optdb.register("scipy_blas", use_scipy_blas, 100, "fast_run")
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"make_scipy_blas_destructive",
make_scipy_blas_destructive,
70.0,
"fast_run",
"inplace",
position=70.0,
)
......@@ -2870,9 +2870,9 @@ def local_add_mul_fusion(fgraph, node):
fuse_seqopt.register(
"local_add_mul_fusion",
FusionOptimizer(local_add_mul_fusion),
0,
"fast_run",
"fusion",
position=0,
)
......
......@@ -1949,10 +1949,10 @@ def crossentropy_to_crossentropy_with_softmax(fgraph):
optdb.register(
"crossentropy_to_crossentropy_with_softmax",
crossentropy_to_crossentropy_with_softmax,
2.01,
"fast_run",
"xent",
"fast_compile_gpu",
position=2.01,
)
......
......@@ -912,21 +912,21 @@ register_specialize_device(bn_groupopt, "fast_compile", "fast_run")
bn_groupopt.register(
"local_abstract_batch_norm_train",
local_abstract_batch_norm_train,
30,
"fast_compile",
"fast_run",
position=30,
)
bn_groupopt.register(
"local_abstract_batch_norm_train_grad",
local_abstract_batch_norm_train_grad,
30,
"fast_compile",
"fast_run",
position=30,
)
bn_groupopt.register(
"local_abstract_batch_norm_inference",
local_abstract_batch_norm_inference,
30,
"fast_compile",
"fast_run",
position=30,
)
......@@ -320,7 +320,7 @@ aesara.compile.optdb.register(
TopoOptimizer(
local_inplace_DiagonalSubtensor, failure_callback=TopoOptimizer.warn_inplace
),
60,
"fast_run",
"inplace",
position=60,
)
......@@ -54,9 +54,9 @@ compile.optdb.register(
TopoOptimizer(
local_inplace_sparse_block_gemv, failure_callback=TopoOptimizer.warn_inplace
),
60,
"fast_run",
"inplace",
position=60,
) # DEBUG
......@@ -78,9 +78,9 @@ compile.optdb.register(
local_inplace_sparse_block_outer,
failure_callback=TopoOptimizer.warn_inplace,
),
60,
"fast_run",
"inplace",
position=60,
) # DEBUG
......@@ -500,69 +500,69 @@ register_specialize_device(conv_groupopt, "fast_compile", "fast_run")
conv_groupopt.register(
"local_abstractconv_gemm",
local_abstractconv_gemm,
30,
"conv_gemm",
"fast_compile",
"fast_run",
position=30,
)
conv_groupopt.register(
"local_abstractconv_gradweight_gemm",
local_abstractconv_gradweight_gemm,
30,
"conv_gemm",
"fast_compile",
"fast_run",
position=30,
)
conv_groupopt.register(
"local_abstractconv_gradinputs_gemm",
local_abstractconv_gradinputs_gemm,
30,
"conv_gemm",
"fast_compile",
"fast_run",
position=30,
)
conv_groupopt.register(
"local_abstractconv3d_gemm",
local_abstractconv3d_gemm,
30,
"conv_gemm",
"fast_compile",
"fast_run",
position=30,
)
conv_groupopt.register(
"local_abstractconv3d_gradweight_gemm",
local_abstractconv3d_gradweight_gemm,
30,
"conv_gemm",
"fast_compile",
"fast_run",
position=30,
)
conv_groupopt.register(
"local_abstractconv3d_gradinputs_gemm",
local_abstractconv3d_gradinputs_gemm,
30,
"conv_gemm",
"fast_compile",
"fast_run",
position=30,
)
# Legacy convolution
conv_groupopt.register(
"local_conv2d_cpu", local_conv2d_cpu, 40, "fast_compile", "fast_run"
"local_conv2d_cpu", local_conv2d_cpu, "fast_compile", "fast_run", position=40
)
conv_groupopt.register(
"local_conv2d_gradweight_cpu",
local_conv2d_gradweight_cpu,
40,
"fast_compile",
"fast_run",
position=40,
)
conv_groupopt.register(
"local_conv2d_gradinputs_cpu",
local_conv2d_gradinputs_cpu,
40,
"fast_compile",
"fast_run",
position=40,
)
......@@ -602,7 +602,7 @@ def local_abstractconv_check(fgraph, node):
optdb.register(
"AbstractConvCheck",
in2out(local_abstractconv_check, name="AbstractConvCheck"),
48.7,
"fast_compile",
"fast_run",
position=48.7,
)
......@@ -55,9 +55,9 @@ def random_make_inplace(fgraph, node):
optdb.register(
"random_make_inplace",
in2out(random_make_inplace, ignore_newtrees=True),
99,
"fast_run",
"inplace",
position=99,
)
......
......@@ -89,7 +89,7 @@ def register_useless(lopt, *tags, **kwargs):
name = kwargs.pop("name", None) or lopt.__name__
compile.mode.local_useless.register(
name, lopt, "last", "fast_run", *tags, **kwargs
name, lopt, "fast_run", *tags, position="last", **kwargs
)
return lopt
......@@ -1230,9 +1230,9 @@ def local_IncSubtensor_serialize(fgraph, node):
compile.optdb.register(
"pre_local_IncSubtensor_serialize",
in2out(local_IncSubtensor_serialize),
# Just before canonizer
0.99,
"fast_run",
# Just before canonizer
position=0.99,
)
......@@ -1267,9 +1267,9 @@ compile.optdb.register(
TopoOptimizer(
local_inplace_setsubtensor, failure_callback=TopoOptimizer.warn_inplace
),
60,
"fast_run",
"inplace",
position=60,
)
......@@ -1288,9 +1288,9 @@ compile.optdb.register(
TopoOptimizer(
local_inplace_AdvancedIncSubtensor1, failure_callback=TopoOptimizer.warn_inplace
),
60,
"fast_run",
"inplace",
position=60,
)
......@@ -1313,9 +1313,9 @@ compile.optdb.register(
TopoOptimizer(
local_inplace_AdvancedIncSubtensor, failure_callback=TopoOptimizer.warn_inplace
),
60,
"fast_run",
"inplace",
position=60,
)
......
......@@ -19,7 +19,7 @@ def typed_list_inplace_opt(fgraph, node):
optdb.register(
"typed_list_inplace_opt",
TopoOptimizer(typed_list_inplace_opt, failure_callback=TopoOptimizer.warn_inplace),
60,
"fast_run",
"inplace",
position=60,
)
......@@ -687,7 +687,7 @@ it to :obj:`optdb` as follows:
.. testcode::
# optdb.register(name, optimizer, order, *tags)
optdb.register('simplify', simplify, 0.5, 'fast_run')
optdb.register('simplify', simplify, 'fast_run', position=0.5)
Once this is done, the ``FAST_RUN`` mode will automatically include your
optimization (since you gave it the ``'fast_run'`` tag). Of course,
......
......@@ -56,7 +56,7 @@ class TestDB:
assert isinstance(res, opt.SeqOptimizer)
assert res.data == []
seq_db.register("b", TestOpt(), 1)
seq_db.register("b", TestOpt(), position=1)
from io import StringIO
......@@ -69,7 +69,7 @@ class TestDB:
assert "names {'b'}" in res
with pytest.raises(TypeError, match=r"`position` must be.*"):
seq_db.register("c", TestOpt(), object())
seq_db.register("c", TestOpt(), position=object())
def test_LocalGroupDB(self):
lg_db = LocalGroupDB()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论