提交 652fa754 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make optimizer classes extend abc.ABC and UserList

上级 b6804245
......@@ -8,6 +8,9 @@ class TestDB:
class Opt(opt.GlobalOptimizer): # inheritance buys __hash__
name = "blah"
def apply(self, fgraph):
pass
db = DB()
db.register("a", Opt())
......
......@@ -146,6 +146,9 @@ class AddFeatureOptimizer(gof.GlobalOptimizer):
super().add_requirements(fgraph)
fgraph.attach_feature(self.feature)
def apply(self, fgraph):
pass
class PrintCurrentFunctionGraph(gof.GlobalOptimizer):
"""
......
......@@ -3,6 +3,7 @@ Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
"""
import abc
import contextlib
import copy
import inspect
......@@ -12,7 +13,7 @@ import sys
import time
import traceback
import warnings
from collections import OrderedDict, defaultdict, deque
from collections import OrderedDict, UserList, defaultdict, deque
from collections.abc import Iterable
from functools import reduce
......@@ -43,7 +44,7 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError):
"""
class GlobalOptimizer:
class GlobalOptimizer(abc.ABC):
"""
A L{GlobalOptimizer} can be applied to an L{FunctionGraph} to transform it.
......@@ -52,22 +53,7 @@ class GlobalOptimizer:
"""
def __hash__(self):
if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = _optimizer_idx[0]
_optimizer_idx[0] += 1
return self._optimizer_idx
def __eq__(self, other):
# added to override the __eq__ implementation that may be inherited
# in subclasses from other bases.
return id(self) == id(other)
def __ne__(self, other):
# added to override the __ne__ implementation that may be inherited
# in subclasses from other bases.
return id(self) != id(other)
@abc.abstractmethod
def apply(self, fgraph):
"""
......@@ -77,6 +63,7 @@ class GlobalOptimizer:
L{InstanceFinder}, it can do so in its L{add_requirements} method.
"""
raise NotImplementedError()
def optimize(self, fgraph, *args, **kwargs):
"""
......@@ -108,6 +95,7 @@ class GlobalOptimizer:
etc.
"""
pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, "name", None)
......@@ -124,14 +112,23 @@ class GlobalOptimizer:
" optimizer return profiling information."
)
def __hash__(self):
if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = _optimizer_idx[0]
_optimizer_idx[0] += 1
return self._optimizer_idx
class FromFunctionOptimizer(GlobalOptimizer):
"""A `GlobalOptimizer` constructed from a given function."""
def __init__(self, fn, requirements=()):
self.apply = fn
self.fn = fn
self.requirements = requirements
def apply(self, *args, **kwargs):
return self.fn(*args, **kwargs)
def add_requirements(self, fgraph):
for req in self.requirements:
req(fgraph)
......@@ -168,7 +165,7 @@ def inplace_optimizer(f):
return rval
class SeqOptimizer(GlobalOptimizer, list):
class SeqOptimizer(GlobalOptimizer, UserList):
"""A `GlobalOptimizer` that applies a list of optimizers sequentially."""
@staticmethod
......@@ -198,7 +195,9 @@ class SeqOptimizer(GlobalOptimizer, list):
"""
if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
opts = opts[0]
self[:] = opts
super().__init__(opts)
self.failure_callback = kw.pop("failure_callback", None)
assert len(kw) == 0
......@@ -234,7 +233,7 @@ class SeqOptimizer(GlobalOptimizer, list):
{},
)
try:
for optimizer in self:
for optimizer in self.data:
try:
nb_nodes_before = len(fgraph.apply_nodes)
t0 = time.time()
......@@ -283,11 +282,8 @@ class SeqOptimizer(GlobalOptimizer, list):
)
return self.pre_profile
def __str__(self):
return f"SeqOpt({list.__str__(self)})"
def __repr__(self):
return list.__repr__(self)
return f"SeqOpt({self.data})"
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, "name", None)
......@@ -297,7 +293,7 @@ class SeqOptimizer(GlobalOptimizer, list):
# This way, -1 will do all depth
if depth != 0:
depth -= 1
for opt in self:
for opt in self.data:
opt.print_summary(stream, level=(level + 2), depth=depth)
@staticmethod
......@@ -1093,19 +1089,8 @@ def pre_constant_merge(vars):
return list(map(recursive_merge, vars))
########################
# Local Optimizers #
########################
class LocalOptimizer:
"""
A class for node-based optimizations.
Instances should implement the transform function,
and be passed to configure a fgraph-based Optimizer instance.
"""
class LocalOptimizer(abc.ABC):
"""A node-based optimizer."""
def __hash__(self):
if not hasattr(self, "_optimizer_idx"):
......@@ -1122,7 +1107,8 @@ class LocalOptimizer:
"""
return None
def transform(self, node):
@abc.abstractmethod
def transform(self, node, *args, **kwargs):
"""
Transform a subgraph whose output is `node`.
......@@ -1142,7 +1128,7 @@ class LocalOptimizer:
"""
raise utils.MethodNotDefined("transform", type(self), self.__class__.__name__)
raise NotImplementedError()
def add_requirements(self, fgraph):
"""
......@@ -1150,8 +1136,7 @@ class LocalOptimizer:
fgraph, this is the place to do it.
"""
# Added by default
# fgraph.attach_feature(toolbox.ReplaceValidate())
pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(f"{' ' * level}{self.__class_.__name__} id={id(self)}", file=stream)
......@@ -1277,16 +1262,16 @@ class LocalMetaOptimizer(LocalOptimizer):
class FromFunctionLocalOptimizer(LocalOptimizer):
"""
WRITEME
"""
"""An optimizer constructed from a given function."""
def __init__(self, fn, tracks=None, requirements=()):
self.transform = fn
self.fn = fn
self._tracks = tracks
self.requirements = requirements
def transform(self, *args, **kwargs):
return self.fn(*args, **kwargs)
def add_requirements(self, fgraph):
for req in self.requirements:
req(fgraph)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论