提交 8706f75a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make keyword arguments in optimizers explicit

上级 d0c7808a
......@@ -179,7 +179,7 @@ class SeqOptimizer(GlobalOptimizer, UserList):
elif config.on_opt_error == "pdb":
pdb.post_mortem(sys.exc_info()[2])
def __init__(self, *opts, **kw):
def __init__(self, *opts, failure_callback=None):
"""
Parameters
----------
......@@ -195,8 +195,7 @@ class SeqOptimizer(GlobalOptimizer, UserList):
super().__init__(opts)
self.failure_callback = kw.pop("failure_callback", None)
assert len(kw) == 0
self.failure_callback = failure_callback
def apply(self, fgraph):
"""Applies each `GlobalOptimizer` in ``self.data`` to `fgraph`."""
......@@ -1206,6 +1205,24 @@ def local_optimizer(
inplace: bool = False,
requirements: Optional[Tuple[type, ...]] = (),
):
r"""A decorator used to construct `FromFunctionLocalOptimizer` instances.
Parameters
----------
tracks :
The `Op` types or instances to which this optimization applies.
inplace :
A boolean indicating whether or not the optimization works in-place.
If ``True``, a `DestroyHandler` `Feature` is added automatically added
to the `FunctionGraph`\s applied to this optimization.
requirements :
`Feature` types required by this optimization.
"""
if requirements is None:
requirements = ()
def decorator(f):
if tracks is not None:
if len(tracks) == 0:
......@@ -1257,7 +1274,7 @@ class LocalOptGroup(LocalOptimizer):
to the outputs.
"""
def __init__(self, *optimizers, **kwargs):
def __init__(self, *optimizers, apply_all_opts=False, profile=False):
if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0])
......@@ -1269,10 +1286,9 @@ class LocalOptGroup(LocalOptimizer):
getattr(opt, "retains_inputs", False) for opt in optimizers
)
self.apply_all_opts = kwargs.pop("apply_all_opts", False)
self.profile = kwargs.pop("profile", False)
self.apply_all_opts = apply_all_opts
self.profile = profile
self.track_map = defaultdict(lambda: [])
assert len(kwargs) == 0
if self.profile:
self.time_opts = {}
self.process_count = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论