提交 f1ce6f5d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename opt alias to aesara_opt in aesara.graph.optdb

上级 5e8398b5
......@@ -6,12 +6,12 @@ from io import StringIO
from typing import Dict, Optional, Sequence, Union
from aesara.configdefaults import config
from aesara.graph import opt
from aesara.graph import opt as aesara_opt
from aesara.misc.ordered_set import OrderedSet
from aesara.utils import DefaultOrderedDict
OptimizersType = Union[opt.GlobalOptimizer, opt.LocalOptimizer]
OptimizersType = Union[aesara_opt.GlobalOptimizer, aesara_opt.LocalOptimizer]
class OptimizationDatabase:
......@@ -58,7 +58,12 @@ class OptimizationDatabase:
"""
if not isinstance(
optimizer, (OptimizationDatabase, opt.GlobalOptimizer, opt.LocalOptimizer)
optimizer,
(
OptimizationDatabase,
aesara_opt.GlobalOptimizer,
aesara_opt.LocalOptimizer,
),
):
raise TypeError(f"{optimizer} is not a valid optimizer type.")
......@@ -357,12 +362,12 @@ class EquilibriumDB(OptimizationDatabase):
final_opts = None
if len(cleanup_opts) == 0:
cleanup_opts = None
return opt.EquilibriumOptimizer(
return aesara_opt.EquilibriumOptimizer(
opts,
max_use_ratio=config.optdb__max_use_ratio,
ignore_newtrees=self.ignore_newtrees,
tracks_on_change_inputs=self.tracks_on_change_inputs,
failure_callback=opt.NavigatorOptimizer.warn_inplace,
failure_callback=aesara_opt.NavigatorOptimizer.warn_inplace,
final_optimizers=final_opts,
cleanup_optimizers=cleanup_opts,
)
......@@ -382,9 +387,9 @@ class SequenceDB(OptimizationDatabase):
"""
seq_opt = opt.SeqOptimizer
seq_opt = aesara_opt.SeqOptimizer
def __init__(self, failure_callback=opt.SeqOptimizer.warn):
def __init__(self, failure_callback=aesara_opt.SeqOptimizer.warn):
super().__init__()
self.__position__ = {}
self.failure_callback = failure_callback
......@@ -488,7 +493,7 @@ class LocalGroupDB(SequenceDB):
self,
apply_all_opts: bool = False,
profile: bool = False,
local_opt=opt.LocalOptGroup,
local_opt=aesara_opt.LocalOptGroup,
):
super().__init__(failure_callback=None)
self.apply_all_opts = apply_all_opts
......@@ -520,7 +525,7 @@ class TopoDB(OptimizationDatabase):
self.failure_callback = failure_callback
def query(self, *tags, **kwtags):
return opt.TopoOptimizer(
return aesara_opt.TopoOptimizer(
self.db.query(*tags, **kwtags),
self.order,
self.ignore_newtrees,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论