提交 8706f75a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make keyword arguments in optimizers explicit

上级 d0c7808a
...@@ -179,7 +179,7 @@ class SeqOptimizer(GlobalOptimizer, UserList): ...@@ -179,7 +179,7 @@ class SeqOptimizer(GlobalOptimizer, UserList):
elif config.on_opt_error == "pdb": elif config.on_opt_error == "pdb":
pdb.post_mortem(sys.exc_info()[2]) pdb.post_mortem(sys.exc_info()[2])
def __init__(self, *opts, **kw): def __init__(self, *opts, failure_callback=None):
""" """
Parameters Parameters
---------- ----------
...@@ -195,8 +195,7 @@ class SeqOptimizer(GlobalOptimizer, UserList): ...@@ -195,8 +195,7 @@ class SeqOptimizer(GlobalOptimizer, UserList):
super().__init__(opts) super().__init__(opts)
self.failure_callback = kw.pop("failure_callback", None) self.failure_callback = failure_callback
assert len(kw) == 0
def apply(self, fgraph): def apply(self, fgraph):
"""Applies each `GlobalOptimizer` in ``self.data`` to `fgraph`.""" """Applies each `GlobalOptimizer` in ``self.data`` to `fgraph`."""
...@@ -1206,6 +1205,24 @@ def local_optimizer( ...@@ -1206,6 +1205,24 @@ def local_optimizer(
inplace: bool = False, inplace: bool = False,
requirements: Optional[Tuple[type, ...]] = (), requirements: Optional[Tuple[type, ...]] = (),
): ):
r"""A decorator used to construct `FromFunctionLocalOptimizer` instances.
Parameters
----------
tracks :
The `Op` types or instances to which this optimization applies.
inplace :
A boolean indicating whether or not the optimization works in-place.
If ``True``, a `DestroyHandler` `Feature` is added automatically added
to the `FunctionGraph`\s applied to this optimization.
requirements :
`Feature` types required by this optimization.
"""
if requirements is None:
requirements = ()
def decorator(f): def decorator(f):
if tracks is not None: if tracks is not None:
if len(tracks) == 0: if len(tracks) == 0:
...@@ -1257,7 +1274,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1257,7 +1274,7 @@ class LocalOptGroup(LocalOptimizer):
to the outputs. to the outputs.
""" """
def __init__(self, *optimizers, **kwargs): def __init__(self, *optimizers, apply_all_opts=False, profile=False):
if len(optimizers) == 1 and isinstance(optimizers[0], list): if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB. # This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0]) optimizers = tuple(optimizers[0])
...@@ -1269,10 +1286,9 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1269,10 +1286,9 @@ class LocalOptGroup(LocalOptimizer):
getattr(opt, "retains_inputs", False) for opt in optimizers getattr(opt, "retains_inputs", False) for opt in optimizers
) )
self.apply_all_opts = kwargs.pop("apply_all_opts", False) self.apply_all_opts = apply_all_opts
self.profile = kwargs.pop("profile", False) self.profile = profile
self.track_map = defaultdict(lambda: []) self.track_map = defaultdict(lambda: [])
assert len(kwargs) == 0
if self.profile: if self.profile:
self.time_opts = {} self.time_opts = {}
self.process_count = {} self.process_count = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论