提交 f1ce6f5d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename opt alias to aesara_opt in aesara.graph.optdb

上级 5e8398b5
...@@ -6,12 +6,12 @@ from io import StringIO ...@@ -6,12 +6,12 @@ from io import StringIO
from typing import Dict, Optional, Sequence, Union from typing import Dict, Optional, Sequence, Union
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import opt from aesara.graph import opt as aesara_opt
from aesara.misc.ordered_set import OrderedSet from aesara.misc.ordered_set import OrderedSet
from aesara.utils import DefaultOrderedDict from aesara.utils import DefaultOrderedDict
OptimizersType = Union[opt.GlobalOptimizer, opt.LocalOptimizer] OptimizersType = Union[aesara_opt.GlobalOptimizer, aesara_opt.LocalOptimizer]
class OptimizationDatabase: class OptimizationDatabase:
...@@ -58,7 +58,12 @@ class OptimizationDatabase: ...@@ -58,7 +58,12 @@ class OptimizationDatabase:
""" """
if not isinstance( if not isinstance(
optimizer, (OptimizationDatabase, opt.GlobalOptimizer, opt.LocalOptimizer) optimizer,
(
OptimizationDatabase,
aesara_opt.GlobalOptimizer,
aesara_opt.LocalOptimizer,
),
): ):
raise TypeError(f"{optimizer} is not a valid optimizer type.") raise TypeError(f"{optimizer} is not a valid optimizer type.")
...@@ -357,12 +362,12 @@ class EquilibriumDB(OptimizationDatabase): ...@@ -357,12 +362,12 @@ class EquilibriumDB(OptimizationDatabase):
final_opts = None final_opts = None
if len(cleanup_opts) == 0: if len(cleanup_opts) == 0:
cleanup_opts = None cleanup_opts = None
return opt.EquilibriumOptimizer( return aesara_opt.EquilibriumOptimizer(
opts, opts,
max_use_ratio=config.optdb__max_use_ratio, max_use_ratio=config.optdb__max_use_ratio,
ignore_newtrees=self.ignore_newtrees, ignore_newtrees=self.ignore_newtrees,
tracks_on_change_inputs=self.tracks_on_change_inputs, tracks_on_change_inputs=self.tracks_on_change_inputs,
failure_callback=opt.NavigatorOptimizer.warn_inplace, failure_callback=aesara_opt.NavigatorOptimizer.warn_inplace,
final_optimizers=final_opts, final_optimizers=final_opts,
cleanup_optimizers=cleanup_opts, cleanup_optimizers=cleanup_opts,
) )
...@@ -382,9 +387,9 @@ class SequenceDB(OptimizationDatabase): ...@@ -382,9 +387,9 @@ class SequenceDB(OptimizationDatabase):
""" """
seq_opt = opt.SeqOptimizer seq_opt = aesara_opt.SeqOptimizer
def __init__(self, failure_callback=opt.SeqOptimizer.warn): def __init__(self, failure_callback=aesara_opt.SeqOptimizer.warn):
super().__init__() super().__init__()
self.__position__ = {} self.__position__ = {}
self.failure_callback = failure_callback self.failure_callback = failure_callback
...@@ -488,7 +493,7 @@ class LocalGroupDB(SequenceDB): ...@@ -488,7 +493,7 @@ class LocalGroupDB(SequenceDB):
self, self,
apply_all_opts: bool = False, apply_all_opts: bool = False,
profile: bool = False, profile: bool = False,
local_opt=opt.LocalOptGroup, local_opt=aesara_opt.LocalOptGroup,
): ):
super().__init__(failure_callback=None) super().__init__(failure_callback=None)
self.apply_all_opts = apply_all_opts self.apply_all_opts = apply_all_opts
...@@ -520,7 +525,7 @@ class TopoDB(OptimizationDatabase): ...@@ -520,7 +525,7 @@ class TopoDB(OptimizationDatabase):
self.failure_callback = failure_callback self.failure_callback = failure_callback
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
return opt.TopoOptimizer( return aesara_opt.TopoOptimizer(
self.db.query(*tags, **kwtags), self.db.query(*tags, **kwtags),
self.order, self.order,
self.ignore_newtrees, self.ignore_newtrees,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论