提交 ba4868a3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create a Rewriter base class for GlobalOptimizer and LocalOptimizer

上级 3685fd8e
...@@ -41,7 +41,6 @@ from aesara.utils import flatten ...@@ -41,7 +41,6 @@ from aesara.utils import flatten
_logger = logging.getLogger("aesara.graph.opt") _logger = logging.getLogger("aesara.graph.opt")
_optimizer_idx = [0]
class LocalMetaOptimizerSkipAssertionError(AssertionError): class LocalMetaOptimizerSkipAssertionError(AssertionError):
...@@ -52,7 +51,25 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError): ...@@ -52,7 +51,25 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError):
""" """
class GlobalOptimizer(abc.ABC): class Rewriter(abc.ABC):
"""Abstract base class for graph/term rewriters."""
@abc.abstractmethod
def add_requirements(self, fgraph: FunctionGraph):
r"""Add `Feature`\s and other requirements to a `FunctionGraph`."""
@abc.abstractmethod
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
"""Print a single-line, indented representation of the rewriter."""
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
class GlobalOptimizer(Rewriter):
"""A optimizer that can be applied to a `FunctionGraph` in order to transform it. """A optimizer that can be applied to a `FunctionGraph` in order to transform it.
It can represent an optimization or, in general, any kind of transformation It can represent an optimization or, in general, any kind of transformation
...@@ -93,15 +110,7 @@ class GlobalOptimizer(abc.ABC): ...@@ -93,15 +110,7 @@ class GlobalOptimizer(abc.ABC):
return self.optimize(fgraph) return self.optimize(fgraph)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
"""Add features to `fgraph` that are required to apply the optimization. ...
For example::
fgraph.attach_feature(History())
fgraph.attach_feature(MyFeature())
# etc.
"""
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, "name", None) name = getattr(self, "name", None)
...@@ -118,11 +127,48 @@ class GlobalOptimizer(abc.ABC): ...@@ -118,11 +127,48 @@ class GlobalOptimizer(abc.ABC):
" optimizer return profiling information." " optimizer return profiling information."
) )
def __hash__(self):
if not hasattr(self, "_optimizer_idx"): class LocalOptimizer(Rewriter):
self._optimizer_idx = _optimizer_idx[0] """A node-based optimizer."""
_optimizer_idx[0] += 1
return self._optimizer_idx def tracks(self):
"""Return the list of `Op` classes to which this optimization applies.
Returns ``None`` when the optimization applies to all nodes.
"""
return None
@abc.abstractmethod
def transform(
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs
) -> Union[bool, List[Variable], Dict[Variable, Variable]]:
r"""Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of the
following:
- ``False`` to indicate that no optimization can be applied to this `node`;
- A list of `Variable`\s to use in place of the `node`'s current outputs.
- A ``dict`` mapping old `Variable`\s to `Variable`\s.
Parameters
----------
fgraph :
A `FunctionGraph` containing `node`.
node :
An `Apply` node to be transformed.
"""
raise NotImplementedError()
def add_requirements(self, fgraph):
r"""Add required `Feature`\s to `fgraph`."""
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream)
class FromFunctionOptimizer(GlobalOptimizer): class FromFunctionOptimizer(GlobalOptimizer):
...@@ -1016,55 +1062,6 @@ def pre_constant_merge(fgraph, variables): ...@@ -1016,55 +1062,6 @@ def pre_constant_merge(fgraph, variables):
return [recursive_merge(v) for v in variables] return [recursive_merge(v) for v in variables]
class LocalOptimizer(abc.ABC):
"""A node-based optimizer."""
def __hash__(self):
if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = _optimizer_idx[0]
_optimizer_idx[0] += 1
return self._optimizer_idx
def tracks(self):
"""Return the list of `Op` classes to which this optimization applies.
Returns ``None`` when the optimization applies to all nodes.
"""
return None
@abc.abstractmethod
def transform(
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs
) -> Union[bool, List[Variable], Dict[Variable, Variable]]:
r"""Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of the
following:
- ``False`` to indicate that no optimization can be applied to this `node`;
- A list of `Variable`\s to use in place of the `node`'s current outputs.
- A ``dict`` mapping old `Variable`\s to `Variable`\s.
Parameters
----------
fgraph :
A `FunctionGraph` containing `node`.
node :
An `Apply` node to be transformed.
"""
raise NotImplementedError()
def add_requirements(self, fgraph):
r"""Add required `Feature`\s to `fgraph`."""
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream)
class LocalMetaOptimizer(LocalOptimizer): class LocalMetaOptimizer(LocalOptimizer):
r""" r"""
Base class for meta-optimizers that try a set of `LocalOptimizer`\s Base class for meta-optimizers that try a set of `LocalOptimizer`\s
......
...@@ -170,12 +170,6 @@ class OptimizationDatabase: ...@@ -170,12 +170,6 @@ class OptimizationDatabase:
print(" names", self._names, file=stream) print(" names", self._names, file=stream)
print(" db", self.__db__, file=stream) print(" db", self.__db__, file=stream)
def __hash__(self):
if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = opt._optimizer_idx[0]
opt._optimizer_idx[0] += 1
return self._optimizer_idx
# This is deprecated and will be removed. # This is deprecated and will be removed.
DB = OptimizationDatabase DB = OptimizationDatabase
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论