提交 652fa754 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make optimizer classes extend abc.ABC and UserList

上级 b6804245
...@@ -8,6 +8,9 @@ class TestDB: ...@@ -8,6 +8,9 @@ class TestDB:
class Opt(opt.GlobalOptimizer): # inheritance buys __hash__ class Opt(opt.GlobalOptimizer): # inheritance buys __hash__
name = "blah" name = "blah"
def apply(self, fgraph):
pass
db = DB() db = DB()
db.register("a", Opt()) db.register("a", Opt())
......
...@@ -146,6 +146,9 @@ class AddFeatureOptimizer(gof.GlobalOptimizer): ...@@ -146,6 +146,9 @@ class AddFeatureOptimizer(gof.GlobalOptimizer):
super().add_requirements(fgraph) super().add_requirements(fgraph)
fgraph.attach_feature(self.feature) fgraph.attach_feature(self.feature)
def apply(self, fgraph):
pass
class PrintCurrentFunctionGraph(gof.GlobalOptimizer): class PrintCurrentFunctionGraph(gof.GlobalOptimizer):
""" """
......
...@@ -3,6 +3,7 @@ Defines the base class for optimizations as well as a certain ...@@ -3,6 +3,7 @@ Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools. amount of useful generic optimization tools.
""" """
import abc
import contextlib import contextlib
import copy import copy
import inspect import inspect
...@@ -12,7 +13,7 @@ import sys ...@@ -12,7 +13,7 @@ import sys
import time import time
import traceback import traceback
import warnings import warnings
from collections import OrderedDict, defaultdict, deque from collections import OrderedDict, UserList, defaultdict, deque
from collections.abc import Iterable from collections.abc import Iterable
from functools import reduce from functools import reduce
...@@ -43,7 +44,7 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError): ...@@ -43,7 +44,7 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError):
""" """
class GlobalOptimizer: class GlobalOptimizer(abc.ABC):
""" """
A L{GlobalOptimizer} can be applied to an L{FunctionGraph} to transform it. A L{GlobalOptimizer} can be applied to an L{FunctionGraph} to transform it.
...@@ -52,22 +53,7 @@ class GlobalOptimizer: ...@@ -52,22 +53,7 @@ class GlobalOptimizer:
""" """
def __hash__(self): @abc.abstractmethod
if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = _optimizer_idx[0]
_optimizer_idx[0] += 1
return self._optimizer_idx
def __eq__(self, other):
# added to override the __eq__ implementation that may be inherited
# in subclasses from other bases.
return id(self) == id(other)
def __ne__(self, other):
# added to override the __ne__ implementation that may be inherited
# in subclasses from other bases.
return id(self) != id(other)
def apply(self, fgraph): def apply(self, fgraph):
""" """
...@@ -77,6 +63,7 @@ class GlobalOptimizer: ...@@ -77,6 +63,7 @@ class GlobalOptimizer:
L{InstanceFinder}, it can do so in its L{add_requirements} method. L{InstanceFinder}, it can do so in its L{add_requirements} method.
""" """
raise NotImplementedError()
def optimize(self, fgraph, *args, **kwargs): def optimize(self, fgraph, *args, **kwargs):
""" """
...@@ -108,6 +95,7 @@ class GlobalOptimizer: ...@@ -108,6 +95,7 @@ class GlobalOptimizer:
etc. etc.
""" """
pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, "name", None) name = getattr(self, "name", None)
...@@ -124,14 +112,23 @@ class GlobalOptimizer: ...@@ -124,14 +112,23 @@ class GlobalOptimizer:
" optimizer return profiling information." " optimizer return profiling information."
) )
def __hash__(self):
if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = _optimizer_idx[0]
_optimizer_idx[0] += 1
return self._optimizer_idx
class FromFunctionOptimizer(GlobalOptimizer): class FromFunctionOptimizer(GlobalOptimizer):
"""A `GlobalOptimizer` constructed from a given function.""" """A `GlobalOptimizer` constructed from a given function."""
def __init__(self, fn, requirements=()): def __init__(self, fn, requirements=()):
self.apply = fn self.fn = fn
self.requirements = requirements self.requirements = requirements
def apply(self, *args, **kwargs):
return self.fn(*args, **kwargs)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
for req in self.requirements: for req in self.requirements:
req(fgraph) req(fgraph)
...@@ -168,7 +165,7 @@ def inplace_optimizer(f): ...@@ -168,7 +165,7 @@ def inplace_optimizer(f):
return rval return rval
class SeqOptimizer(GlobalOptimizer, list): class SeqOptimizer(GlobalOptimizer, UserList):
"""A `GlobalOptimizer` that applies a list of optimizers sequentially.""" """A `GlobalOptimizer` that applies a list of optimizers sequentially."""
@staticmethod @staticmethod
...@@ -198,7 +195,9 @@ class SeqOptimizer(GlobalOptimizer, list): ...@@ -198,7 +195,9 @@ class SeqOptimizer(GlobalOptimizer, list):
""" """
if len(opts) == 1 and isinstance(opts[0], (list, tuple)): if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
opts = opts[0] opts = opts[0]
self[:] = opts
super().__init__(opts)
self.failure_callback = kw.pop("failure_callback", None) self.failure_callback = kw.pop("failure_callback", None)
assert len(kw) == 0 assert len(kw) == 0
...@@ -234,7 +233,7 @@ class SeqOptimizer(GlobalOptimizer, list): ...@@ -234,7 +233,7 @@ class SeqOptimizer(GlobalOptimizer, list):
{}, {},
) )
try: try:
for optimizer in self: for optimizer in self.data:
try: try:
nb_nodes_before = len(fgraph.apply_nodes) nb_nodes_before = len(fgraph.apply_nodes)
t0 = time.time() t0 = time.time()
...@@ -283,11 +282,8 @@ class SeqOptimizer(GlobalOptimizer, list): ...@@ -283,11 +282,8 @@ class SeqOptimizer(GlobalOptimizer, list):
) )
return self.pre_profile return self.pre_profile
def __str__(self):
return f"SeqOpt({list.__str__(self)})"
def __repr__(self): def __repr__(self):
return list.__repr__(self) return f"SeqOpt({self.data})"
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, "name", None) name = getattr(self, "name", None)
...@@ -297,7 +293,7 @@ class SeqOptimizer(GlobalOptimizer, list): ...@@ -297,7 +293,7 @@ class SeqOptimizer(GlobalOptimizer, list):
# This way, -1 will do all depth # This way, -1 will do all depth
if depth != 0: if depth != 0:
depth -= 1 depth -= 1
for opt in self: for opt in self.data:
opt.print_summary(stream, level=(level + 2), depth=depth) opt.print_summary(stream, level=(level + 2), depth=depth)
@staticmethod @staticmethod
...@@ -1093,19 +1089,8 @@ def pre_constant_merge(vars): ...@@ -1093,19 +1089,8 @@ def pre_constant_merge(vars):
return list(map(recursive_merge, vars)) return list(map(recursive_merge, vars))
######################## class LocalOptimizer(abc.ABC):
# Local Optimizers # """A node-based optimizer."""
########################
class LocalOptimizer:
"""
A class for node-based optimizations.
Instances should implement the transform function,
and be passed to configure a fgraph-based Optimizer instance.
"""
def __hash__(self): def __hash__(self):
if not hasattr(self, "_optimizer_idx"): if not hasattr(self, "_optimizer_idx"):
...@@ -1122,7 +1107,8 @@ class LocalOptimizer: ...@@ -1122,7 +1107,8 @@ class LocalOptimizer:
""" """
return None return None
def transform(self, node): @abc.abstractmethod
def transform(self, node, *args, **kwargs):
""" """
Transform a subgraph whose output is `node`. Transform a subgraph whose output is `node`.
...@@ -1142,7 +1128,7 @@ class LocalOptimizer: ...@@ -1142,7 +1128,7 @@ class LocalOptimizer:
""" """
raise utils.MethodNotDefined("transform", type(self), self.__class__.__name__) raise NotImplementedError()
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
""" """
...@@ -1150,8 +1136,7 @@ class LocalOptimizer: ...@@ -1150,8 +1136,7 @@ class LocalOptimizer:
fgraph, this is the place to do it. fgraph, this is the place to do it.
""" """
# Added by default pass
# fgraph.attach_feature(toolbox.ReplaceValidate())
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(f"{' ' * level}{self.__class_.__name__} id={id(self)}", file=stream) print(f"{' ' * level}{self.__class_.__name__} id={id(self)}", file=stream)
...@@ -1277,16 +1262,16 @@ class LocalMetaOptimizer(LocalOptimizer): ...@@ -1277,16 +1262,16 @@ class LocalMetaOptimizer(LocalOptimizer):
class FromFunctionLocalOptimizer(LocalOptimizer): class FromFunctionLocalOptimizer(LocalOptimizer):
""" """An optimizer constructed from a given function."""
WRITEME
"""
def __init__(self, fn, tracks=None, requirements=()): def __init__(self, fn, tracks=None, requirements=()):
self.transform = fn self.fn = fn
self._tracks = tracks self._tracks = tracks
self.requirements = requirements self.requirements = requirements
def transform(self, *args, **kwargs):
return self.fn(*args, **kwargs)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
for req in self.requirements: for req in self.requirements:
req(fgraph) req(fgraph)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论