提交 e30d7b90 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Change SequenceDB's position argument to a keyword

上级 aab9e8c8
...@@ -854,7 +854,7 @@ def inline_ofg_expansion(fgraph, node): ...@@ -854,7 +854,7 @@ def inline_ofg_expansion(fgraph, node):
optdb.register( optdb.register(
"inline_ofg_expansion", "inline_ofg_expansion",
in2out(inline_ofg_expansion), in2out(inline_ofg_expansion),
-0.01,
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=-0.01,
) )
...@@ -190,7 +190,9 @@ class PrintCurrentFunctionGraph(GlobalOptimizer): ...@@ -190,7 +190,9 @@ class PrintCurrentFunctionGraph(GlobalOptimizer):
optdb = SequenceDB() optdb = SequenceDB()
optdb.register("merge1", MergeOptimizer(), 0, "fast_run", "fast_compile", "merge") optdb.register(
"merge1", MergeOptimizer(), "fast_run", "fast_compile", "merge", position=0
)
# After scan1 opt at 0.5 and before ShapeOpt at 1 # After scan1 opt at 0.5 and before ShapeOpt at 1
# This should only remove nodes. # This should only remove nodes.
...@@ -201,21 +203,23 @@ local_useless = LocalGroupDB(apply_all_opts=True, profile=True) ...@@ -201,21 +203,23 @@ local_useless = LocalGroupDB(apply_all_opts=True, profile=True)
optdb.register( optdb.register(
"useless", "useless",
TopoDB(local_useless, failure_callback=NavigatorOptimizer.warn_inplace), TopoDB(local_useless, failure_callback=NavigatorOptimizer.warn_inplace),
0.6,
"fast_run", "fast_run",
"fast_compile", "fast_compile",
position=0.6,
) )
optdb.register("merge1.1", MergeOptimizer(), 0.65, "fast_run", "fast_compile", "merge") optdb.register(
"merge1.1", MergeOptimizer(), "fast_run", "fast_compile", "merge", position=0.65
)
# rearranges elemwise expressions # rearranges elemwise expressions
optdb.register( optdb.register(
"canonicalize", "canonicalize",
EquilibriumDB(ignore_newtrees=False), EquilibriumDB(ignore_newtrees=False),
1,
"fast_run", "fast_run",
"fast_compile", "fast_compile",
"canonicalize_db", "canonicalize_db",
position=1,
) )
# Register in the canonizer Equilibrium as a clean up opt the merge opt. # Register in the canonizer Equilibrium as a clean up opt the merge opt.
# Without this, as the equilibrium have ignore_newtrees=False, we # Without this, as the equilibrium have ignore_newtrees=False, we
...@@ -228,41 +232,47 @@ optdb["canonicalize"].register( ...@@ -228,41 +232,47 @@ optdb["canonicalize"].register(
"merge", MergeOptimizer(), "fast_run", "fast_compile", cleanup=True "merge", MergeOptimizer(), "fast_run", "fast_compile", cleanup=True
) )
optdb.register("merge1.2", MergeOptimizer(), 1.2, "fast_run", "fast_compile", "merge") optdb.register(
"merge1.2", MergeOptimizer(), "fast_run", "fast_compile", "merge", position=1.2
)
optdb.register( optdb.register(
"Print1.21", "Print1.21",
PrintCurrentFunctionGraph("Post-canonicalize"), PrintCurrentFunctionGraph("Post-canonicalize"),
1.21, position=1.21,
) # 'fast_run', 'fast_compile') ) # 'fast_run', 'fast_compile')
# replace unstable subgraphs # replace unstable subgraphs
optdb.register("stabilize", EquilibriumDB(), 1.5, "fast_run") optdb.register("stabilize", EquilibriumDB(), "fast_run", position=1.5)
optdb.register( optdb.register(
"Print1.51", "Print1.51",
PrintCurrentFunctionGraph("Post-stabilize"), PrintCurrentFunctionGraph("Post-stabilize"),
1.51, position=1.51,
) # 'fast_run', 'fast_compile') ) # 'fast_run', 'fast_compile')
# misc special cases for speed # misc special cases for speed
optdb.register("specialize", EquilibriumDB(), 2, "fast_run", "fast_compile_gpu") optdb.register(
"specialize", EquilibriumDB(), "fast_run", "fast_compile_gpu", position=2
)
# misc special cases for speed that break canonicalization # misc special cases for speed that break canonicalization
optdb.register("uncanonicalize", EquilibriumDB(), 3, "fast_run") optdb.register("uncanonicalize", EquilibriumDB(), "fast_run", position=3)
# misc special cases for speed that are dependent on the device. # misc special cases for speed that are dependent on the device.
optdb.register( optdb.register(
"specialize_device", EquilibriumDB(), 48.6, "fast_compile", "fast_run" "specialize_device", EquilibriumDB(), "fast_compile", "fast_run", position=48.6
) # must be after gpu stuff at 48.5 ) # must be after gpu stuff at 48.5
# especially constant merge # especially constant merge
optdb.register("merge2", MergeOptimizer(), 49, "fast_run", "merge") optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49)
optdb.register("add_destroy_handler", AddDestroyHandler(), 49.5, "fast_run", "inplace") optdb.register(
"add_destroy_handler", AddDestroyHandler(), "fast_run", "inplace", position=49.5
)
# final pass just to make sure # final pass just to make sure
optdb.register("merge3", MergeOptimizer(), 100, "fast_run", "merge") optdb.register("merge3", MergeOptimizer(), "fast_run", "merge", position=100)
_tags: Union[Tuple[str, str], Tuple] _tags: Union[Tuple[str, str], Tuple]
...@@ -272,7 +282,7 @@ if config.check_stack_trace in ("raise", "warn", "log"): ...@@ -272,7 +282,7 @@ if config.check_stack_trace in ("raise", "warn", "log"):
if config.check_stack_trace == "off": if config.check_stack_trace == "off":
_tags = () _tags = ()
optdb.register("CheckStackTrace", CheckStackTraceOptimization(), -1, *_tags) optdb.register("CheckStackTrace", CheckStackTraceOptimization(), *_tags, position=-1)
del _tags del _tags
......
...@@ -2090,5 +2090,10 @@ gpuablas_opt_inplace = in2out( ...@@ -2090,5 +2090,10 @@ gpuablas_opt_inplace = in2out(
) )
optdb.register( optdb.register(
"InplaceGpuaBlasOpt", gpuablas_opt_inplace, 70.0, "fast_run", "inplace", "gpuarray" "InplaceGpuaBlasOpt",
gpuablas_opt_inplace,
"fast_run",
"inplace",
"gpuarray",
position=70.0,
) )
...@@ -432,11 +432,11 @@ optdb.register( ...@@ -432,11 +432,11 @@ optdb.register(
local_dnn_convgi_inplace, local_dnn_convgi_inplace,
name="local_dnna_conv_inplace", name="local_dnna_conv_inplace",
), ),
70.0,
"fast_run", "fast_run",
"inplace", "inplace",
"gpuarray", "gpuarray",
"cudnn", "cudnn",
position=70.0,
) )
...@@ -837,7 +837,7 @@ class NoCuDNNRaise(GlobalOptimizer): ...@@ -837,7 +837,7 @@ class NoCuDNNRaise(GlobalOptimizer):
) )
gpu_seqopt.register("NoCuDNNRaise", NoCuDNNRaise(), 0, "cudnn") gpu_seqopt.register("NoCuDNNRaise", NoCuDNNRaise(), "cudnn", position=0)
@register_inplace() @register_inplace()
......
...@@ -204,23 +204,28 @@ _logger = logging.getLogger("aesara.gpuarray.opt") ...@@ -204,23 +204,28 @@ _logger = logging.getLogger("aesara.gpuarray.opt")
gpu_seqopt.register( gpu_seqopt.register(
"gpuarray_graph_optimization", "gpuarray_graph_optimization",
GraphToGPUDB(), GraphToGPUDB(),
-0.5,
"fast_compile", "fast_compile",
"fast_run", "fast_run",
"gpuarray", "gpuarray",
position=-0.5,
) )
gpu_seqopt.register( gpu_seqopt.register(
"gpuarray_local_optimizations", "gpuarray_local_optimizations",
gpu_optimizer, gpu_optimizer,
1,
"fast_compile", "fast_compile",
"fast_run", "fast_run",
"gpuarray", "gpuarray",
"gpuarray_local_optimiziations", "gpuarray_local_optimiziations",
position=1,
) )
gpu_seqopt.register( gpu_seqopt.register(
"gpuarray_cut_transfers", gpu_cut_copies, 2, "fast_compile", "fast_run", "gpuarray" "gpuarray_cut_transfers",
gpu_cut_copies,
"fast_compile",
"fast_run",
"gpuarray",
position=2,
) )
register_opt("fast_compile")(aesara.tensor.basic_opt.local_track_shape_i) register_opt("fast_compile")(aesara.tensor.basic_opt.local_track_shape_i)
...@@ -280,10 +285,10 @@ class InputToGpuOptimizer(GlobalOptimizer): ...@@ -280,10 +285,10 @@ class InputToGpuOptimizer(GlobalOptimizer):
gpu_seqopt.register( gpu_seqopt.register(
"InputToGpuArrayOptimizer", "InputToGpuArrayOptimizer",
InputToGpuOptimizer(), InputToGpuOptimizer(),
0,
"fast_run", "fast_run",
"fast_compile", "fast_compile",
"merge", "merge",
position=0,
) )
...@@ -702,8 +707,8 @@ optdb.register( ...@@ -702,8 +707,8 @@ optdb.register(
"local_gpua_alloc_empty_to_zeros", "local_gpua_alloc_empty_to_zeros",
aesara.graph.opt.in2out(local_gpua_alloc_empty_to_zeros), aesara.graph.opt.in2out(local_gpua_alloc_empty_to_zeros),
# After move to gpu and merge2, before inplace. # After move to gpu and merge2, before inplace.
49.3,
"alloc_empty_to_zeros", "alloc_empty_to_zeros",
position=49.3,
) )
...@@ -866,27 +871,27 @@ gpu_local_elemwise_fusion = aesara.tensor.basic_opt.local_elemwise_fusion_op( ...@@ -866,27 +871,27 @@ gpu_local_elemwise_fusion = aesara.tensor.basic_opt.local_elemwise_fusion_op(
) )
optdb.register( optdb.register(
"gpua_elemwise_fusion", "gpua_elemwise_fusion",
# 48.5 move to gpu
# 48.6 specialize
# 49 cpu fusion
# 49.5 add destroy handler
aesara.tensor.basic_opt.FusionOptimizer(gpu_local_elemwise_fusion), aesara.tensor.basic_opt.FusionOptimizer(gpu_local_elemwise_fusion),
49,
"fast_run", "fast_run",
"fusion", "fusion",
"local_elemwise_fusion", "local_elemwise_fusion",
"gpuarray", "gpuarray",
# 48.5 move to gpu
# 48.6 specialize
# 49 cpu fusion
# 49.5 add destroy handler
position=49,
) )
inplace_gpu_elemwise_opt = aesara.tensor.basic_opt.InplaceElemwiseOptimizer(GpuElemwise) inplace_gpu_elemwise_opt = aesara.tensor.basic_opt.InplaceElemwiseOptimizer(GpuElemwise)
optdb.register( optdb.register(
"gpua_inplace_opt", "gpua_inplace_opt",
inplace_gpu_elemwise_opt, inplace_gpu_elemwise_opt,
75,
"inplace_elemwise_optimizer", "inplace_elemwise_optimizer",
"fast_run", "fast_run",
"inplace", "inplace",
"gpuarray", "gpuarray",
position=75,
) )
register_opt(aesara.tensor.basic_opt.local_useless_elemwise) register_opt(aesara.tensor.basic_opt.local_useless_elemwise)
...@@ -2608,7 +2613,9 @@ assert_no_cpu_op = aesara.graph.opt.in2out( ...@@ -2608,7 +2613,9 @@ assert_no_cpu_op = aesara.graph.opt.in2out(
local_assert_no_cpu_op, name="assert_no_cpu_op" local_assert_no_cpu_op, name="assert_no_cpu_op"
) )
# 49.2 is after device specialization & fusion optimizations for last transfers # 49.2 is after device specialization & fusion optimizations for last transfers
optdb.register("gpua_assert_no_cpu_op", assert_no_cpu_op, 49.2, "assert_no_cpu_op") optdb.register(
"gpua_assert_no_cpu_op", assert_no_cpu_op, "assert_no_cpu_op", position=49.2
)
def tensor_to_gpu(x, context_name): def tensor_to_gpu(x, context_name):
...@@ -2961,97 +2968,97 @@ def local_gpu_ctc(fgraph, op, context_name, inputs, outputs): ...@@ -2961,97 +2968,97 @@ def local_gpu_ctc(fgraph, op, context_name, inputs, outputs):
optdb.register( optdb.register(
"gpua_scanOp_make_inplace", "gpua_scanOp_make_inplace",
ScanInplaceOptimizer(typeInfer=_scan_type_infer, gpua_flag=True), ScanInplaceOptimizer(typeInfer=_scan_type_infer, gpua_flag=True),
75,
"gpuarray", "gpuarray",
"inplace", "inplace",
"scan", "scan",
position=75,
) )
abstractconv_groupopt.register( abstractconv_groupopt.register(
"local_abstractconv_dnn", "local_abstractconv_dnn",
local_abstractconv_cudnn, local_abstractconv_cudnn,
20,
"conv_dnn", "conv_dnn",
"gpuarray", "gpuarray",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
"cudnn", "cudnn",
position=20,
) )
abstractconv_groupopt.register( abstractconv_groupopt.register(
"local_abstractconv_gw_dnn", "local_abstractconv_gw_dnn",
local_abstractconv_gw_cudnn, local_abstractconv_gw_cudnn,
20,
"conv_dnn", "conv_dnn",
"gpuarray", "gpuarray",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
"cudnn", "cudnn",
position=20,
) )
abstractconv_groupopt.register( abstractconv_groupopt.register(
"local_abstractconv_gi_dnn", "local_abstractconv_gi_dnn",
local_abstractconv_gi_cudnn, local_abstractconv_gi_cudnn,
20,
"conv_dnn", "conv_dnn",
"gpuarray", "gpuarray",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
"cudnn", "cudnn",
position=20,
) )
# The GEMM-based convolution comes last to catch all remaining cases. # The GEMM-based convolution comes last to catch all remaining cases.
# It can be disabled by excluding 'conv_gemm'. # It can be disabled by excluding 'conv_gemm'.
abstractconv_groupopt.register( abstractconv_groupopt.register(
"local_abstractconv_gemm", "local_abstractconv_gemm",
local_abstractconv_gemm, local_abstractconv_gemm,
30,
"conv_gemm", "conv_gemm",
"gpuarray", "gpuarray",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
abstractconv_groupopt.register( abstractconv_groupopt.register(
"local_abstractconv3d_gemm", "local_abstractconv3d_gemm",
local_abstractconv3d_gemm, local_abstractconv3d_gemm,
30,
"conv_gemm", "conv_gemm",
"gpuarray", "gpuarray",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
abstractconv_groupopt.register( abstractconv_groupopt.register(
"local_abstractconv_gradweights_gemm", "local_abstractconv_gradweights_gemm",
local_abstractconv_gradweights_gemm, local_abstractconv_gradweights_gemm,
30,
"conv_gemm", "conv_gemm",
"gpuarray", "gpuarray",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
abstractconv_groupopt.register( abstractconv_groupopt.register(
"local_abstractconv3d_gradweights_gemm", "local_abstractconv3d_gradweights_gemm",
local_abstractconv3d_gradweights_gemm, local_abstractconv3d_gradweights_gemm,
30,
"conv_gemm", "conv_gemm",
"gpuarray", "gpuarray",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
abstractconv_groupopt.register( abstractconv_groupopt.register(
"local_abstractconv_gradinputs", "local_abstractconv_gradinputs",
local_abstractconv_gradinputs_gemm, local_abstractconv_gradinputs_gemm,
30,
"conv_gemm", "conv_gemm",
"gpuarray", "gpuarray",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
abstractconv_groupopt.register( abstractconv_groupopt.register(
"local_abstractconv3d_gradinputs", "local_abstractconv3d_gradinputs",
local_abstractconv3d_gradinputs_gemm, local_abstractconv3d_gradinputs_gemm,
30,
"conv_gemm", "conv_gemm",
"gpuarray", "gpuarray",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
conv_metaopt = ConvMetaOptimizer() conv_metaopt = ConvMetaOptimizer()
......
...@@ -60,8 +60,8 @@ gpu_seqopt = SequenceDB() ...@@ -60,8 +60,8 @@ gpu_seqopt = SequenceDB()
optdb.register( optdb.register(
"gpuarray_opt", "gpuarray_opt",
gpu_seqopt, gpu_seqopt,
optdb.__position__.get("add_destroy_handler", 49.5) - 1,
"gpuarray", "gpuarray",
position=optdb.__position__.get("add_destroy_handler", 49.5) - 1,
) )
...@@ -123,11 +123,11 @@ def register_inplace(*tags, **kwargs): ...@@ -123,11 +123,11 @@ def register_inplace(*tags, **kwargs):
optdb.register( optdb.register(
name, name,
TopoOptimizer(local_opt, failure_callback=TopoOptimizer.warn_inplace), TopoOptimizer(local_opt, failure_callback=TopoOptimizer.warn_inplace),
60,
"fast_run", "fast_run",
"inplace", "inplace",
"gpuarray", "gpuarray",
*tags, *tags,
position=60,
) )
return local_opt return local_opt
......
...@@ -34,6 +34,7 @@ class OptimizationDatabase: ...@@ -34,6 +34,7 @@ class OptimizationDatabase:
optimizer: Union["OptimizationDatabase", OptimizersType], optimizer: Union["OptimizationDatabase", OptimizersType],
*tags: str, *tags: str,
use_db_name_as_tag=True, use_db_name_as_tag=True,
**kwargs,
): ):
"""Register a new optimizer to the database. """Register a new optimizer to the database.
...@@ -339,10 +340,10 @@ class EquilibriumDB(OptimizationDatabase): ...@@ -339,10 +340,10 @@ class EquilibriumDB(OptimizationDatabase):
self.__final__ = {} self.__final__ = {}
self.__cleanup__ = {} self.__cleanup__ = {}
def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwtags): def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwargs):
if final_opt and cleanup: if final_opt and cleanup:
raise ValueError("`final_opt` and `cleanup` cannot both be true.") raise ValueError("`final_opt` and `cleanup` cannot both be true.")
super().register(name, obj, *tags, **kwtags) super().register(name, obj, *tags, **kwargs)
self.__final__[name] = final_opt self.__final__[name] = final_opt
self.__cleanup__[name] = cleanup self.__cleanup__[name] = cleanup
...@@ -387,7 +388,9 @@ class SequenceDB(OptimizationDatabase): ...@@ -387,7 +388,9 @@ class SequenceDB(OptimizationDatabase):
self.__position__ = {} self.__position__ = {}
self.failure_callback = failure_callback self.failure_callback = failure_callback
def register(self, name, obj, position: Union[str, int, float], *tags, **kwargs): def register(
self, name, obj, *tags, position: Union[str, int, float] = "last", **kwargs
):
super().register(name, obj, *tags, **kwargs) super().register(name, obj, *tags, **kwargs)
if position == "last": if position == "last":
if len(self.__position__) == 0: if len(self.__position__) == 0:
...@@ -493,7 +496,7 @@ class LocalGroupDB(SequenceDB): ...@@ -493,7 +496,7 @@ class LocalGroupDB(SequenceDB):
self.__name__: str = "" self.__name__: str = ""
def register(self, name, obj, *tags, position="last", **kwargs): def register(self, name, obj, *tags, position="last", **kwargs):
super().register(name, obj, position, *tags, **kwargs) super().register(name, obj, *tags, position=position, **kwargs)
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
opts = list(super().query(*tags, **kwtags)) opts = list(super().query(*tags, **kwtags))
......
...@@ -420,9 +420,9 @@ def cond_make_inplace(fgraph, node): ...@@ -420,9 +420,9 @@ def cond_make_inplace(fgraph, node):
optdb.register( optdb.register(
"cond_make_inplace", "cond_make_inplace",
in2out(cond_make_inplace, ignore_newtrees=True), in2out(cond_make_inplace, ignore_newtrees=True),
95,
"fast_run", "fast_run",
"inplace", "inplace",
position=95,
) )
# XXX: Optimizations commented pending further debugging (certain optimizations # XXX: Optimizations commented pending further debugging (certain optimizations
...@@ -456,8 +456,8 @@ where, each of the optimization do the following things: ...@@ -456,8 +456,8 @@ where, each of the optimization do the following things:
`ifelse_lift` (def cond_lift_single_if): `ifelse_lift` (def cond_lift_single_if):
""" """
# optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, .5, 'fast_run', # optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, 'fast_run',
# 'ifelse') # 'ifelse', position=.5)
acceptable_ops = ( acceptable_ops = (
...@@ -768,26 +768,26 @@ def cond_merge_random_op(fgraph, main_node): ...@@ -768,26 +768,26 @@ def cond_merge_random_op(fgraph, main_node):
# #
# ifelse_seqopt.register('ifelse_condPushOut_equilibrium', # ifelse_seqopt.register('ifelse_condPushOut_equilibrium',
# pushout_equilibrium, # pushout_equilibrium,
# 1, 'fast_run', 'ifelse') # 'fast_run', 'ifelse', position=1)
# #
# ifelse_seqopt.register('merge_nodes_1', # ifelse_seqopt.register('merge_nodes_1',
# graph.opt.MergeOptimizer(skip_const_merge=False), # graph.opt.MergeOptimizer(skip_const_merge=False),
# 2, 'fast_run', 'ifelse') # 'fast_run', 'ifelse', position=2)
# #
# #
# ifelse_seqopt.register('ifelse_sameCondTrue', # ifelse_seqopt.register('ifelse_sameCondTrue',
# in2out(cond_merge_ifs_true, # in2out(cond_merge_ifs_true,
# ignore_newtrees=True), # ignore_newtrees=True),
# 3, 'fast_run', 'ifelse') # 'fast_run', 'ifelse', position=3)
# #
# #
# ifelse_seqopt.register('ifelse_sameCondFalse', # ifelse_seqopt.register('ifelse_sameCondFalse',
# in2out(cond_merge_ifs_false, # in2out(cond_merge_ifs_false,
# ignore_newtrees=True), # ignore_newtrees=True),
# 4, 'fast_run', 'ifelse') # 'fast_run', 'ifelse', position=4)
# #
# #
# ifelse_seqopt.register('ifelse_removeIdenetical', # ifelse_seqopt.register('ifelse_removeIdenetical',
# in2out(cond_remove_identical, # in2out(cond_remove_identical,
# ignore_newtrees=True), # ignore_newtrees=True),
# 7, 'fast_run', 'ifelse') # 'fast_run', 'ifelse', position=7)
...@@ -1358,7 +1358,7 @@ def mrg_random_make_inplace(fgraph, node): ...@@ -1358,7 +1358,7 @@ def mrg_random_make_inplace(fgraph, node):
optdb.register( optdb.register(
"random_make_inplace_mrg", "random_make_inplace_mrg",
in2out(mrg_random_make_inplace, ignore_newtrees=True), in2out(mrg_random_make_inplace, ignore_newtrees=True),
99,
"fast_run", "fast_run",
"inplace", "inplace",
position=99,
) )
...@@ -2336,68 +2336,68 @@ scan_eqopt2 = EquilibriumDB() ...@@ -2336,68 +2336,68 @@ scan_eqopt2 = EquilibriumDB()
# scan_eqopt1 before ShapeOpt at 0.1 # scan_eqopt1 before ShapeOpt at 0.1
# This is needed to don't have ShapeFeature trac old Scan that we # This is needed to don't have ShapeFeature trac old Scan that we
# don't want to reintroduce. # don't want to reintroduce.
optdb.register("scan_eqopt1", scan_eqopt1, 0.05, "fast_run", "scan") optdb.register("scan_eqopt1", scan_eqopt1, "fast_run", "scan", position=0.05)
# We run before blas opt at 1.7 and specialize 2.0 # We run before blas opt at 1.7 and specialize 2.0
# but after stabilize at 1.5. Should we put it before stabilize? # but after stabilize at 1.5. Should we put it before stabilize?
optdb.register("scan_eqopt2", scan_eqopt2, 1.6, "fast_run", "scan") optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
# ScanSaveMem should execute only once per node. # ScanSaveMem should execute only once per node.
optdb.register( optdb.register(
"scan_save_mem", "scan_save_mem",
in2out(save_mem_new_scan, ignore_newtrees=True), in2out(save_mem_new_scan, ignore_newtrees=True),
1.61,
"fast_run", "fast_run",
"scan", "scan",
position=1.61,
) )
optdb.register( optdb.register(
"scan_make_inplace", "scan_make_inplace",
ScanInplaceOptimizer(typeInfer=None), ScanInplaceOptimizer(typeInfer=None),
75,
"fast_run", "fast_run",
"inplace", "inplace",
"scan", "scan",
position=75,
) )
scan_eqopt1.register("all_pushout_opt", scan_seqopt1, 1, "fast_run", "scan") scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan", position=1)
scan_seqopt1.register( scan_seqopt1.register(
"scan_remove_constants_and_unused_inputs0", "scan_remove_constants_and_unused_inputs0",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
1,
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
"fast_run", "fast_run",
"scan", "scan",
position=1,
) )
scan_seqopt1.register( scan_seqopt1.register(
"scan_pushout_nonseqs_ops", "scan_pushout_nonseqs_ops",
in2out(push_out_non_seq_scan, ignore_newtrees=True), in2out(push_out_non_seq_scan, ignore_newtrees=True),
2,
"fast_run", "fast_run",
"scan", "scan",
"scan_pushout", "scan_pushout",
position=2,
) )
scan_seqopt1.register( scan_seqopt1.register(
"scan_pushout_seqs_ops", "scan_pushout_seqs_ops",
in2out(push_out_seq_scan, ignore_newtrees=True), in2out(push_out_seq_scan, ignore_newtrees=True),
3,
"fast_run", "fast_run",
"scan", "scan",
"scan_pushout", "scan_pushout",
position=3,
) )
scan_seqopt1.register( scan_seqopt1.register(
"scan_pushout_dot1", "scan_pushout_dot1",
in2out(push_out_dot1_scan, ignore_newtrees=True), in2out(push_out_dot1_scan, ignore_newtrees=True),
4,
"fast_run", "fast_run",
"more_mem", "more_mem",
"scan", "scan",
"scan_pushout", "scan_pushout",
position=4,
) )
...@@ -2405,62 +2405,62 @@ scan_seqopt1.register( ...@@ -2405,62 +2405,62 @@ scan_seqopt1.register(
"scan_pushout_add", "scan_pushout_add",
# TODO: Perhaps this should be an `EquilibriumOptimizer`? # TODO: Perhaps this should be an `EquilibriumOptimizer`?
in2out(push_out_add_scan, ignore_newtrees=False), in2out(push_out_add_scan, ignore_newtrees=False),
5,
"fast_run", "fast_run",
"more_mem", "more_mem",
"scan", "scan",
"scan_pushout", "scan_pushout",
position=5,
) )
scan_eqopt2.register( scan_eqopt2.register(
"constant_folding_for_scan2", "constant_folding_for_scan2",
in2out(basic_opt.constant_folding, ignore_newtrees=True), in2out(basic_opt.constant_folding, ignore_newtrees=True),
1,
"fast_run", "fast_run",
"scan", "scan",
position=1,
) )
scan_eqopt2.register( scan_eqopt2.register(
"scan_remove_constants_and_unused_inputs1", "scan_remove_constants_and_unused_inputs1",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2,
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
"fast_run", "fast_run",
"scan", "scan",
position=2,
) )
# after const merge but before stabilize so that we can have identity # after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out # for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later. # of the scan later.
scan_eqopt2.register("scan_merge", ScanMerge(), 4, "fast_run", "scan") scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan", position=4)
# After Merge optimization # After Merge optimization
scan_eqopt2.register( scan_eqopt2.register(
"scan_remove_constants_and_unused_inputs2", "scan_remove_constants_and_unused_inputs2",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
5,
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
"fast_run", "fast_run",
"scan", "scan",
position=5,
) )
scan_eqopt2.register( scan_eqopt2.register(
"scan_merge_inouts", "scan_merge_inouts",
in2out(scan_merge_inouts, ignore_newtrees=True), in2out(scan_merge_inouts, ignore_newtrees=True),
6,
"fast_run", "fast_run",
"scan", "scan",
position=6,
) )
# After everything else # After everything else
scan_eqopt2.register( scan_eqopt2.register(
"scan_remove_constants_and_unused_inputs3", "scan_remove_constants_and_unused_inputs3",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
8,
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
"fast_run", "fast_run",
"scan", "scan",
position=8,
) )
...@@ -75,9 +75,9 @@ def local_inplace_remove0(fgraph, node): ...@@ -75,9 +75,9 @@ def local_inplace_remove0(fgraph, node):
aesara.compile.optdb.register( aesara.compile.optdb.register(
"local_inplace_remove0", "local_inplace_remove0",
TopoOptimizer(local_inplace_remove0, failure_callback=TopoOptimizer.warn_inplace), TopoOptimizer(local_inplace_remove0, failure_callback=TopoOptimizer.warn_inplace),
60,
"fast_run", "fast_run",
"inplace", "inplace",
position=60,
) )
...@@ -216,9 +216,9 @@ aesara.compile.optdb.register( ...@@ -216,9 +216,9 @@ aesara.compile.optdb.register(
TopoOptimizer( TopoOptimizer(
local_inplace_addsd_ccode, failure_callback=TopoOptimizer.warn_inplace local_inplace_addsd_ccode, failure_callback=TopoOptimizer.warn_inplace
), ),
60,
"fast_run", "fast_run",
"inplace", "inplace",
position=60,
) )
...@@ -248,8 +248,8 @@ aesara.compile.optdb.register( ...@@ -248,8 +248,8 @@ aesara.compile.optdb.register(
"local_addsd_ccode", "local_addsd_ccode",
TopoOptimizer(local_addsd_ccode), TopoOptimizer(local_addsd_ccode),
# Must be after local_inplace_addsd_ccode at 60 # Must be after local_inplace_addsd_ccode at 60
61,
"fast_run", "fast_run",
position=61,
) )
......
...@@ -474,11 +474,11 @@ inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise) ...@@ -474,11 +474,11 @@ inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise)
compile.optdb.register( compile.optdb.register(
"inplace_elemwise_opt", "inplace_elemwise_opt",
inplace_elemwise_optimizer, inplace_elemwise_optimizer,
75,
"inplace_opt", # for historic reason "inplace_opt", # for historic reason
"inplace_elemwise_optimizer", "inplace_elemwise_optimizer",
"fast_run", "fast_run",
"inplace", "inplace",
position=75,
) )
...@@ -493,7 +493,7 @@ def register_useless(lopt, *tags, **kwargs): ...@@ -493,7 +493,7 @@ def register_useless(lopt, *tags, **kwargs):
name = kwargs.pop("name", None) or lopt.__name__ name = kwargs.pop("name", None) or lopt.__name__
compile.mode.local_useless.register( compile.mode.local_useless.register(
name, lopt, "last", "fast_run", *tags, **kwargs name, lopt, "fast_run", *tags, position="last", **kwargs
) )
return lopt return lopt
...@@ -1475,12 +1475,12 @@ class UnShapeOptimizer(GlobalOptimizer): ...@@ -1475,12 +1475,12 @@ class UnShapeOptimizer(GlobalOptimizer):
# Register it after merge1 optimization at 0. We don't want to track # Register it after merge1 optimization at 0. We don't want to track
# the shape of merged node. # the shape of merged node.
aesara.compile.mode.optdb.register( aesara.compile.mode.optdb.register(
"ShapeOpt", ShapeOptimizer(), 0.1, "fast_run", "fast_compile" "ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1
) )
# Not enabled by default for now. Some crossentropy opt use the # Not enabled by default for now. Some crossentropy opt use the
# shape_feature. They are at step 2.01. uncanonicalize is at step # shape_feature. They are at step 2.01. uncanonicalize is at step
# 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable. # 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable.
aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), 10) aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
@register_specialize("local_alloc_elemwise") @register_specialize("local_alloc_elemwise")
...@@ -1741,11 +1741,11 @@ def local_fill_to_alloc(fgraph, node): ...@@ -1741,11 +1741,11 @@ def local_fill_to_alloc(fgraph, node):
# Register this after stabilize at 1.5 to make sure stabilize don't # Register this after stabilize at 1.5 to make sure stabilize don't
# get affected by less canonicalized graph due to alloc. # get affected by less canonicalized graph due to alloc.
compile.optdb.register( compile.optdb.register(
"local_fill_to_alloc", in2out(local_fill_to_alloc), 1.51, "fast_run" "local_fill_to_alloc", in2out(local_fill_to_alloc), "fast_run", position=1.51
) )
# Needed to clean some extra alloc added by local_fill_to_alloc # Needed to clean some extra alloc added by local_fill_to_alloc
compile.optdb.register( compile.optdb.register(
"local_elemwise_alloc", in2out(local_elemwise_alloc), 1.52, "fast_run" "local_elemwise_alloc", in2out(local_elemwise_alloc), "fast_run", position=1.52
) )
...@@ -1856,8 +1856,8 @@ compile.optdb.register( ...@@ -1856,8 +1856,8 @@ compile.optdb.register(
"local_alloc_empty_to_zeros", "local_alloc_empty_to_zeros",
in2out(local_alloc_empty_to_zeros), in2out(local_alloc_empty_to_zeros),
# After move to gpu and merge2, before inplace. # After move to gpu and merge2, before inplace.
49.3,
"alloc_empty_to_zeros", "alloc_empty_to_zeros",
position=49.3,
) )
...@@ -3369,28 +3369,28 @@ if config.tensor__local_elemwise_fusion: ...@@ -3369,28 +3369,28 @@ if config.tensor__local_elemwise_fusion:
fuse_seqopt.register( fuse_seqopt.register(
"composite_elemwise_fusion", "composite_elemwise_fusion",
FusionOptimizer(local_elemwise_fusion), FusionOptimizer(local_elemwise_fusion),
1,
"fast_run", "fast_run",
"fusion", "fusion",
position=1,
) )
compile.optdb.register( compile.optdb.register(
"elemwise_fusion", "elemwise_fusion",
fuse_seqopt, fuse_seqopt,
49,
"fast_run", "fast_run",
"fusion", "fusion",
"local_elemwise_fusion", "local_elemwise_fusion",
"FusionOptimizer", "FusionOptimizer",
position=49,
) )
else: else:
_logger.debug("not enabling optimization fusion elemwise in fast_run") _logger.debug("not enabling optimization fusion elemwise in fast_run")
compile.optdb.register( compile.optdb.register(
"elemwise_fusion", "elemwise_fusion",
FusionOptimizer(local_elemwise_fusion), FusionOptimizer(local_elemwise_fusion),
49,
"fusion", "fusion",
"local_elemwise_fusion", "local_elemwise_fusion",
"FusionOptimizer", "FusionOptimizer",
position=49,
) )
......
...@@ -1798,15 +1798,19 @@ def local_dot22_to_ger_or_gemv(fgraph, node): ...@@ -1798,15 +1798,19 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
blas_optdb = SequenceDB() blas_optdb = SequenceDB()
# run after numerical stability optimizations (1.5) # run after numerical stability optimizations (1.5)
optdb.register("BlasOpt", blas_optdb, 1.7, "fast_run", "fast_compile") optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7)
# run before specialize (2.0) because specialize is basically a # run before specialize (2.0) because specialize is basically a
# free-for-all that makes the graph crazy. # free-for-all that makes the graph crazy.
# fast_compile is needed to have GpuDot22 created. # fast_compile is needed to have GpuDot22 created.
blas_optdb.register( blas_optdb.register(
"local_dot_to_dot22", in2out(local_dot_to_dot22), 0, "fast_run", "fast_compile" "local_dot_to_dot22",
in2out(local_dot_to_dot22),
"fast_run",
"fast_compile",
position=0,
) )
blas_optdb.register("gemm_optimizer", GemmOptimizer(), 10, "fast_run") blas_optdb.register("gemm_optimizer", GemmOptimizer(), "fast_run", position=10)
blas_optdb.register( blas_optdb.register(
"local_gemm_to_gemv", "local_gemm_to_gemv",
EquilibriumOptimizer( EquilibriumOptimizer(
...@@ -1819,8 +1823,8 @@ blas_optdb.register( ...@@ -1819,8 +1823,8 @@ blas_optdb.register(
max_use_ratio=5, max_use_ratio=5,
ignore_newtrees=False, ignore_newtrees=False,
), ),
15,
"fast_run", "fast_run",
position=15,
) )
...@@ -1830,7 +1834,12 @@ blas_opt_inplace = in2out( ...@@ -1830,7 +1834,12 @@ blas_opt_inplace = in2out(
local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace" local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
) )
optdb.register( optdb.register(
"InplaceBlasOpt", blas_opt_inplace, 70.0, "fast_run", "inplace", "blas_opt_inplace" "InplaceBlasOpt",
blas_opt_inplace,
"fast_run",
"inplace",
"blas_opt_inplace",
position=70.0,
) )
...@@ -2048,7 +2057,10 @@ def local_dot22_to_dot22scalar(fgraph, node): ...@@ -2048,7 +2057,10 @@ def local_dot22_to_dot22scalar(fgraph, node):
# must happen after gemm as the gemm optimizer don't understant # must happen after gemm as the gemm optimizer don't understant
# dot22scalar and gemm give more speed up then dot22scalar # dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register( blas_optdb.register(
"local_dot22_to_dot22scalar", in2out(local_dot22_to_dot22scalar), 11, "fast_run" "local_dot22_to_dot22scalar",
in2out(local_dot22_to_dot22scalar),
"fast_run",
position=11,
) )
......
...@@ -730,15 +730,15 @@ def make_c_gemv_destructive(fgraph, node): ...@@ -730,15 +730,15 @@ def make_c_gemv_destructive(fgraph, node):
# ##### ####### ####### # ##### ####### #######
blas_optdb.register( blas_optdb.register(
"use_c_blas", in2out(use_c_ger, use_c_gemv), 20, "fast_run", "c_blas" "use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20
) )
# this matches the InplaceBlasOpt defined in blas.py # this matches the InplaceBlasOpt defined in blas.py
optdb.register( optdb.register(
"c_blas_destructive", "c_blas_destructive",
in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"), in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"),
70.0,
"fast_run", "fast_run",
"inplace", "inplace",
"c_blas", "c_blas",
position=70.0,
) )
...@@ -79,13 +79,13 @@ if have_fblas: ...@@ -79,13 +79,13 @@ if have_fblas:
# C implementations should be scheduled earlier than this, so that they take # C implementations should be scheduled earlier than this, so that they take
# precedence. Once the original Ger is replaced, then these optimizations # precedence. Once the original Ger is replaced, then these optimizations
# have no effect. # have no effect.
blas_optdb.register("scipy_blas", use_scipy_blas, 100, "fast_run") blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
# this matches the InplaceBlasOpt defined in blas.py # this matches the InplaceBlasOpt defined in blas.py
optdb.register( optdb.register(
"make_scipy_blas_destructive", "make_scipy_blas_destructive",
make_scipy_blas_destructive, make_scipy_blas_destructive,
70.0,
"fast_run", "fast_run",
"inplace", "inplace",
position=70.0,
) )
...@@ -2870,9 +2870,9 @@ def local_add_mul_fusion(fgraph, node): ...@@ -2870,9 +2870,9 @@ def local_add_mul_fusion(fgraph, node):
fuse_seqopt.register( fuse_seqopt.register(
"local_add_mul_fusion", "local_add_mul_fusion",
FusionOptimizer(local_add_mul_fusion), FusionOptimizer(local_add_mul_fusion),
0,
"fast_run", "fast_run",
"fusion", "fusion",
position=0,
) )
......
...@@ -1949,10 +1949,10 @@ def crossentropy_to_crossentropy_with_softmax(fgraph): ...@@ -1949,10 +1949,10 @@ def crossentropy_to_crossentropy_with_softmax(fgraph):
optdb.register( optdb.register(
"crossentropy_to_crossentropy_with_softmax", "crossentropy_to_crossentropy_with_softmax",
crossentropy_to_crossentropy_with_softmax, crossentropy_to_crossentropy_with_softmax,
2.01,
"fast_run", "fast_run",
"xent", "xent",
"fast_compile_gpu", "fast_compile_gpu",
position=2.01,
) )
......
...@@ -912,21 +912,21 @@ register_specialize_device(bn_groupopt, "fast_compile", "fast_run") ...@@ -912,21 +912,21 @@ register_specialize_device(bn_groupopt, "fast_compile", "fast_run")
bn_groupopt.register( bn_groupopt.register(
"local_abstract_batch_norm_train", "local_abstract_batch_norm_train",
local_abstract_batch_norm_train, local_abstract_batch_norm_train,
30,
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
bn_groupopt.register( bn_groupopt.register(
"local_abstract_batch_norm_train_grad", "local_abstract_batch_norm_train_grad",
local_abstract_batch_norm_train_grad, local_abstract_batch_norm_train_grad,
30,
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
bn_groupopt.register( bn_groupopt.register(
"local_abstract_batch_norm_inference", "local_abstract_batch_norm_inference",
local_abstract_batch_norm_inference, local_abstract_batch_norm_inference,
30,
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
...@@ -320,7 +320,7 @@ aesara.compile.optdb.register( ...@@ -320,7 +320,7 @@ aesara.compile.optdb.register(
TopoOptimizer( TopoOptimizer(
local_inplace_DiagonalSubtensor, failure_callback=TopoOptimizer.warn_inplace local_inplace_DiagonalSubtensor, failure_callback=TopoOptimizer.warn_inplace
), ),
60,
"fast_run", "fast_run",
"inplace", "inplace",
position=60,
) )
...@@ -54,9 +54,9 @@ compile.optdb.register( ...@@ -54,9 +54,9 @@ compile.optdb.register(
TopoOptimizer( TopoOptimizer(
local_inplace_sparse_block_gemv, failure_callback=TopoOptimizer.warn_inplace local_inplace_sparse_block_gemv, failure_callback=TopoOptimizer.warn_inplace
), ),
60,
"fast_run", "fast_run",
"inplace", "inplace",
position=60,
) # DEBUG ) # DEBUG
...@@ -78,9 +78,9 @@ compile.optdb.register( ...@@ -78,9 +78,9 @@ compile.optdb.register(
local_inplace_sparse_block_outer, local_inplace_sparse_block_outer,
failure_callback=TopoOptimizer.warn_inplace, failure_callback=TopoOptimizer.warn_inplace,
), ),
60,
"fast_run", "fast_run",
"inplace", "inplace",
position=60,
) # DEBUG ) # DEBUG
...@@ -500,69 +500,69 @@ register_specialize_device(conv_groupopt, "fast_compile", "fast_run") ...@@ -500,69 +500,69 @@ register_specialize_device(conv_groupopt, "fast_compile", "fast_run")
conv_groupopt.register( conv_groupopt.register(
"local_abstractconv_gemm", "local_abstractconv_gemm",
local_abstractconv_gemm, local_abstractconv_gemm,
30,
"conv_gemm", "conv_gemm",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
conv_groupopt.register( conv_groupopt.register(
"local_abstractconv_gradweight_gemm", "local_abstractconv_gradweight_gemm",
local_abstractconv_gradweight_gemm, local_abstractconv_gradweight_gemm,
30,
"conv_gemm", "conv_gemm",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
conv_groupopt.register( conv_groupopt.register(
"local_abstractconv_gradinputs_gemm", "local_abstractconv_gradinputs_gemm",
local_abstractconv_gradinputs_gemm, local_abstractconv_gradinputs_gemm,
30,
"conv_gemm", "conv_gemm",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
conv_groupopt.register( conv_groupopt.register(
"local_abstractconv3d_gemm", "local_abstractconv3d_gemm",
local_abstractconv3d_gemm, local_abstractconv3d_gemm,
30,
"conv_gemm", "conv_gemm",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
conv_groupopt.register( conv_groupopt.register(
"local_abstractconv3d_gradweight_gemm", "local_abstractconv3d_gradweight_gemm",
local_abstractconv3d_gradweight_gemm, local_abstractconv3d_gradweight_gemm,
30,
"conv_gemm", "conv_gemm",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
conv_groupopt.register( conv_groupopt.register(
"local_abstractconv3d_gradinputs_gemm", "local_abstractconv3d_gradinputs_gemm",
local_abstractconv3d_gradinputs_gemm, local_abstractconv3d_gradinputs_gemm,
30,
"conv_gemm", "conv_gemm",
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=30,
) )
# Legacy convolution # Legacy convolution
conv_groupopt.register( conv_groupopt.register(
"local_conv2d_cpu", local_conv2d_cpu, 40, "fast_compile", "fast_run" "local_conv2d_cpu", local_conv2d_cpu, "fast_compile", "fast_run", position=40
) )
conv_groupopt.register( conv_groupopt.register(
"local_conv2d_gradweight_cpu", "local_conv2d_gradweight_cpu",
local_conv2d_gradweight_cpu, local_conv2d_gradweight_cpu,
40,
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=40,
) )
conv_groupopt.register( conv_groupopt.register(
"local_conv2d_gradinputs_cpu", "local_conv2d_gradinputs_cpu",
local_conv2d_gradinputs_cpu, local_conv2d_gradinputs_cpu,
40,
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=40,
) )
...@@ -602,7 +602,7 @@ def local_abstractconv_check(fgraph, node): ...@@ -602,7 +602,7 @@ def local_abstractconv_check(fgraph, node):
optdb.register( optdb.register(
"AbstractConvCheck", "AbstractConvCheck",
in2out(local_abstractconv_check, name="AbstractConvCheck"), in2out(local_abstractconv_check, name="AbstractConvCheck"),
48.7,
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=48.7,
) )
...@@ -55,9 +55,9 @@ def random_make_inplace(fgraph, node): ...@@ -55,9 +55,9 @@ def random_make_inplace(fgraph, node):
optdb.register( optdb.register(
"random_make_inplace", "random_make_inplace",
in2out(random_make_inplace, ignore_newtrees=True), in2out(random_make_inplace, ignore_newtrees=True),
99,
"fast_run", "fast_run",
"inplace", "inplace",
position=99,
) )
......
...@@ -89,7 +89,7 @@ def register_useless(lopt, *tags, **kwargs): ...@@ -89,7 +89,7 @@ def register_useless(lopt, *tags, **kwargs):
name = kwargs.pop("name", None) or lopt.__name__ name = kwargs.pop("name", None) or lopt.__name__
compile.mode.local_useless.register( compile.mode.local_useless.register(
name, lopt, "last", "fast_run", *tags, **kwargs name, lopt, "fast_run", *tags, position="last", **kwargs
) )
return lopt return lopt
...@@ -1230,9 +1230,9 @@ def local_IncSubtensor_serialize(fgraph, node): ...@@ -1230,9 +1230,9 @@ def local_IncSubtensor_serialize(fgraph, node):
compile.optdb.register( compile.optdb.register(
"pre_local_IncSubtensor_serialize", "pre_local_IncSubtensor_serialize",
in2out(local_IncSubtensor_serialize), in2out(local_IncSubtensor_serialize),
# Just before canonizer
0.99,
"fast_run", "fast_run",
# Just before canonizer
position=0.99,
) )
...@@ -1267,9 +1267,9 @@ compile.optdb.register( ...@@ -1267,9 +1267,9 @@ compile.optdb.register(
TopoOptimizer( TopoOptimizer(
local_inplace_setsubtensor, failure_callback=TopoOptimizer.warn_inplace local_inplace_setsubtensor, failure_callback=TopoOptimizer.warn_inplace
), ),
60,
"fast_run", "fast_run",
"inplace", "inplace",
position=60,
) )
...@@ -1288,9 +1288,9 @@ compile.optdb.register( ...@@ -1288,9 +1288,9 @@ compile.optdb.register(
TopoOptimizer( TopoOptimizer(
local_inplace_AdvancedIncSubtensor1, failure_callback=TopoOptimizer.warn_inplace local_inplace_AdvancedIncSubtensor1, failure_callback=TopoOptimizer.warn_inplace
), ),
60,
"fast_run", "fast_run",
"inplace", "inplace",
position=60,
) )
...@@ -1313,9 +1313,9 @@ compile.optdb.register( ...@@ -1313,9 +1313,9 @@ compile.optdb.register(
TopoOptimizer( TopoOptimizer(
local_inplace_AdvancedIncSubtensor, failure_callback=TopoOptimizer.warn_inplace local_inplace_AdvancedIncSubtensor, failure_callback=TopoOptimizer.warn_inplace
), ),
60,
"fast_run", "fast_run",
"inplace", "inplace",
position=60,
) )
......
...@@ -19,7 +19,7 @@ def typed_list_inplace_opt(fgraph, node): ...@@ -19,7 +19,7 @@ def typed_list_inplace_opt(fgraph, node):
optdb.register( optdb.register(
"typed_list_inplace_opt", "typed_list_inplace_opt",
TopoOptimizer(typed_list_inplace_opt, failure_callback=TopoOptimizer.warn_inplace), TopoOptimizer(typed_list_inplace_opt, failure_callback=TopoOptimizer.warn_inplace),
60,
"fast_run", "fast_run",
"inplace", "inplace",
position=60,
) )
...@@ -687,7 +687,7 @@ it to :obj:`optdb` as follows: ...@@ -687,7 +687,7 @@ it to :obj:`optdb` as follows:
.. testcode:: .. testcode::
# optdb.register(name, optimizer, order, *tags) # optdb.register(name, optimizer, order, *tags)
optdb.register('simplify', simplify, 0.5, 'fast_run') optdb.register('simplify', simplify, 'fast_run', position=0.5)
Once this is done, the ``FAST_RUN`` mode will automatically include your Once this is done, the ``FAST_RUN`` mode will automatically include your
optimization (since you gave it the ``'fast_run'`` tag). Of course, optimization (since you gave it the ``'fast_run'`` tag). Of course,
......
...@@ -56,7 +56,7 @@ class TestDB: ...@@ -56,7 +56,7 @@ class TestDB:
assert isinstance(res, opt.SeqOptimizer) assert isinstance(res, opt.SeqOptimizer)
assert res.data == [] assert res.data == []
seq_db.register("b", TestOpt(), 1) seq_db.register("b", TestOpt(), position=1)
from io import StringIO from io import StringIO
...@@ -69,7 +69,7 @@ class TestDB: ...@@ -69,7 +69,7 @@ class TestDB:
assert "names {'b'}" in res assert "names {'b'}" in res
with pytest.raises(TypeError, match=r"`position` must be.*"): with pytest.raises(TypeError, match=r"`position` must be.*"):
seq_db.register("c", TestOpt(), object()) seq_db.register("c", TestOpt(), position=object())
def test_LocalGroupDB(self): def test_LocalGroupDB(self):
lg_db = LocalGroupDB() lg_db = LocalGroupDB()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论