提交 550a6e98 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename LocalOptimizer to NodeRewriter

上级 214ef4cf
...@@ -24,7 +24,7 @@ from aesara.graph.basic import ( ...@@ -24,7 +24,7 @@ from aesara.graph.basic import (
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
from aesara.graph.op import HasInnerGraph, Op from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.opt import in2out, local_optimizer from aesara.graph.opt import in2out, node_rewriter
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from aesara.tensor.basic_opt import ShapeFeature from aesara.tensor.basic_opt import ShapeFeature
...@@ -928,7 +928,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -928,7 +928,7 @@ class OpFromGraph(Op, HasInnerGraph):
output[0] = variable output[0] = variable
@local_optimizer([OpFromGraph]) @node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node): def inline_ofg_expansion(fgraph, node):
""" """
This optimization expands internal graph of OpFromGraph. This optimization expands internal graph of OpFromGraph.
......
...@@ -13,7 +13,7 @@ from aesara.graph.basic import ( ...@@ -13,7 +13,7 @@ from aesara.graph.basic import (
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import local_optimizer, optimizer from aesara.graph.opt import node_rewriter, optimizer
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import OptimizationQuery
......
...@@ -6,11 +6,11 @@ from unification import var ...@@ -6,11 +6,11 @@ from unification import var
from unification.variable import Var from unification.variable import Var
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.opt import LocalOptimizer from aesara.graph.opt import NodeRewriter
from aesara.graph.unify import eval_if_etuple from aesara.graph.unify import eval_if_etuple
class KanrenRelationSub(LocalOptimizer): class KanrenRelationSub(NodeRewriter):
r"""A local optimizer that uses `kanren` to match and replace terms. r"""A local optimizer that uses `kanren` to match and replace terms.
See `kanren <https://github.com/pythological/kanren>`__ for more information See `kanren <https://github.com/pythological/kanren>`__ for more information
......
...@@ -48,7 +48,7 @@ FailureCallbackType = Callable[ ...@@ -48,7 +48,7 @@ FailureCallbackType = Callable[
Exception, Exception,
"NavigatorOptimizer", "NavigatorOptimizer",
List[Tuple[Variable, None]], List[Tuple[Variable, None]],
"LocalOptimizer", "NodeRewriter",
Apply, Apply,
], ],
None, None,
...@@ -142,13 +142,13 @@ class GraphRewriter(Rewriter): ...@@ -142,13 +142,13 @@ class GraphRewriter(Rewriter):
) )
class LocalOptimizer(Rewriter): class NodeRewriter(Rewriter):
"""A node-based optimizer.""" """A `Rewriter` that is applied to an `Apply` node."""
def tracks(self): def tracks(self) -> Optional[Sequence[Op]]:
"""Return the list of `Op` classes to which this optimization applies. """Return the list of `Op` classes to which this rewrite applies.
Returns ``None`` when the optimization applies to all nodes. Returns ``None`` when the rewrite applies to all nodes.
""" """
return None return None
...@@ -162,23 +162,22 @@ class LocalOptimizer(Rewriter): ...@@ -162,23 +162,22 @@ class LocalOptimizer(Rewriter):
Subclasses should implement this function so that it returns one of the Subclasses should implement this function so that it returns one of the
following: following:
- ``False`` to indicate that no optimization can be applied to this `node`; - ``False`` to indicate that this rewrite cannot be applied to `node`
- A list of `Variable`\s to use in place of the `node`'s current outputs. - A list of `Variable`\s to use in place of the `node`'s current outputs
- A ``dict`` mapping old `Variable`\s to `Variable`\s. - A ``dict`` mapping old `Variable`\s to new `Variable`\s
Parameters Parameters
---------- ----------
fgraph : fgraph
A `FunctionGraph` containing `node`. A `FunctionGraph` containing `node`.
node : node
An `Apply` node to be transformed. An `Apply` node to be rewritten.
""" """
raise NotImplementedError() raise NotImplementedError()
def add_requirements(self, fgraph): def add_requirements(self, fgraph: FunctionGraph):
r"""Add required `Feature`\s to `fgraph`.""" r"""Add required `Feature`\s to `fgraph`."""
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
...@@ -939,9 +938,9 @@ def pre_constant_merge(fgraph, variables): ...@@ -939,9 +938,9 @@ def pre_constant_merge(fgraph, variables):
return [recursive_merge(v) for v in variables] return [recursive_merge(v) for v in variables]
class LocalMetaOptimizer(LocalOptimizer): class LocalMetaOptimizer(NodeRewriter):
r""" r"""
Base class for meta-optimizers that try a set of `LocalOptimizer`\s Base class for meta-optimizers that try a set of `NodeRewriter`\s
to replace a node and choose the one that executes the fastest. to replace a node and choose the one that executes the fastest.
If the error `MetaNodeRewriterSkip` is raised during If the error `MetaNodeRewriterSkip` is raised during
...@@ -1058,8 +1057,8 @@ class LocalMetaOptimizer(LocalOptimizer): ...@@ -1058,8 +1057,8 @@ class LocalMetaOptimizer(LocalOptimizer):
return time.time() - start return time.time() - start
class FromFunctionLocalOptimizer(LocalOptimizer): class FromFunctionLocalOptimizer(NodeRewriter):
"""A `LocalOptimizer` constructed from a function.""" """A `NodeRewriter` constructed from a function."""
def __init__(self, fn, tracks=None, requirements=()): def __init__(self, fn, tracks=None, requirements=()):
self.fn = fn self.fn = fn
...@@ -1095,7 +1094,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -1095,7 +1094,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
print(f"{' ' * level}{self.transform} id={id(self)}", file=stream) print(f"{' ' * level}{self.transform} id={id(self)}", file=stream)
def local_optimizer( def node_rewriter(
tracks: Optional[Sequence[Union[Op, type]]], tracks: Optional[Sequence[Union[Op, type]]],
inplace: bool = False, inplace: bool = False,
requirements: Optional[Tuple[type, ...]] = (), requirements: Optional[Tuple[type, ...]] = (),
...@@ -1150,12 +1149,12 @@ class LocalOptTracker: ...@@ -1150,12 +1149,12 @@ class LocalOptTracker:
r"""A container that maps rewrites to `Op` instances and `Op`-type inheritance.""" r"""A container that maps rewrites to `Op` instances and `Op`-type inheritance."""
def __init__(self): def __init__(self):
self.tracked_instances: Dict[Op, List[LocalOptimizer]] = {} self.tracked_instances: Dict[Op, List[NodeRewriter]] = {}
self.tracked_types: Dict[type, List[LocalOptimizer]] = {} self.tracked_types: Dict[type, List[NodeRewriter]] = {}
self.untracked_opts: List[LocalOptimizer] = [] self.untracked_opts: List[NodeRewriter] = []
def add_tracker(self, rw: LocalOptimizer): def add_tracker(self, rw: NodeRewriter):
"""Add a `LocalOptimizer` to be keyed by its `LocalOptimizer.tracks` or applied generally.""" """Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally."""
tracks = rw.tracks() tracks = rw.tracks()
if tracks is None: if tracks is None:
...@@ -1167,8 +1166,8 @@ class LocalOptTracker: ...@@ -1167,8 +1166,8 @@ class LocalOptTracker:
else: else:
self.tracked_instances.setdefault(c, []).append(rw) self.tracked_instances.setdefault(c, []).append(rw)
def _find_impl(self, cls) -> List[LocalOptimizer]: def _find_impl(self, cls) -> List[NodeRewriter]:
r"""Returns the `LocalOptimizer`\s that apply to `cls` based on inheritance. r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
This based on `functools._find_impl`. This based on `functools._find_impl`.
""" """
...@@ -1181,7 +1180,7 @@ class LocalOptTracker: ...@@ -1181,7 +1180,7 @@ class LocalOptTracker:
return matches return matches
@functools.lru_cache() @functools.lru_cache()
def get_trackers(self, op: Op) -> List[LocalOptimizer]: def get_trackers(self, op: Op) -> List[NodeRewriter]:
"""Get all the rewrites applicable to `op`.""" """Get all the rewrites applicable to `op`."""
return ( return (
self._find_impl(type(op)) self._find_impl(type(op))
...@@ -1198,8 +1197,8 @@ class LocalOptTracker: ...@@ -1198,8 +1197,8 @@ class LocalOptTracker:
) )
class LocalOptGroup(LocalOptimizer): class LocalOptGroup(NodeRewriter):
r"""An optimizer that applies a list of `LocalOptimizer`\s to a node. r"""An optimizer that applies a list of `NodeRewriter`\s to a node.
Attributes Attributes
---------- ----------
...@@ -1390,7 +1389,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1390,7 +1389,7 @@ class LocalOptGroup(LocalOptimizer):
opt.add_requirements(fgraph) opt.add_requirements(fgraph)
class OpSub(LocalOptimizer): class OpSub(NodeRewriter):
""" """
Replaces the application of a certain `Op` by the application of Replaces the application of a certain `Op` by the application of
...@@ -1440,7 +1439,7 @@ class OpSub(LocalOptimizer): ...@@ -1440,7 +1439,7 @@ class OpSub(LocalOptimizer):
return f"{self.op1} -> {self.op2}" return f"{self.op1} -> {self.op2}"
class OpRemove(LocalOptimizer): class OpRemove(NodeRewriter):
""" """
Removes all applications of an `Op` by transferring each of its Removes all applications of an `Op` by transferring each of its
outputs to the corresponding input. outputs to the corresponding input.
...@@ -1473,7 +1472,7 @@ class OpRemove(LocalOptimizer): ...@@ -1473,7 +1472,7 @@ class OpRemove(LocalOptimizer):
) )
class PatternSub(LocalOptimizer): class PatternSub(NodeRewriter):
"""Replace all occurrences of an input pattern with an output pattern. """Replace all occurrences of an input pattern with an output pattern.
The input and output patterns have the following syntax: The input and output patterns have the following syntax:
...@@ -1719,44 +1718,20 @@ class Updater(Feature): ...@@ -1719,44 +1718,20 @@ class Updater(Feature):
class NavigatorOptimizer(GraphRewriter): class NavigatorOptimizer(GraphRewriter):
r"""An optimizer that applies a `LocalOptimizer` with considerations for the new nodes it creates. r"""An optimizer that applies a `NodeRewriter` with considerations for the new nodes it creates.
This optimizer also allows the `LocalOptimizer` to use a special ``"remove"`` value This optimizer also allows the `NodeRewriter` to use a special ``"remove"`` value
in the ``dict``\s returned by :meth:`LocalOptimizer`. `Variable`\s mapped to this in the ``dict``\s returned by :meth:`NodeRewriter`. `Variable`\s mapped to this
value are removed from the `FunctionGraph`. value are removed from the `FunctionGraph`.
Parameters
----------
local_opt :
A `LocalOptimizer` to apply over a `FunctionGraph` (or ``None``).
ignore_newtrees :
- ``True``: new subgraphs returned by an optimization are not a
candidate for optimization.
- ``False``: new subgraphs returned by an optimization is a candidate
for optimization.
- ``'auto'``: let the `local_opt` set this parameter via its :attr:`reentrant`
attribute.
failure_callback
A function with the signature ``(exception, navigator, [(old, new),
(old,new),...])`` that is called when there's an exception.
If the exception is raised in ``local_opt.transform``, the ``new`` variables
will be ``None``.
If the exception is raised during validation (e.g. the new types don't
match) then the new variables will be the ones created by ``self.transform``.
If this parameter is ``None``, then exceptions are not caught here and
are raised normally.
""" """
@staticmethod @staticmethod
def warn(exc, nav, repl_pairs, local_opt, node): def warn(exc, nav, repl_pairs, node_rewriter, node):
"""A failure callback that prints a traceback.""" """A failure callback that prints a traceback."""
if config.on_opt_error != "ignore": if config.on_opt_error != "ignore":
_logger.error(f"Optimization failure due to: {local_opt}") _logger.error(f"Optimization failure due to: {node_rewriter}")
_logger.error(f"node: {node}") _logger.error(f"node: {node}")
_logger.error("TRACEBACK:") _logger.error("TRACEBACK:")
_logger.error(traceback.format_exc()) _logger.error(traceback.format_exc())
...@@ -1768,30 +1743,59 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1768,30 +1743,59 @@ class NavigatorOptimizer(GraphRewriter):
raise exc raise exc
@staticmethod @staticmethod
def warn_inplace(exc, nav, repl_pairs, local_opt, node): def warn_inplace(exc, nav, repl_pairs, node_rewriter, node):
r"""A failure callback that ignores ``InconsistencyError``\s and prints a traceback. r"""A failure callback that ignores `InconsistencyError`\s and prints a traceback.
If the error occurred during replacement, ``repl_pairs`` is set; If the error occurred during replacement, `repl_pairs` is set;
otherwise, its value is ``None``. otherwise, its value is ``None``.
""" """
if isinstance(exc, InconsistencyError): if isinstance(exc, InconsistencyError):
return return
return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt, node) return NavigatorOptimizer.warn(exc, nav, repl_pairs, node_rewriter, node)
@staticmethod @staticmethod
def warn_ignore(exc, nav, repl_pairs, local_opt, node): def warn_ignore(exc, nav, repl_pairs, node_rewriter, node):
"""A failure callback that ignores all errors.""" """A failure callback that ignores all errors."""
def __init__( def __init__(
self, self,
local_opt: LocalOptimizer, node_rewriter: Optional[NodeRewriter],
ignore_newtrees: Literal[True, False, "auto"], ignore_newtrees: Literal[True, False, "auto"],
failure_callback: Optional[FailureCallbackType] = None, failure_callback: Optional[FailureCallbackType] = None,
): ):
self.local_opt = local_opt """
Parameters
----------
node_rewriter
A `NodeRewriter` to apply over a `FunctionGraph` (or ``None``).
ignore_newtrees
- ``True``: new subgraphs returned by an optimization are not a
candidate for optimization.
- ``False``: new subgraphs returned by an optimization is a
candidate for optimization.
- ``'auto'``: let the `node_rewriter` set this parameter via its
:attr:`reentrant` attribute.
failure_callback
A function with the signature
``(exception, navigator, [(old, new), (old,new),...])``
that is called when there's an exception.
If the exception is raised in `node_rewriter.transform`, the
``new`` variables will be ``None``.
If the exception is raised during validation (e.g. the new types
don't match) then the new variables will be the ones created by
``self.transform``.
If this parameter is ``None``, then exceptions are not caught here
and are raised normally.
"""
self.node_rewriter = node_rewriter
if ignore_newtrees == "auto": if ignore_newtrees == "auto":
self.ignore_newtrees = not getattr(local_opt, "reentrant", True) self.ignore_newtrees = not getattr(node_rewriter, "reentrant", True)
else: else:
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.failure_callback = failure_callback self.failure_callback = failure_callback
...@@ -1865,7 +1869,7 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1865,7 +1869,7 @@ class NavigatorOptimizer(GraphRewriter):
node : node :
An `Apply` instance in `fgraph` An `Apply` instance in `fgraph`
lopt : lopt :
A `LocalOptimizer` instance that may have a better idea for A `NodeRewriter` instance that may have a better idea for
how to compute node's outputs. how to compute node's outputs.
Returns Returns
...@@ -1874,7 +1878,7 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1874,7 +1878,7 @@ class NavigatorOptimizer(GraphRewriter):
``True`` iff the `node`'s outputs were replaced in the `fgraph`. ``True`` iff the `node`'s outputs were replaced in the `fgraph`.
""" """
lopt = lopt or self.local_opt lopt = lopt or self.node_rewriter
try: try:
replacements = lopt.transform(fgraph, node) replacements = lopt.transform(fgraph, node)
except Exception as e: except Exception as e:
...@@ -1896,19 +1900,17 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1896,19 +1900,17 @@ class NavigatorOptimizer(GraphRewriter):
replacements = list(replacements.values()) replacements = list(replacements.values())
elif not isinstance(replacements, (tuple, list)): elif not isinstance(replacements, (tuple, list)):
raise TypeError( raise TypeError(
f"Local optimizer {lopt} gave wrong type of replacement. " f"Node rewriter {lopt} gave wrong type of replacement. "
f"Expected list or tuple; got {replacements}" f"Expected list or tuple; got {replacements}"
) )
if len(old_vars) != len(replacements): if len(old_vars) != len(replacements):
raise ValueError( raise ValueError(f"Node rewriter {lopt} gave wrong number of replacements")
f"Local optimizer {lopt} gave wrong number of replacements"
)
# None in the replacement mean that this variable isn't used # None in the replacement mean that this variable isn't used
# and we want to remove it # and we want to remove it
for r, rnew in zip(old_vars, replacements): for r, rnew in zip(old_vars, replacements):
if rnew is None and len(fgraph.clients[r]) > 0: if rnew is None and len(fgraph.clients[r]) > 0:
raise ValueError( raise ValueError(
f"Local optimizer {lopt} tried to remove a variable" f"Node rewriter {lopt} tried to remove a variable"
f" that is being used: {r}" f" that is being used: {r}"
) )
# If an output would be replaced by itself, no need to perform # If an output would be replaced by itself, no need to perform
...@@ -1939,21 +1941,23 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1939,21 +1941,23 @@ class NavigatorOptimizer(GraphRewriter):
super().add_requirements(fgraph) super().add_requirements(fgraph)
# Added by default # Added by default
# fgraph.attach_feature(ReplaceValidate()) # fgraph.attach_feature(ReplaceValidate())
if self.local_opt: if self.node_rewriter:
self.local_opt.add_requirements(fgraph) self.node_rewriter.add_requirements(fgraph)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream) print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream)
if depth != 0: if depth != 0:
self.local_opt.print_summary(stream, level=(level + 2), depth=(depth - 1)) self.node_rewriter.print_summary(
stream, level=(level + 2), depth=(depth - 1)
)
class TopoOptimizer(NavigatorOptimizer): class TopoOptimizer(NavigatorOptimizer):
"""An optimizer that applies a single `LocalOptimizer` to each node in topological order (or reverse).""" """An optimizer that applies a single `NodeRewriter` to each node in topological order (or reverse)."""
def __init__( def __init__(
self, self,
local_opt: LocalOptimizer, node_rewriter: NodeRewriter,
order: Literal["out_to_in", "in_to_out"] = "in_to_out", order: Literal["out_to_in", "in_to_out"] = "in_to_out",
ignore_newtrees: bool = False, ignore_newtrees: bool = False,
failure_callback: Optional[FailureCallbackType] = None, failure_callback: Optional[FailureCallbackType] = None,
...@@ -1961,7 +1965,7 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -1961,7 +1965,7 @@ class TopoOptimizer(NavigatorOptimizer):
if order not in ("out_to_in", "in_to_out"): if order not in ("out_to_in", "in_to_out"):
raise ValueError("order must be 'out_to_in' or 'in_to_out'") raise ValueError("order must be 'out_to_in' or 'in_to_out'")
self.order = order self.order = order
super().__init__(local_opt, ignore_newtrees, failure_callback) super().__init__(node_rewriter, ignore_newtrees, failure_callback)
def apply(self, fgraph, start_from=None): def apply(self, fgraph, start_from=None):
if start_from is None: if start_from is None:
...@@ -2005,7 +2009,7 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -2005,7 +2009,7 @@ class TopoOptimizer(NavigatorOptimizer):
io_t, io_t,
loop_t, loop_t,
callback_time, callback_time,
self.local_opt, self.node_rewriter,
) )
@staticmethod @staticmethod
...@@ -2061,22 +2065,26 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -2061,22 +2065,26 @@ class TopoOptimizer(NavigatorOptimizer):
def topogroup_optimizer( def topogroup_optimizer(
order, *local_opts, name=None, failure_callback=TopoOptimizer.warn_inplace, **kwargs order,
*node_rewriters,
name=None,
failure_callback=TopoOptimizer.warn_inplace,
**kwargs,
): ):
"""Apply `local_opts` from the input/output nodes to the output/input nodes of a graph. """Apply `node_rewriters` from the input/output nodes to the output/input nodes of a graph.
This constructs `TopoOptimizer`s, and uses a `LocalOptGroup` when there's This constructs `TopoOptimizer`s, and uses a `LocalOptGroup` when there's
more than one entry in `local_opts`. more than one entry in `node_rewriters`.
""" """
if len(local_opts) > 1: if len(node_rewriters) > 1:
# Don't wrap it uselessly if their is only 1 optimization. # Don't wrap it uselessly if their is only 1 optimization.
local_opts = LocalOptGroup(*local_opts) node_rewriters = LocalOptGroup(*node_rewriters)
else: else:
(local_opts,) = local_opts (node_rewriters,) = node_rewriters
if not name: if not name:
name = local_opts.__name__ name = node_rewriters.__name__
ret = TopoOptimizer( ret = TopoOptimizer(
local_opts, node_rewriters,
order=order, order=order,
failure_callback=failure_callback, failure_callback=failure_callback,
**kwargs, **kwargs,
...@@ -2091,9 +2099,9 @@ out2in = partial(topogroup_optimizer, "out_to_in") ...@@ -2091,9 +2099,9 @@ out2in = partial(topogroup_optimizer, "out_to_in")
class OpKeyOptimizer(NavigatorOptimizer): class OpKeyOptimizer(NavigatorOptimizer):
r"""An optimizer that applies a `LocalOptimizer` to specific `Op`\s. r"""An optimizer that applies a `NodeRewriter` to specific `Op`\s.
The `Op`\s are provided by a :meth:`LocalOptimizer.op_key` method (either The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either
as a list of `Op`\s or a single `Op`), and discovered within a as a list of `Op`\s or a single `Op`), and discovered within a
`FunctionGraph` using the `NodeFinder` `Feature`. `FunctionGraph` using the `NodeFinder` `Feature`.
...@@ -2101,13 +2109,13 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -2101,13 +2109,13 @@ class OpKeyOptimizer(NavigatorOptimizer):
""" """
def __init__(self, local_opt, ignore_newtrees=False, failure_callback=None): def __init__(self, node_rewriter, ignore_newtrees=False, failure_callback=None):
if not hasattr(local_opt, "op_key"): if not hasattr(node_rewriter, "op_key"):
raise TypeError(f"{local_opt} must have an `op_key` method.") raise TypeError(f"{node_rewriter} must have an `op_key` method.")
super().__init__(local_opt, ignore_newtrees, failure_callback) super().__init__(node_rewriter, ignore_newtrees, failure_callback)
def apply(self, fgraph): def apply(self, fgraph):
op = self.local_opt.op_key() op = self.node_rewriter.op_key()
if isinstance(op, (list, tuple)): if isinstance(op, (list, tuple)):
q = reduce(list.__iadd__, map(fgraph.get_nodes, op)) q = reduce(list.__iadd__, map(fgraph.get_nodes, op))
else: else:
...@@ -2175,68 +2183,86 @@ def merge_dict(d1, d2): ...@@ -2175,68 +2183,86 @@ def merge_dict(d1, d2):
class EquilibriumOptimizer(NavigatorOptimizer): class EquilibriumOptimizer(NavigatorOptimizer):
"""An optimizer that applies an optimization until a fixed-point/equilibrium is reached. """An `Rewriter` that applies an optimization until a fixed-point/equilibrium is reached."""
def __init__(
self,
optimizers: Sequence[Rewriter],
failure_callback: Optional[FailureCallbackType] = None,
ignore_newtrees: bool = True,
tracks_on_change_inputs: bool = False,
max_use_ratio: Optional[float] = None,
final_optimizers: Optional[Sequence[GraphRewriter]] = None,
cleanup_optimizers: Optional[Sequence[GraphRewriter]] = None,
):
"""
Parameters Parameters
---------- ----------
optimizers : list or set optimizers
Local or global optimizations to apply until equilibrium. Node or graph rewriters to apply until equilibrium.
The global optimizer will be run at the start of each iteration before The global optimizer will be run at the start of each iteration before
the local optimizer. the node rewriter.
max_use_ratio : int or float failure_callback
Each optimizer can be applied at most ``(size of graph * this number)`` See :attr:`NavigatorOptimizer.failure_callback`.
ignore_newtrees
See :attr:`NavigatorOptimizer.ignore_newtrees`.
tracks_on_change_inputs
See :attr:`NavigatorOptimizer.tracks_on_change_inputs`.
max_use_ratio
Each rewriter can be applied at most ``(size_of_graph * max_use_ratio)``
times. times.
ignore_newtrees : final_optimizers
See :attr:`EquilibriumDB.ignore_newtrees`. Rewriters that will be run after each iteration.
final_optimizers : cleanup_optimizers
Global optimizers that will be run after each iteration. Rewriters applied after all graph rewriters, then when one
cleanup_optimizers : `NodeRewriter` is applied, then after all final rewriters.
Global optimizers that apply a list of pre determined optimization. They should not traverse the entire graph, since they are called
They must not traverse the graph as they are called very frequently. very frequently. The `MergeOptimizer` is one example of a rewriter
The MergeOptimizer is one example of optimization that respect this. that respect this.
They are applied after all global optimizers, then when one local
optimizer is applied, then after all final optimizers.
""" """
def __init__(
self,
optimizers,
failure_callback=None,
ignore_newtrees=True,
tracks_on_change_inputs=False,
max_use_ratio=None,
final_optimizers=None,
cleanup_optimizers=None,
):
super().__init__( super().__init__(
None, ignore_newtrees=ignore_newtrees, failure_callback=failure_callback None, ignore_newtrees=ignore_newtrees, failure_callback=failure_callback
) )
self.global_optimizers = [] self.global_optimizers: List[GraphRewriter] = []
self.final_optimizers = []
self.cleanup_optimizers = []
self.tracks_on_change_inputs = tracks_on_change_inputs self.tracks_on_change_inputs = tracks_on_change_inputs
self.local_tracker = LocalOptTracker() self.local_tracker = LocalOptTracker()
for opt in optimizers: for opt in optimizers:
if isinstance(opt, LocalOptimizer): if isinstance(opt, NodeRewriter):
self.local_tracker.add_tracker(opt) self.local_tracker.add_tracker(opt)
else: else:
assert isinstance(opt, GraphRewriter)
self.global_optimizers.append(opt) self.global_optimizers.append(opt)
if final_optimizers: if final_optimizers:
self.final_optimizers = final_optimizers self.final_optimizers = list(final_optimizers)
else:
self.final_optimizers = []
if cleanup_optimizers: if cleanup_optimizers:
self.cleanup_optimizers = cleanup_optimizers self.cleanup_optimizers = list(cleanup_optimizers)
else:
self.cleanup_optimizers = []
self.max_use_ratio = max_use_ratio self.max_use_ratio = max_use_ratio
def get_local_optimizers(self): def get_node_rewriters(self):
yield from self.local_tracker.get_rewriters() yield from self.local_tracker.get_rewriters()
def get_local_optimizers(self):
warnings.warn(
"`get_local_optimizers` is deprecated; use `get_node_rewriters` instead.",
DeprecationWarning,
stacklevel=2,
)
yield from self.get_node_rewriters()
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
super().add_requirements(fgraph) super().add_requirements(fgraph)
for opt in self.get_local_optimizers(): for opt in self.get_node_rewriters():
opt.add_requirements(fgraph) opt.add_requirements(fgraph)
for opt in self.global_optimizers: for opt in self.global_optimizers:
opt.add_requirements(fgraph) opt.add_requirements(fgraph)
...@@ -2274,7 +2300,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2274,7 +2300,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
cleanup_sub_profs = [] cleanup_sub_profs = []
for opt in ( for opt in (
self.global_optimizers self.global_optimizers
+ list(self.get_local_optimizers()) + list(self.get_node_rewriters())
+ self.final_optimizers + self.final_optimizers
+ self.cleanup_optimizers + self.cleanup_optimizers
): ):
...@@ -2468,7 +2494,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2468,7 +2494,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
f"{' ' * level}{self.__class__.__name__} {name} id={id(self)}", file=stream f"{' ' * level}{self.__class__.__name__} {name} id={id(self)}", file=stream
) )
if depth != 0: if depth != 0:
for lopt in self.get_local_optimizers(): for lopt in self.get_node_rewriters():
lopt.print_summary(stream, level=(level + 2), depth=(depth - 1)) lopt.print_summary(stream, level=(level + 2), depth=(depth - 1))
@staticmethod @staticmethod
...@@ -2502,7 +2528,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2502,7 +2528,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
file=stream, file=stream,
) )
print(blanc, f" time io_toposort {sum(io_toposort_timing):.3f}s", file=stream) print(blanc, f" time io_toposort {sum(io_toposort_timing):.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.get_local_optimizers()) s = sum(time_opts[o] for o in opt.get_node_rewriters())
print(blanc, f" time in local optimizers {s:.3f}s", file=stream) print(blanc, f" time in local optimizers {s:.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.global_optimizers) s = sum(time_opts[o] for o in opt.global_optimizers)
print(blanc, f" time in global optimizers {s:.3f}s", file=stream) print(blanc, f" time in global optimizers {s:.3f}s", file=stream)
...@@ -2534,7 +2560,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2534,7 +2560,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
process_count = {} process_count = {}
for o in ( for o in (
opt.global_optimizers opt.global_optimizers
+ list(opt.get_local_optimizers()) + list(opt.get_node_rewriters())
+ list(opt.final_optimizers) + list(opt.final_optimizers)
+ list(opt.cleanup_optimizers) + list(opt.cleanup_optimizers)
): ):
...@@ -2605,8 +2631,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2605,8 +2631,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def merge_profile(prof1, prof2): def merge_profile(prof1, prof2):
# (opt, loop_timing, loop_process_count, max_nb_nodes, # (opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1 # global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers = OrderedSet(prof1[0].get_local_optimizers()).union( node_rewriters = OrderedSet(prof1[0].get_node_rewriters()).union(
prof2[0].get_local_optimizers() prof2[0].get_node_rewriters()
) )
global_optimizers = OrderedSet(prof1[0].global_optimizers).union( global_optimizers = OrderedSet(prof1[0].global_optimizers).union(
prof2[0].global_optimizers prof2[0].global_optimizers
...@@ -2618,7 +2644,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2618,7 +2644,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
OrderedSet(prof1[0].cleanup_optimizers).union(prof2[0].cleanup_optimizers) OrderedSet(prof1[0].cleanup_optimizers).union(prof2[0].cleanup_optimizers)
) )
new_opt = EquilibriumOptimizer( new_opt = EquilibriumOptimizer(
local_optimizers.union(global_optimizers), node_rewriters.union(global_optimizers),
max_use_ratio=1, max_use_ratio=1,
final_optimizers=final_optimizers, final_optimizers=final_optimizers,
cleanup_optimizers=cleanup_optimizers, cleanup_optimizers=cleanup_optimizers,
...@@ -2758,7 +2784,7 @@ def check_chain(r, *chain): ...@@ -2758,7 +2784,7 @@ def check_chain(r, *chain):
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain))) return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
def pre_greedy_local_optimizer(fgraph, optimizations, out): def pre_greedy_node_rewriter(fgraph, optimizations, out):
"""Apply local optimizations to a graph. """Apply local optimizations to a graph.
This function traverses the computation graph in the graph before the This function traverses the computation graph in the graph before the
...@@ -2786,7 +2812,7 @@ def pre_greedy_local_optimizer(fgraph, optimizations, out): ...@@ -2786,7 +2812,7 @@ def pre_greedy_local_optimizer(fgraph, optimizations, out):
---------- ----------
fgraph : FunctionGraph fgraph : FunctionGraph
The graph used to avoid/filter nodes. The graph used to avoid/filter nodes.
optimizations : list of LocalOptimizer optimizations : list of NodeRewriter
The list of local optimizations to apply The list of local optimizations to apply
out : Variable out : Variable
A `Variable` specifying the graph to optimize. A `Variable` specifying the graph to optimize.
...@@ -3065,6 +3091,21 @@ DEPRECATED_NAMES = [ ...@@ -3065,6 +3091,21 @@ DEPRECATED_NAMES = [
"`GlobalOptimizer` is deprecated: use `GraphRewriter` instead.", "`GlobalOptimizer` is deprecated: use `GraphRewriter` instead.",
GraphRewriter, GraphRewriter,
), ),
(
"LocalOptimizer",
"`LocalOptimizer` is deprecated: use `NodeRewriter` instead.",
NodeRewriter,
),
(
"local_optimizer",
"`local_optimizer` is deprecated: use `node_rewriter` instead.",
node_rewriter,
),
(
"pre_greedy_local_optimizer",
"`pre_greedy_local_optimizer` is deprecated: use `pre_greedy_node_rewriter` instead.",
pre_greedy_node_rewriter,
),
] ]
......
...@@ -11,14 +11,14 @@ from aesara.misc.ordered_set import OrderedSet ...@@ -11,14 +11,14 @@ from aesara.misc.ordered_set import OrderedSet
from aesara.utils import DefaultOrderedDict from aesara.utils import DefaultOrderedDict
OptimizersType = Union[aesara_opt.GraphRewriter, aesara_opt.LocalOptimizer] OptimizersType = Union[aesara_opt.GraphRewriter, aesara_opt.NodeRewriter]
class OptimizationDatabase: class OptimizationDatabase:
r"""A class that represents a collection/database of optimizations. r"""A class that represents a collection/database of optimizations.
These databases are used to logically organize collections of optimizers These databases are used to logically organize collections of optimizers
(i.e. `GraphRewriter`\s and `LocalOptimizer`). (i.e. `GraphRewriter`\s and `NodeRewriter`).
""" """
def __init__(self): def __init__(self):
...@@ -62,7 +62,7 @@ class OptimizationDatabase: ...@@ -62,7 +62,7 @@ class OptimizationDatabase:
( (
OptimizationDatabase, OptimizationDatabase,
aesara_opt.GraphRewriter, aesara_opt.GraphRewriter,
aesara_opt.LocalOptimizer, aesara_opt.NodeRewriter,
), ),
): ):
raise TypeError(f"{optimizer} is not a valid optimizer type.") raise TypeError(f"{optimizer} is not a valid optimizer type.")
...@@ -311,7 +311,7 @@ class EquilibriumDB(OptimizationDatabase): ...@@ -311,7 +311,7 @@ class EquilibriumDB(OptimizationDatabase):
Notes Notes
----- -----
We can use `LocalOptimizer` and `GraphRewriter` since `EquilibriumOptimizer` We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumOptimizer`
supports both. supports both.
It is probably not a good idea to have ignore_newtrees=False and It is probably not a good idea to have ignore_newtrees=False and
...@@ -474,24 +474,18 @@ class SequenceDB(OptimizationDatabase): ...@@ -474,24 +474,18 @@ class SequenceDB(OptimizationDatabase):
class LocalGroupDB(SequenceDB): class LocalGroupDB(SequenceDB):
""" r"""A database that generates `NodeRewriter`\s of type `LocalOptGroup`."""
Generate a local optimizer of type LocalOptGroup instead
of a global optimizer.
It supports the tracks, to only get applied to some Op.
"""
def __init__( def __init__(
self, self,
apply_all_opts: bool = False, apply_all_opts: bool = False,
profile: bool = False, profile: bool = False,
local_opt=aesara_opt.LocalOptGroup, node_rewriter=aesara_opt.LocalOptGroup,
): ):
super().__init__(failure_callback=None) super().__init__(failure_callback=None)
self.apply_all_opts = apply_all_opts self.apply_all_opts = apply_all_opts
self.profile = profile self.profile = profile
self.local_opt = local_opt self.node_rewriter = node_rewriter
self.__name__: str = "" self.__name__: str = ""
def register(self, name, obj, *tags, position="last", **kwargs): def register(self, name, obj, *tags, position="last", **kwargs):
...@@ -499,7 +493,7 @@ class LocalGroupDB(SequenceDB): ...@@ -499,7 +493,7 @@ class LocalGroupDB(SequenceDB):
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
opts = list(super().query(*tags, **kwtags)) opts = list(super().query(*tags, **kwtags))
ret = self.local_opt( ret = self.node_rewriter(
*opts, apply_all_opts=self.apply_all_opts, profile=self.profile *opts, apply_all_opts=self.apply_all_opts, profile=self.profile
) )
return ret return ret
......
...@@ -22,7 +22,7 @@ from aesara.compile import optdb ...@@ -22,7 +22,7 @@ from aesara.compile import optdb
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GraphRewriter, in2out, local_optimizer from aesara.graph.opt import GraphRewriter, in2out, node_rewriter
from aesara.graph.type import HasDataType, HasShape from aesara.graph.type import HasDataType, HasShape
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
...@@ -404,7 +404,7 @@ def ifelse( ...@@ -404,7 +404,7 @@ def ifelse(
return tuple(rval) return tuple(rval)
@local_optimizer([IfElse]) @node_rewriter([IfElse])
def cond_make_inplace(fgraph, node): def cond_make_inplace(fgraph, node):
op = node.op op = node.op
if ( if (
...@@ -482,7 +482,7 @@ acceptable_ops = ( ...@@ -482,7 +482,7 @@ acceptable_ops = (
) )
@local_optimizer(acceptable_ops) @node_rewriter(acceptable_ops)
def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node): def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node):
"""This optimization lifts up certain ifelse instances. """This optimization lifts up certain ifelse instances.
...@@ -529,7 +529,7 @@ def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node): ...@@ -529,7 +529,7 @@ def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node):
return nw_outs return nw_outs
@local_optimizer([IfElse]) @node_rewriter([IfElse])
def cond_merge_ifs_true(fgraph, node): def cond_merge_ifs_true(fgraph, node):
op = node.op op = node.op
if not isinstance(op, IfElse): if not isinstance(op, IfElse):
...@@ -556,7 +556,7 @@ def cond_merge_ifs_true(fgraph, node): ...@@ -556,7 +556,7 @@ def cond_merge_ifs_true(fgraph, node):
return op(*old_ins, return_list=True) return op(*old_ins, return_list=True)
@local_optimizer([IfElse]) @node_rewriter([IfElse])
def cond_merge_ifs_false(fgraph, node): def cond_merge_ifs_false(fgraph, node):
op = node.op op = node.op
if not isinstance(op, IfElse): if not isinstance(op, IfElse):
...@@ -635,7 +635,7 @@ class CondMerge(GraphRewriter): ...@@ -635,7 +635,7 @@ class CondMerge(GraphRewriter):
fgraph.replace_all_validate(pairs, reason="cond_merge") fgraph.replace_all_validate(pairs, reason="cond_merge")
@local_optimizer([IfElse]) @node_rewriter([IfElse])
def cond_remove_identical(fgraph, node): def cond_remove_identical(fgraph, node):
op = node.op op = node.op
...@@ -681,7 +681,7 @@ def cond_remove_identical(fgraph, node): ...@@ -681,7 +681,7 @@ def cond_remove_identical(fgraph, node):
return rval return rval
@local_optimizer([IfElse]) @node_rewriter([IfElse])
def cond_merge_random_op(fgraph, main_node): def cond_merge_random_op(fgraph, main_node):
if isinstance(main_node.op, IfElse): if isinstance(main_node.op, IfElse):
return False return False
......
import logging import logging
from aesara.graph.opt import local_optimizer from aesara.graph.opt import node_rewriter
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor.basic_opt import ( from aesara.tensor.basic_opt import (
register_canonicalize, register_canonicalize,
...@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) ...@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
@register_canonicalize @register_canonicalize
@local_optimizer([DimShuffle]) @node_rewriter([DimShuffle])
def transinv_to_invtrans(fgraph, node): def transinv_to_invtrans(fgraph, node):
if isinstance(node.op, DimShuffle): if isinstance(node.op, DimShuffle):
if node.op.new_order == (1, 0): if node.op.new_order == (1, 0):
...@@ -32,7 +32,7 @@ def transinv_to_invtrans(fgraph, node): ...@@ -32,7 +32,7 @@ def transinv_to_invtrans(fgraph, node):
@register_stabilize @register_stabilize
@local_optimizer([Dot, Dot22]) @node_rewriter([Dot, Dot22])
def inv_as_solve(fgraph, node): def inv_as_solve(fgraph, node):
""" """
This utilizes a boolean `symmetric` tag on the matrices. This utilizes a boolean `symmetric` tag on the matrices.
...@@ -51,7 +51,7 @@ def inv_as_solve(fgraph, node): ...@@ -51,7 +51,7 @@ def inv_as_solve(fgraph, node):
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
@local_optimizer([Solve]) @node_rewriter([Solve])
def tag_solve_triangular(fgraph, node): def tag_solve_triangular(fgraph, node):
""" """
If a general solve() is applied to the output of a cholesky op, then If a general solve() is applied to the output of a cholesky op, then
...@@ -82,7 +82,7 @@ def tag_solve_triangular(fgraph, node): ...@@ -82,7 +82,7 @@ def tag_solve_triangular(fgraph, node):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([DimShuffle]) @node_rewriter([DimShuffle])
def no_transpose_symmetric(fgraph, node): def no_transpose_symmetric(fgraph, node):
if isinstance(node.op, DimShuffle): if isinstance(node.op, DimShuffle):
x = node.inputs[0] x = node.inputs[0]
...@@ -92,7 +92,7 @@ def no_transpose_symmetric(fgraph, node): ...@@ -92,7 +92,7 @@ def no_transpose_symmetric(fgraph, node):
@register_stabilize @register_stabilize
@local_optimizer([Solve]) @node_rewriter([Solve])
def psd_solve_with_chol(fgraph, node): def psd_solve_with_chol(fgraph, node):
""" """
This utilizes a boolean `psd` tag on matrices. This utilizes a boolean `psd` tag on matrices.
...@@ -111,7 +111,7 @@ def psd_solve_with_chol(fgraph, node): ...@@ -111,7 +111,7 @@ def psd_solve_with_chol(fgraph, node):
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([Det]) @node_rewriter([Det])
def local_det_chol(fgraph, node): def local_det_chol(fgraph, node):
""" """
If we have det(X) and there is already an L=cholesky(X) If we have det(X) and there is already an L=cholesky(X)
...@@ -129,7 +129,7 @@ def local_det_chol(fgraph, node): ...@@ -129,7 +129,7 @@ def local_det_chol(fgraph, node):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([log]) @node_rewriter([log])
def local_log_prod_sqr(fgraph, node): def local_log_prod_sqr(fgraph, node):
""" """
This utilizes a boolean `positive` tag on matrices. This utilizes a boolean `positive` tag on matrices.
......
...@@ -25,7 +25,7 @@ from aesara.compile import optdb ...@@ -25,7 +25,7 @@ from aesara.compile import optdb
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import undefined_grad from aesara.gradient import undefined_grad
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.opt import in2out, local_optimizer from aesara.graph.opt import in2out, node_rewriter
from aesara.link.c.op import COp, Op from aesara.link.c.op import COp, Op
from aesara.link.c.params_type import ParamsType from aesara.link.c.params_type import ParamsType
from aesara.sandbox import multinomial from aesara.sandbox import multinomial
...@@ -1343,7 +1343,7 @@ def _check_size(size): ...@@ -1343,7 +1343,7 @@ def _check_size(size):
return at.as_tensor_variable(size, ndim=1) return at.as_tensor_variable(size, ndim=1)
@local_optimizer((mrg_uniform_base,)) @node_rewriter((mrg_uniform_base,))
def mrg_random_make_inplace(fgraph, node): def mrg_random_make_inplace(fgraph, node):
op = node.op op = node.op
......
...@@ -28,7 +28,7 @@ from aesara.graph.destroyhandler import DestroyHandler ...@@ -28,7 +28,7 @@ from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value from aesara.graph.op import compute_test_value
from aesara.graph.opt import GraphRewriter, in2out, local_optimizer from aesara.graph.opt import GraphRewriter, in2out, node_rewriter
from aesara.graph.optdb import EquilibriumDB, SequenceDB from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.graph.type import HasShape from aesara.graph.type import HasShape
from aesara.graph.utils import InconsistencyError from aesara.graph.utils import InconsistencyError
...@@ -67,7 +67,7 @@ list_opt_slice = [ ...@@ -67,7 +67,7 @@ list_opt_slice = [
] ]
@local_optimizer([Scan]) @node_rewriter([Scan])
def remove_constants_and_unused_inputs_scan(fgraph, node): def remove_constants_and_unused_inputs_scan(fgraph, node):
"""Move constants into the inner graph, and remove unused inputs. """Move constants into the inner graph, and remove unused inputs.
...@@ -192,7 +192,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -192,7 +192,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
return False return False
@local_optimizer([Scan]) @node_rewriter([Scan])
def push_out_non_seq_scan(fgraph, node): def push_out_non_seq_scan(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on non-sequences. r"""Push out the variables inside the `Scan` that depend only on non-sequences.
...@@ -400,7 +400,7 @@ def push_out_non_seq_scan(fgraph, node): ...@@ -400,7 +400,7 @@ def push_out_non_seq_scan(fgraph, node):
return False return False
@local_optimizer([Scan]) @node_rewriter([Scan])
def push_out_seq_scan(fgraph, node): def push_out_seq_scan(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on constants and sequences. r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
...@@ -812,7 +812,7 @@ def add_nitsot_outputs( ...@@ -812,7 +812,7 @@ def add_nitsot_outputs(
return new_scan_node, {} return new_scan_node, {}
@local_optimizer([Scan]) @node_rewriter([Scan])
def push_out_add_scan(fgraph, node): def push_out_add_scan(fgraph, node):
r"""Push `Add` operations performed at the end of the inner graph to the outside. r"""Push `Add` operations performed at the end of the inner graph to the outside.
...@@ -1113,7 +1113,7 @@ def sanitize(x): ...@@ -1113,7 +1113,7 @@ def sanitize(x):
return at.as_tensor_variable(x) return at.as_tensor_variable(x)
@local_optimizer([Scan]) @node_rewriter([Scan])
def save_mem_new_scan(fgraph, node): def save_mem_new_scan(fgraph, node):
r"""Graph optimizer that reduces scan memory consumption. r"""Graph optimizer that reduces scan memory consumption.
...@@ -1950,7 +1950,7 @@ def make_equiv(lo, li): ...@@ -1950,7 +1950,7 @@ def make_equiv(lo, li):
return left, right return left, right
@local_optimizer([Scan]) @node_rewriter([Scan])
def scan_merge_inouts(fgraph, node): def scan_merge_inouts(fgraph, node):
""" """
This optimization attempts to merge a `Scan` `Op`'s identical outer inputs as well This optimization attempts to merge a `Scan` `Op`'s identical outer inputs as well
...@@ -2154,7 +2154,7 @@ def scan_merge_inouts(fgraph, node): ...@@ -2154,7 +2154,7 @@ def scan_merge_inouts(fgraph, node):
return na.outer_outputs return na.outer_outputs
@local_optimizer([Scan]) @node_rewriter([Scan])
def push_out_dot1_scan(fgraph, node): def push_out_dot1_scan(fgraph, node):
r""" r"""
This is another optimization that attempts to detect certain patterns of This is another optimization that attempts to detect certain patterns of
......
...@@ -4,7 +4,7 @@ import aesara ...@@ -4,7 +4,7 @@ import aesara
import aesara.scalar as aes import aesara.scalar as aes
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.opt import PatternSub, TopoOptimizer, local_optimizer from aesara.graph.opt import PatternSub, TopoOptimizer, node_rewriter
from aesara.link.c.op import COp, _NoPythonCOp from aesara.link.c.op import COp, _NoPythonCOp
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.sparse import basic as sparse from aesara.sparse import basic as sparse
...@@ -32,7 +32,7 @@ _is_dense = sparse._is_dense ...@@ -32,7 +32,7 @@ _is_dense = sparse._is_dense
# This is tested in tests/test_opt.py:test_local_csm_properties_csm # This is tested in tests/test_opt.py:test_local_csm_properties_csm
@local_optimizer([csm_properties]) @node_rewriter([csm_properties])
def local_csm_properties_csm(fgraph, node): def local_csm_properties_csm(fgraph, node):
""" """
If we find csm_properties(CSM(*args)), then we can replace that with the If we find csm_properties(CSM(*args)), then we can replace that with the
...@@ -51,7 +51,7 @@ register_specialize(local_csm_properties_csm) ...@@ -51,7 +51,7 @@ register_specialize(local_csm_properties_csm)
# This is tested in tests/test_basic.py:test_remove0 # This is tested in tests/test_basic.py:test_remove0
@local_optimizer([sparse.Remove0]) @node_rewriter([sparse.Remove0])
def local_inplace_remove0(fgraph, node): def local_inplace_remove0(fgraph, node):
""" """
Optimization to insert inplace versions of Remove0. Optimization to insert inplace versions of Remove0.
...@@ -188,7 +188,7 @@ class AddSD_ccode(_NoPythonCOp): ...@@ -188,7 +188,7 @@ class AddSD_ccode(_NoPythonCOp):
return (2,) return (2,)
@local_optimizer([sparse.AddSD]) @node_rewriter([sparse.AddSD])
def local_inplace_addsd_ccode(fgraph, node): def local_inplace_addsd_ccode(fgraph, node):
""" """
Optimization to insert inplace versions of AddSD. Optimization to insert inplace versions of AddSD.
...@@ -218,7 +218,7 @@ aesara.compile.optdb.register( ...@@ -218,7 +218,7 @@ aesara.compile.optdb.register(
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@register_specialize @register_specialize
@local_optimizer([sparse.DenseFromSparse]) @node_rewriter([sparse.DenseFromSparse])
def local_dense_from_sparse_sparse_from_dense(fgraph, node): def local_dense_from_sparse_sparse_from_dense(fgraph, node):
if isinstance(node.op, sparse.DenseFromSparse): if isinstance(node.op, sparse.DenseFromSparse):
inp = node.inputs[0] inp = node.inputs[0]
...@@ -226,7 +226,7 @@ def local_dense_from_sparse_sparse_from_dense(fgraph, node): ...@@ -226,7 +226,7 @@ def local_dense_from_sparse_sparse_from_dense(fgraph, node):
return inp.owner.inputs return inp.owner.inputs
@local_optimizer([sparse.AddSD]) @node_rewriter([sparse.AddSD])
def local_addsd_ccode(fgraph, node): def local_addsd_ccode(fgraph, node):
""" """
Convert AddSD to faster AddSD_ccode. Convert AddSD to faster AddSD_ccode.
...@@ -638,7 +638,7 @@ sd_csr = StructuredDotCSR() ...@@ -638,7 +638,7 @@ sd_csr = StructuredDotCSR()
# register a specialization to replace StructuredDot -> StructuredDotCSx # register a specialization to replace StructuredDot -> StructuredDotCSx
# This is tested in tests/test_basic.py:792 # This is tested in tests/test_basic.py:792
@local_optimizer([sparse._structured_dot]) @node_rewriter([sparse._structured_dot])
def local_structured_dot(fgraph, node): def local_structured_dot(fgraph, node):
if node.op == sparse._structured_dot: if node.op == sparse._structured_dot:
a, b = node.inputs a, b = node.inputs
...@@ -950,7 +950,7 @@ register_specialize(local_usmm, name="local_usmm") ...@@ -950,7 +950,7 @@ register_specialize(local_usmm, name="local_usmm")
# register a specialization to replace usmm_csc_dense -> usmm_csc_dense_inplace # register a specialization to replace usmm_csc_dense -> usmm_csc_dense_inplace
# This is tested in tests/test_basic.py:UsmmTests # This is tested in tests/test_basic.py:UsmmTests
@local_optimizer([usmm_csc_dense]) @node_rewriter([usmm_csc_dense])
def local_usmm_csc_dense_inplace(fgraph, node): def local_usmm_csc_dense_inplace(fgraph, node):
if node.op == usmm_csc_dense: if node.op == usmm_csc_dense:
return [usmm_csc_dense_inplace(*node.inputs)] return [usmm_csc_dense_inplace(*node.inputs)]
...@@ -960,7 +960,7 @@ register_specialize(local_usmm_csc_dense_inplace, "cxx_only", "inplace") ...@@ -960,7 +960,7 @@ register_specialize(local_usmm_csc_dense_inplace, "cxx_only", "inplace")
# This is tested in tests/test_basic.py:UsmmTests # This is tested in tests/test_basic.py:UsmmTests
@local_optimizer([usmm]) @node_rewriter([usmm])
def local_usmm_csx(fgraph, node): def local_usmm_csx(fgraph, node):
""" """
usmm -> usmm_csc_dense usmm -> usmm_csc_dense
...@@ -1120,7 +1120,7 @@ csm_grad_c = CSMGradC() ...@@ -1120,7 +1120,7 @@ csm_grad_c = CSMGradC()
# register a specialization to replace csm_grad -> csm_grad_c # register a specialization to replace csm_grad -> csm_grad_c
# This is tested in tests/test_opt.py:test_local_csm_grad_c # This is tested in tests/test_opt.py:test_local_csm_grad_c
@local_optimizer([csm_grad(None)]) @node_rewriter([csm_grad(None)])
def local_csm_grad_c(fgraph, node): def local_csm_grad_c(fgraph, node):
""" """
csm_grad(None) -> csm_grad_c csm_grad(None) -> csm_grad_c
...@@ -1404,7 +1404,7 @@ mul_s_d_csr = MulSDCSR() ...@@ -1404,7 +1404,7 @@ mul_s_d_csr = MulSDCSR()
# register a specialization to replace MulSD -> MulSDCSX # register a specialization to replace MulSD -> MulSDCSX
@local_optimizer([sparse.mul_s_d]) @node_rewriter([sparse.mul_s_d])
def local_mul_s_d(fgraph, node): def local_mul_s_d(fgraph, node):
if node.op == sparse.mul_s_d: if node.op == sparse.mul_s_d:
x, y = node.inputs x, y = node.inputs
...@@ -1584,7 +1584,7 @@ mul_s_v_csr = MulSVCSR() ...@@ -1584,7 +1584,7 @@ mul_s_v_csr = MulSVCSR()
# register a specialization to replace MulSV -> MulSVCSR # register a specialization to replace MulSV -> MulSVCSR
@local_optimizer([sparse.mul_s_v]) @node_rewriter([sparse.mul_s_v])
def local_mul_s_v(fgraph, node): def local_mul_s_v(fgraph, node):
if node.op == sparse.mul_s_v: if node.op == sparse.mul_s_v:
x, y = node.inputs x, y = node.inputs
...@@ -1762,7 +1762,7 @@ structured_add_s_v_csr = StructuredAddSVCSR() ...@@ -1762,7 +1762,7 @@ structured_add_s_v_csr = StructuredAddSVCSR()
# register a specialization to replace # register a specialization to replace
# structured_add_s_v -> structured_add_s_v_csr # structured_add_s_v -> structured_add_s_v_csr
@local_optimizer([sparse.structured_add_s_v]) @node_rewriter([sparse.structured_add_s_v])
def local_structured_add_s_v(fgraph, node): def local_structured_add_s_v(fgraph, node):
if node.op == sparse.structured_add_s_v: if node.op == sparse.structured_add_s_v:
x, y = node.inputs x, y = node.inputs
...@@ -2051,7 +2051,7 @@ sampling_dot_csr = SamplingDotCSR() ...@@ -2051,7 +2051,7 @@ sampling_dot_csr = SamplingDotCSR()
# register a specialization to replace SamplingDot -> SamplingDotCsr # register a specialization to replace SamplingDot -> SamplingDotCsr
@local_optimizer([sparse.sampling_dot]) @node_rewriter([sparse.sampling_dot])
def local_sampling_dot_csr(fgraph, node): def local_sampling_dot_csr(fgraph, node):
if not config.blas__ldflags: if not config.blas__ldflags:
# The C implementation of SamplingDotCsr relies on BLAS routines # The C implementation of SamplingDotCsr relies on BLAS routines
......
...@@ -32,7 +32,7 @@ from aesara.graph.opt import ( ...@@ -32,7 +32,7 @@ from aesara.graph.opt import (
check_chain, check_chain,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
local_optimizer, node_rewriter,
) )
from aesara.graph.optdb import SequenceDB from aesara.graph.optdb import SequenceDB
from aesara.graph.utils import ( from aesara.graph.utils import (
...@@ -605,7 +605,7 @@ def is_dimshuffle_useless(new_order, input): ...@@ -605,7 +605,7 @@ def is_dimshuffle_useless(new_order, input):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([DimShuffle]) @node_rewriter([DimShuffle])
def local_dimshuffle_lift(fgraph, node): def local_dimshuffle_lift(fgraph, node):
""" """
"Lifts" DimShuffle through Elemwise operations and merges "Lifts" DimShuffle through Elemwise operations and merges
...@@ -651,7 +651,7 @@ def local_dimshuffle_lift(fgraph, node): ...@@ -651,7 +651,7 @@ def local_dimshuffle_lift(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([DimShuffle]) @node_rewriter([DimShuffle])
def local_useless_dimshuffle_makevector(fgraph, node): def local_useless_dimshuffle_makevector(fgraph, node):
r"""Remove `DimShuffle`\s that drop one dimensional broadcastable `MakeVector`s. r"""Remove `DimShuffle`\s that drop one dimensional broadcastable `MakeVector`s.
...@@ -680,7 +680,7 @@ def local_useless_dimshuffle_makevector(fgraph, node): ...@@ -680,7 +680,7 @@ def local_useless_dimshuffle_makevector(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([Reshape]) @node_rewriter([Reshape])
def local_useless_dimshuffle_in_reshape(fgraph, node): def local_useless_dimshuffle_in_reshape(fgraph, node):
""" """
Removes useless DimShuffle operation inside Reshape: Removes useless DimShuffle operation inside Reshape:
...@@ -720,7 +720,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node): ...@@ -720,7 +720,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([TensorFromScalar]) @node_rewriter([TensorFromScalar])
def local_tensor_scalar_tensor(fgraph, node): def local_tensor_scalar_tensor(fgraph, node):
"""tensor_from_scalar(scalar_from_tensor(x)) -> x""" """tensor_from_scalar(scalar_from_tensor(x)) -> x"""
if isinstance(node.op, TensorFromScalar): if isinstance(node.op, TensorFromScalar):
...@@ -734,7 +734,7 @@ def local_tensor_scalar_tensor(fgraph, node): ...@@ -734,7 +734,7 @@ def local_tensor_scalar_tensor(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([ScalarFromTensor]) @node_rewriter([ScalarFromTensor])
def local_scalar_tensor_scalar(fgraph, node): def local_scalar_tensor_scalar(fgraph, node):
"""scalar_from_tensor(tensor_from_scalar(x)) -> x""" """scalar_from_tensor(tensor_from_scalar(x)) -> x"""
if isinstance(node.op, ScalarFromTensor): if isinstance(node.op, ScalarFromTensor):
...@@ -1474,7 +1474,7 @@ aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10 ...@@ -1474,7 +1474,7 @@ aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10
@register_specialize("local_alloc_elemwise") @register_specialize("local_alloc_elemwise")
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_elemwise_alloc(fgraph, node): def local_elemwise_alloc(fgraph, node):
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s. r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
...@@ -1595,7 +1595,7 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1595,7 +1595,7 @@ def local_elemwise_alloc(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_fill_sink(fgraph, node): def local_fill_sink(fgraph, node):
""" """
f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e))) f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
...@@ -1647,7 +1647,7 @@ def local_fill_sink(fgraph, node): ...@@ -1647,7 +1647,7 @@ def local_fill_sink(fgraph, node):
@register_specialize @register_specialize
@register_stabilize @register_stabilize
@local_optimizer([fill]) @node_rewriter([fill])
def local_fill_to_alloc(fgraph, node): def local_fill_to_alloc(fgraph, node):
r"""Remove `fill`\s or replace them with `Alloc`\s. r"""Remove `fill`\s or replace them with `Alloc`\s.
...@@ -1698,7 +1698,7 @@ compile.optdb.register( ...@@ -1698,7 +1698,7 @@ compile.optdb.register(
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@register_useless @register_useless
@local_optimizer([fill]) @node_rewriter([fill])
def local_useless_fill(fgraph, node): def local_useless_fill(fgraph, node):
"""fill(s,v) -> v """fill(s,v) -> v
...@@ -1721,7 +1721,7 @@ def local_useless_fill(fgraph, node): ...@@ -1721,7 +1721,7 @@ def local_useless_fill(fgraph, node):
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
@register_useless @register_useless
@local_optimizer([Alloc]) @node_rewriter([Alloc])
def local_useless_alloc(fgraph, node): def local_useless_alloc(fgraph, node):
""" """
If the input type is the same as the output type (dtype and broadcast) If the input type is the same as the output type (dtype and broadcast)
...@@ -1751,7 +1751,7 @@ def local_useless_alloc(fgraph, node): ...@@ -1751,7 +1751,7 @@ def local_useless_alloc(fgraph, node):
@register_specialize @register_specialize
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
@local_optimizer([Alloc]) @node_rewriter([Alloc])
def local_alloc_sink_dimshuffle(fgraph, node): def local_alloc_sink_dimshuffle(fgraph, node):
r"""Convert broadcastable leading dimensions in an `Alloc` to `DimShuffle`\s.""" r"""Convert broadcastable leading dimensions in an `Alloc` to `DimShuffle`\s."""
op = node.op op = node.op
...@@ -1785,7 +1785,7 @@ def local_alloc_sink_dimshuffle(fgraph, node): ...@@ -1785,7 +1785,7 @@ def local_alloc_sink_dimshuffle(fgraph, node):
return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
@local_optimizer([AllocEmpty]) @node_rewriter([AllocEmpty])
def local_alloc_empty_to_zeros(fgraph, node): def local_alloc_empty_to_zeros(fgraph, node):
"""This convert AllocEmpty to Alloc of 0. """This convert AllocEmpty to Alloc of 0.
...@@ -1808,7 +1808,7 @@ compile.optdb.register( ...@@ -1808,7 +1808,7 @@ compile.optdb.register(
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@local_optimizer([Shape]) @node_rewriter([Shape])
def local_shape_to_shape_i(fgraph, node): def local_shape_to_shape_i(fgraph, node):
if isinstance(node.op, Shape): if isinstance(node.op, Shape):
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
...@@ -1824,7 +1824,7 @@ def local_shape_to_shape_i(fgraph, node): ...@@ -1824,7 +1824,7 @@ def local_shape_to_shape_i(fgraph, node):
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@local_optimizer([Shape_i]) @node_rewriter([Shape_i])
def local_track_shape_i(fgraph, node): def local_track_shape_i(fgraph, node):
if not isinstance(node.op, Shape_i): if not isinstance(node.op, Shape_i):
return False return False
...@@ -1847,7 +1847,7 @@ def local_track_shape_i(fgraph, node): ...@@ -1847,7 +1847,7 @@ def local_track_shape_i(fgraph, node):
@register_useless @register_useless
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@register_specialize @register_specialize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_useless_elemwise(fgraph, node): def local_useless_elemwise(fgraph, node):
""" """
eq(x, x) -> 1 eq(x, x) -> 1
...@@ -1952,7 +1952,7 @@ def local_useless_elemwise(fgraph, node): ...@@ -1952,7 +1952,7 @@ def local_useless_elemwise(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_alloc_unary(fgraph, node): def local_alloc_unary(fgraph, node):
"""unary(alloc(x, shp)) -> alloc(unary(x), shp)""" """unary(alloc(x, shp)) -> alloc(unary(x), shp)"""
if isinstance(node.op, Elemwise) and len(node.inputs) == 1: if isinstance(node.op, Elemwise) and len(node.inputs) == 1:
...@@ -1974,7 +1974,7 @@ def local_alloc_unary(fgraph, node): ...@@ -1974,7 +1974,7 @@ def local_alloc_unary(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_cast_cast(fgraph, node): def local_cast_cast(fgraph, node):
"""cast(cast(x, dtype1), dtype2) """cast(cast(x, dtype1), dtype2)
...@@ -2052,7 +2052,7 @@ def is_an_upcast(type1, type2): ...@@ -2052,7 +2052,7 @@ def is_an_upcast(type1, type2):
@register_useless @register_useless
@register_specialize @register_specialize
@local_optimizer(None) @node_rewriter(None)
def local_remove_useless_assert(fgraph, node): def local_remove_useless_assert(fgraph, node):
if not isinstance(node.op, CheckAndRaise): if not isinstance(node.op, CheckAndRaise):
return False return False
...@@ -2079,7 +2079,7 @@ def local_remove_useless_assert(fgraph, node): ...@@ -2079,7 +2079,7 @@ def local_remove_useless_assert(fgraph, node):
return [new_var] return [new_var]
@local_optimizer([Assert]) @node_rewriter([Assert])
def local_remove_all_assert(fgraph, node): def local_remove_all_assert(fgraph, node):
"""An optimization disabled by default that removes all asserts from """An optimization disabled by default that removes all asserts from
the graph. the graph.
...@@ -2122,7 +2122,7 @@ compile.optdb["useless"].register( ...@@ -2122,7 +2122,7 @@ compile.optdb["useless"].register(
@register_canonicalize @register_canonicalize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_upcast_elemwise_constant_inputs(fgraph, node): def local_upcast_elemwise_constant_inputs(fgraph, node):
"""This explicitly upcasts constant inputs to elemwise Ops, when """This explicitly upcasts constant inputs to elemwise Ops, when
those Ops do implicit upcasting anyway. those Ops do implicit upcasting anyway.
...@@ -2197,7 +2197,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): ...@@ -2197,7 +2197,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Unbroadcast]) @node_rewriter([Unbroadcast])
def local_useless_unbroadcast(fgraph, node): def local_useless_unbroadcast(fgraph, node):
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern. """Remove `Unbroadcast` if it does not actually change the broadcasting pattern.
...@@ -2225,7 +2225,7 @@ def local_useless_unbroadcast(fgraph, node): ...@@ -2225,7 +2225,7 @@ def local_useless_unbroadcast(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Unbroadcast]) @node_rewriter([Unbroadcast])
def local_unbroadcast_lift(fgraph, node): def local_unbroadcast_lift(fgraph, node):
""" """
Lifts `Unbroadcast` through unary Elemwise operations, Lifts `Unbroadcast` through unary Elemwise operations,
...@@ -2271,7 +2271,7 @@ def local_unbroadcast_lift(fgraph, node): ...@@ -2271,7 +2271,7 @@ def local_unbroadcast_lift(fgraph, node):
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@register_useless @register_useless
@local_optimizer([Join]) @node_rewriter([Join])
def local_join_1(fgraph, node): def local_join_1(fgraph, node):
"""Join(i, x) => x """Join(i, x) => x
...@@ -2291,7 +2291,7 @@ def local_join_1(fgraph, node): ...@@ -2291,7 +2291,7 @@ def local_join_1(fgraph, node):
@register_useless @register_useless
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@local_optimizer([Join]) @node_rewriter([Join])
def local_join_empty(fgraph, node): def local_join_empty(fgraph, node):
"""Join(i, x, y, empty) => Join(i, x, y) """Join(i, x, y, empty) => Join(i, x, y)
...@@ -2338,7 +2338,7 @@ def local_join_empty(fgraph, node): ...@@ -2338,7 +2338,7 @@ def local_join_empty(fgraph, node):
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@register_useless @register_useless
@local_optimizer([Join]) @node_rewriter([Join])
def local_join_make_vector(fgraph, node): def local_join_make_vector(fgraph, node):
r"""Merge `MakeVector` inputs within a `Join`. r"""Merge `MakeVector` inputs within a `Join`.
...@@ -2385,7 +2385,7 @@ def local_join_make_vector(fgraph, node): ...@@ -2385,7 +2385,7 @@ def local_join_make_vector(fgraph, node):
@register_useless("local_remove_switch_const_cond") @register_useless("local_remove_switch_const_cond")
@register_canonicalize("fast_compile", "local_remove_switch_const_cond") @register_canonicalize("fast_compile", "local_remove_switch_const_cond")
@register_specialize @register_specialize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_useless_switch(fgraph, node): def local_useless_switch(fgraph, node):
""" """
This optimization makes the following changes in the graph: This optimization makes the following changes in the graph:
...@@ -2462,7 +2462,7 @@ def local_useless_switch(fgraph, node): ...@@ -2462,7 +2462,7 @@ def local_useless_switch(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_merge_switch_same_cond(fgraph, node): def local_merge_switch_same_cond(fgraph, node):
""" """
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
...@@ -2499,7 +2499,7 @@ def local_merge_switch_same_cond(fgraph, node): ...@@ -2499,7 +2499,7 @@ def local_merge_switch_same_cond(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Split]) @node_rewriter([Split])
def local_useless_split(fgraph, node): def local_useless_split(fgraph, node):
"""Split{n_splits=1}(x, y) -> x """Split{n_splits=1}(x, y) -> x
...@@ -2520,7 +2520,7 @@ def local_useless_split(fgraph, node): ...@@ -2520,7 +2520,7 @@ def local_useless_split(fgraph, node):
def local_reshape_chain(op): def local_reshape_chain(op):
@local_optimizer([op]) @node_rewriter([op])
def f(fgraph, node): def f(fgraph, node):
""" """
Reshape(Reshape(shape1),shape2) -> Reshape(shape2) Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
...@@ -2560,7 +2560,7 @@ register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain") ...@@ -2560,7 +2560,7 @@ register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain")
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@local_optimizer([Reshape]) @node_rewriter([Reshape])
def local_useless_reshape(fgraph, node): def local_useless_reshape(fgraph, node):
""" """
Remove two kinds of useless reshape. Remove two kinds of useless reshape.
...@@ -2658,7 +2658,7 @@ def local_useless_reshape(fgraph, node): ...@@ -2658,7 +2658,7 @@ def local_useless_reshape(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([Reshape]) @node_rewriter([Reshape])
def local_reshape_to_dimshuffle(fgraph, node): def local_reshape_to_dimshuffle(fgraph, node):
""" """
Broadcastable dimensions in Reshape are replaced with dimshuffle. Broadcastable dimensions in Reshape are replaced with dimshuffle.
...@@ -2706,7 +2706,7 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -2706,7 +2706,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@local_optimizer([Reshape]) @node_rewriter([Reshape])
def local_reshape_lift(fgraph, node): def local_reshape_lift(fgraph, node):
""" """
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x)) Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
...@@ -2736,7 +2736,7 @@ def local_reshape_lift(fgraph, node): ...@@ -2736,7 +2736,7 @@ def local_reshape_lift(fgraph, node):
register_canonicalize(OpRemove(tensor_copy), name="remove_tensor_copy") register_canonicalize(OpRemove(tensor_copy), name="remove_tensor_copy")
@local_optimizer(None) @node_rewriter(None)
def constant_folding(fgraph, node): def constant_folding(fgraph, node):
if not node.op.do_constant_folding(fgraph, node): if not node.op.do_constant_folding(fgraph, node):
...@@ -3092,9 +3092,9 @@ class FusionOptimizer(GraphRewriter): ...@@ -3092,9 +3092,9 @@ class FusionOptimizer(GraphRewriter):
""" """
def __init__(self, local_optimizer): def __init__(self, node_rewriter):
super().__init__() super().__init__()
self.optimizer = local_optimizer self.optimizer = node_rewriter
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate()) fgraph.attach_feature(ReplaceValidate())
...@@ -3206,7 +3206,7 @@ else: ...@@ -3206,7 +3206,7 @@ else:
@register_canonicalize @register_canonicalize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_useless_composite(fgraph, node): def local_useless_composite(fgraph, node):
"""For elemwise Composite that have multiple outputs, remove the """For elemwise Composite that have multiple outputs, remove the
outputs that are not used. outputs that are not used.
...@@ -3227,7 +3227,7 @@ def local_useless_composite(fgraph, node): ...@@ -3227,7 +3227,7 @@ def local_useless_composite(fgraph, node):
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@register_useless("fast_compile") @register_useless("fast_compile")
@local_optimizer(None) @node_rewriter(None)
def local_view_op(fgraph, node): def local_view_op(fgraph, node):
if isinstance(node.op, ViewOp): if isinstance(node.op, ViewOp):
return node.inputs return node.inputs
...@@ -3237,7 +3237,7 @@ def local_view_op(fgraph, node): ...@@ -3237,7 +3237,7 @@ def local_view_op(fgraph, node):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([Alloc]) @node_rewriter([Alloc])
def local_merge_alloc(fgraph, node): def local_merge_alloc(fgraph, node):
# This opt takes care of several cases: # This opt takes care of several cases:
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) # Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
...@@ -3274,7 +3274,7 @@ def local_merge_alloc(fgraph, node): ...@@ -3274,7 +3274,7 @@ def local_merge_alloc(fgraph, node):
@register_useless("fast_compile") @register_useless("fast_compile")
@local_optimizer([TopKOp]) @node_rewriter([TopKOp])
def local_useless_topk(fgraph, node): def local_useless_topk(fgraph, node):
""" """
TopKOp generates two outputs by default TopKOp generates two outputs by default
...@@ -3310,7 +3310,7 @@ def local_useless_topk(fgraph, node): ...@@ -3310,7 +3310,7 @@ def local_useless_topk(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@local_optimizer([SpecifyShape]) @node_rewriter([SpecifyShape])
def local_merge_consecutive_specify_shape(fgraph, node): def local_merge_consecutive_specify_shape(fgraph, node):
"""Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s3)``, """Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s3)``,
where s3 is the union of specified dimensions in s1 and s2, with preference given to s2. where s3 is the union of specified dimensions in s1 and s2, with preference given to s2.
...@@ -3336,7 +3336,7 @@ def local_merge_consecutive_specify_shape(fgraph, node): ...@@ -3336,7 +3336,7 @@ def local_merge_consecutive_specify_shape(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@local_optimizer([Shape]) @node_rewriter([Shape])
def local_Shape_of_SpecifyShape(fgraph, node): def local_Shape_of_SpecifyShape(fgraph, node):
"""Replace ``specify_shape(x, s).shape`` with ``s``.""" """Replace ``specify_shape(x, s).shape`` with ``s``."""
...@@ -3360,7 +3360,7 @@ def local_Shape_of_SpecifyShape(fgraph, node): ...@@ -3360,7 +3360,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@local_optimizer([Shape_i]) @node_rewriter([Shape_i])
def local_Shape_i_of_broadcastable(fgraph, node): def local_Shape_i_of_broadcastable(fgraph, node):
"""Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``.""" """Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``."""
...@@ -3378,7 +3378,7 @@ def local_Shape_i_of_broadcastable(fgraph, node): ...@@ -3378,7 +3378,7 @@ def local_Shape_i_of_broadcastable(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@local_optimizer([Unique]) @node_rewriter([Unique])
def local_Unique_scalar(fgraph, node): def local_Unique_scalar(fgraph, node):
"""Convert ``unique(x)`` to ``x`` when ``x`` is a scalar.""" """Convert ``unique(x)`` to ``x`` when ``x`` is a scalar."""
if not isinstance(node.op, Unique): if not isinstance(node.op, Unique):
...@@ -3399,7 +3399,7 @@ def local_Unique_scalar(fgraph, node): ...@@ -3399,7 +3399,7 @@ def local_Unique_scalar(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@local_optimizer([Unique]) @node_rewriter([Unique])
def local_Unique_Alloc_lift(fgraph, node): def local_Unique_Alloc_lift(fgraph, node):
"""Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``. """Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``.
...@@ -3432,7 +3432,7 @@ def local_Unique_Alloc_lift(fgraph, node): ...@@ -3432,7 +3432,7 @@ def local_Unique_Alloc_lift(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@local_optimizer([Unique]) @node_rewriter([Unique])
def local_Unique_BroadcastTo_lift(fgraph, node): def local_Unique_BroadcastTo_lift(fgraph, node):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``. """Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
...@@ -3465,7 +3465,7 @@ def local_Unique_BroadcastTo_lift(fgraph, node): ...@@ -3465,7 +3465,7 @@ def local_Unique_BroadcastTo_lift(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@local_optimizer([Unique]) @node_rewriter([Unique])
def local_Unique_Repeat_lift(fgraph, node): def local_Unique_Repeat_lift(fgraph, node):
"""Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``. """Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``.
...@@ -3498,7 +3498,7 @@ def local_Unique_Repeat_lift(fgraph, node): ...@@ -3498,7 +3498,7 @@ def local_Unique_Repeat_lift(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@local_optimizer([Unique]) @node_rewriter([Unique])
def local_Unique_second(fgraph, node): def local_Unique_second(fgraph, node):
"""Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``. """Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``.
...@@ -3535,7 +3535,7 @@ def local_Unique_second(fgraph, node): ...@@ -3535,7 +3535,7 @@ def local_Unique_second(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@local_optimizer([BroadcastTo]) @node_rewriter([BroadcastTo])
def local_remove_scalar_BroadcastTo(fgraph, node): def local_remove_scalar_BroadcastTo(fgraph, node):
bcast_shape = node.inputs[1:] bcast_shape = node.inputs[1:]
......
...@@ -150,7 +150,7 @@ from aesara.graph.opt import ( ...@@ -150,7 +150,7 @@ from aesara.graph.opt import (
GraphRewriter, GraphRewriter,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
local_optimizer, node_rewriter,
) )
from aesara.graph.optdb import SequenceDB from aesara.graph.optdb import SequenceDB
from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
...@@ -1733,7 +1733,7 @@ class Dot22(GemmRelated): ...@@ -1733,7 +1733,7 @@ class Dot22(GemmRelated):
_dot22 = Dot22() _dot22 = Dot22()
@local_optimizer([Dot]) @node_rewriter([Dot])
def local_dot_to_dot22(fgraph, node): def local_dot_to_dot22(fgraph, node):
# This works for tensor.outer too because basic.outer is a macro that # This works for tensor.outer too because basic.outer is a macro that
# produces a dot(dimshuffle,dimshuffle) of form 4 below # produces a dot(dimshuffle,dimshuffle) of form 4 below
...@@ -1766,7 +1766,7 @@ def local_dot_to_dot22(fgraph, node): ...@@ -1766,7 +1766,7 @@ def local_dot_to_dot22(fgraph, node):
_logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}") _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}")
@local_optimizer([gemm_no_inplace], inplace=True) @node_rewriter([gemm_no_inplace], inplace=True)
def local_inplace_gemm(fgraph, node): def local_inplace_gemm(fgraph, node):
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
new_out = [gemm_inplace(*node.inputs)] new_out = [gemm_inplace(*node.inputs)]
...@@ -1774,7 +1774,7 @@ def local_inplace_gemm(fgraph, node): ...@@ -1774,7 +1774,7 @@ def local_inplace_gemm(fgraph, node):
return new_out return new_out
@local_optimizer([gemv_no_inplace], inplace=True) @node_rewriter([gemv_no_inplace], inplace=True)
def local_inplace_gemv(fgraph, node): def local_inplace_gemv(fgraph, node):
if node.op == gemv_no_inplace: if node.op == gemv_no_inplace:
new_out = [gemv_inplace(*node.inputs)] new_out = [gemv_inplace(*node.inputs)]
...@@ -1782,7 +1782,7 @@ def local_inplace_gemv(fgraph, node): ...@@ -1782,7 +1782,7 @@ def local_inplace_gemv(fgraph, node):
return new_out return new_out
@local_optimizer([ger], inplace=True) @node_rewriter([ger], inplace=True)
def local_inplace_ger(fgraph, node): def local_inplace_ger(fgraph, node):
if node.op == ger: if node.op == ger:
new_out = [ger_destructive(*node.inputs)] new_out = [ger_destructive(*node.inputs)]
...@@ -1790,7 +1790,7 @@ def local_inplace_ger(fgraph, node): ...@@ -1790,7 +1790,7 @@ def local_inplace_ger(fgraph, node):
return new_out return new_out
@local_optimizer([gemm_no_inplace]) @node_rewriter([gemm_no_inplace])
def local_gemm_to_gemv(fgraph, node): def local_gemm_to_gemv(fgraph, node):
"""GEMM acting on row or column matrices -> GEMV.""" """GEMM acting on row or column matrices -> GEMV."""
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
...@@ -1807,7 +1807,7 @@ def local_gemm_to_gemv(fgraph, node): ...@@ -1807,7 +1807,7 @@ def local_gemm_to_gemv(fgraph, node):
return new_out return new_out
@local_optimizer([gemm_no_inplace]) @node_rewriter([gemm_no_inplace])
def local_gemm_to_ger(fgraph, node): def local_gemm_to_ger(fgraph, node):
"""GEMM computing an outer-product -> GER.""" """GEMM computing an outer-product -> GER."""
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
...@@ -1839,7 +1839,7 @@ def local_gemm_to_ger(fgraph, node): ...@@ -1839,7 +1839,7 @@ def local_gemm_to_ger(fgraph, node):
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline # TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
# working # working
@local_optimizer([_dot22]) @node_rewriter([_dot22])
def local_dot22_to_ger_or_gemv(fgraph, node): def local_dot22_to_ger_or_gemv(fgraph, node):
"""dot22 computing an outer-product -> GER.""" """dot22 computing an outer-product -> GER."""
if node.op == _dot22: if node.op == _dot22:
...@@ -2033,7 +2033,7 @@ class Dot22Scalar(GemmRelated): ...@@ -2033,7 +2033,7 @@ class Dot22Scalar(GemmRelated):
_dot22scalar = Dot22Scalar() _dot22scalar = Dot22Scalar()
@local_optimizer([mul]) @node_rewriter([mul])
def local_dot22_to_dot22scalar(fgraph, node): def local_dot22_to_dot22scalar(fgraph, node):
""" """
Notes Notes
...@@ -2651,7 +2651,7 @@ _batched_dot = BatchedDot() ...@@ -2651,7 +2651,7 @@ _batched_dot = BatchedDot()
# from opt import register_specialize, register_canonicalize # from opt import register_specialize, register_canonicalize
# @register_specialize # @register_specialize
@local_optimizer([sub, add]) @node_rewriter([sub, add])
def local_print_as_we_go_along(fgraph, node): def local_print_as_we_go_along(fgraph, node):
if node.op in (sub, add): if node.op in (sub, add):
debugprint(node) debugprint(node)
......
...@@ -15,7 +15,7 @@ from aesara.tensor.blas import ( ...@@ -15,7 +15,7 @@ from aesara.tensor.blas import (
ger, ger,
ger_destructive, ger_destructive,
ldflags, ldflags,
local_optimizer, node_rewriter,
optdb, optdb,
) )
...@@ -344,7 +344,7 @@ cger_inplace = CGer(True) ...@@ -344,7 +344,7 @@ cger_inplace = CGer(True)
cger_no_inplace = CGer(False) cger_no_inplace = CGer(False)
@local_optimizer([ger, ger_destructive]) @node_rewriter([ger, ger_destructive])
def use_c_ger(fgraph, node): def use_c_ger(fgraph, node):
if not config.blas__ldflags: if not config.blas__ldflags:
return return
...@@ -355,7 +355,7 @@ def use_c_ger(fgraph, node): ...@@ -355,7 +355,7 @@ def use_c_ger(fgraph, node):
return [CGer(True)(*node.inputs)] return [CGer(True)(*node.inputs)]
@local_optimizer([CGer(False)]) @node_rewriter([CGer(False)])
def make_c_ger_destructive(fgraph, node): def make_c_ger_destructive(fgraph, node):
if isinstance(node.op, CGer) and not node.op.destructive: if isinstance(node.op, CGer) and not node.op.destructive:
return [cger_inplace(*node.inputs)] return [cger_inplace(*node.inputs)]
...@@ -699,7 +699,7 @@ int main() { ...@@ -699,7 +699,7 @@ int main() {
check_force_gemv_init._force_init_beta = None check_force_gemv_init._force_init_beta = None
@local_optimizer([gemv_inplace, gemv_no_inplace]) @node_rewriter([gemv_inplace, gemv_no_inplace])
def use_c_gemv(fgraph, node): def use_c_gemv(fgraph, node):
if not config.blas__ldflags: if not config.blas__ldflags:
return return
...@@ -710,7 +710,7 @@ def use_c_gemv(fgraph, node): ...@@ -710,7 +710,7 @@ def use_c_gemv(fgraph, node):
return [cgemv_inplace(*node.inputs)] return [cgemv_inplace(*node.inputs)]
@local_optimizer([CGemv(inplace=False)]) @node_rewriter([CGemv(inplace=False)])
def make_c_gemv_destructive(fgraph, node): def make_c_gemv_destructive(fgraph, node):
if isinstance(node.op, CGemv) and not node.op.inplace: if isinstance(node.op, CGemv) and not node.op.inplace:
inputs = list(node.inputs) inputs = list(node.inputs)
......
...@@ -11,7 +11,7 @@ from aesara.tensor.blas import ( ...@@ -11,7 +11,7 @@ from aesara.tensor.blas import (
ger, ger,
ger_destructive, ger_destructive,
have_fblas, have_fblas,
local_optimizer, node_rewriter,
optdb, optdb,
) )
...@@ -58,13 +58,13 @@ scipy_ger_no_inplace = ScipyGer(False) ...@@ -58,13 +58,13 @@ scipy_ger_no_inplace = ScipyGer(False)
scipy_ger_inplace = ScipyGer(True) scipy_ger_inplace = ScipyGer(True)
@local_optimizer([ger, ger_destructive]) @node_rewriter([ger, ger_destructive])
def use_scipy_ger(fgraph, node): def use_scipy_ger(fgraph, node):
if node.op == ger: if node.op == ger:
return [scipy_ger_no_inplace(*node.inputs)] return [scipy_ger_no_inplace(*node.inputs)]
@local_optimizer([scipy_ger_no_inplace]) @node_rewriter([scipy_ger_no_inplace])
def make_ger_destructive(fgraph, node): def make_ger_destructive(fgraph, node):
if node.op == scipy_ger_no_inplace: if node.op == scipy_ger_no_inplace:
return [scipy_ger_inplace(*node.inputs)] return [scipy_ger_inplace(*node.inputs)]
......
...@@ -11,11 +11,11 @@ import aesara.scalar.math as aes_math ...@@ -11,11 +11,11 @@ import aesara.scalar.math as aes_math
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.graph.opt import ( from aesara.graph.opt import (
LocalOptGroup, LocalOptGroup,
LocalOptimizer, NodeRewriter,
PatternSub, PatternSub,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
local_optimizer, node_rewriter,
) )
from aesara.graph.opt_utils import get_clients_at_depth from aesara.graph.opt_utils import get_clients_at_depth
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
...@@ -148,7 +148,7 @@ def fill_chain(new_out, orig_inputs): ...@@ -148,7 +148,7 @@ def fill_chain(new_out, orig_inputs):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@local_optimizer([Dot]) @node_rewriter([Dot])
def local_0_dot_x(fgraph, node): def local_0_dot_x(fgraph, node):
if not isinstance(node.op, Dot): if not isinstance(node.op, Dot):
return False return False
...@@ -185,7 +185,7 @@ def local_0_dot_x(fgraph, node): ...@@ -185,7 +185,7 @@ def local_0_dot_x(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([DimShuffle]) @node_rewriter([DimShuffle])
def local_lift_transpose_through_dot(fgraph, node): def local_lift_transpose_through_dot(fgraph, node):
"""Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)`` """Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)``
...@@ -229,7 +229,7 @@ def is_inverse_pair(node_op, prev_op, inv_pair): ...@@ -229,7 +229,7 @@ def is_inverse_pair(node_op, prev_op, inv_pair):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_func_inv(fgraph, node): def local_func_inv(fgraph, node):
""" """
Check for two consecutive operations that are functional inverses Check for two consecutive operations that are functional inverses
...@@ -271,7 +271,7 @@ def local_func_inv(fgraph, node): ...@@ -271,7 +271,7 @@ def local_func_inv(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_exp_log(fgraph, node): def local_exp_log(fgraph, node):
x = node.inputs[0] x = node.inputs[0]
...@@ -313,7 +313,7 @@ def local_exp_log(fgraph, node): ...@@ -313,7 +313,7 @@ def local_exp_log(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_exp_log_nan_switch(fgraph, node): def local_exp_log_nan_switch(fgraph, node):
# Rewrites of the kind exp(log...(x)) that require a `nan` switch # Rewrites of the kind exp(log...(x)) that require a `nan` switch
x = node.inputs[0] x = node.inputs[0]
...@@ -371,7 +371,7 @@ def local_exp_log_nan_switch(fgraph, node): ...@@ -371,7 +371,7 @@ def local_exp_log_nan_switch(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Sum]) @node_rewriter([Sum])
def local_sumsqr2dot(fgraph, node): def local_sumsqr2dot(fgraph, node):
""" """
This optimization detects This optimization detects
...@@ -418,7 +418,7 @@ def local_sumsqr2dot(fgraph, node): ...@@ -418,7 +418,7 @@ def local_sumsqr2dot(fgraph, node):
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_expm1(fgraph, node): def local_expm1(fgraph, node):
""" """
This optimization detects exp(a)-1 and converts this to expm1(a). This optimization detects exp(a)-1 and converts this to expm1(a).
...@@ -446,7 +446,7 @@ def local_expm1(fgraph, node): ...@@ -446,7 +446,7 @@ def local_expm1(fgraph, node):
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@local_optimizer([mul]) @node_rewriter([mul])
def local_mul_switch_sink(fgraph, node): def local_mul_switch_sink(fgraph, node):
""" """
This optimization makes the following changes in the graph: This optimization makes the following changes in the graph:
...@@ -540,7 +540,7 @@ def local_mul_switch_sink(fgraph, node): ...@@ -540,7 +540,7 @@ def local_mul_switch_sink(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([true_div, int_div]) @node_rewriter([true_div, int_div])
def local_div_switch_sink(fgraph, node): def local_div_switch_sink(fgraph, node):
""" """
This optimization makes the following changes in the graph: This optimization makes the following changes in the graph:
...@@ -616,33 +616,33 @@ def local_div_switch_sink(fgraph, node): ...@@ -616,33 +616,33 @@ def local_div_switch_sink(fgraph, node):
return False return False
class AlgebraicCanonizer(LocalOptimizer): class AlgebraicCanonizer(NodeRewriter):
r"""Simplification tool. r"""A `Rewriter` that rewrites algebraic expressions.
The variable is a ``local_optimizer``. It is best used The variable is a `node_rewriter`. It is best used
with a ``TopoOptimizer`` in ``in_to_out`` order. with a `TopoOptimizer` in in-to-out order.
Usage: ``AlgebraicCanonizer(main, inverse, reciprocal, calculate)`` Usage: ``AlgebraicCanonizer(main, inverse, reciprocal, calculate)``
Parameters Parameters
---------- ----------
main main
A suitable ``Op`` class that is commutative, associative and A suitable `Op` class that is commutative, associative and
takes one to an arbitrary number of inputs, e.g. add or takes one to an arbitrary number of inputs, e.g. add or
mul mul
inverse inverse
An ``Op`` class such that ``inverse(main(x, y), y) == x`` An `Op` class such that ``inverse(main(x, y), y) == x``
e.g. ``sub`` or true_div (e.g. `sub` or `true_div`).
reciprocal reciprocal
A function such that ``main(x, reciprocal(y)) == inverse(x, y)`` A function such that ``main(x, reciprocal(y)) == inverse(x, y)``
e.g. ``neg`` or ``reciprocal`` (e.g. `neg` or `reciprocal`).
calculate calculate
Function that takes a list of numpy.ndarray instances Function that takes a list of `numpy.ndarray` instances
for the numerator, another list for the denumerator, for the numerator, another list for the denumerator,
and calculates ``inverse(main(\*num), main(\*denum))``. It and calculates ``inverse(main(\*num), main(\*denum))``. It
takes a keyword argument, aslist. If True, the value takes a keyword argument, `aslist`. If ``True``, the value
should be returned as a list of one element, unless should be returned as a list of one element, unless
the value is such that value = main(). In that case, the value is such that ``value = main()``. In that case,
the return value should be an empty list. the return value should be an empty list.
Examples Examples
...@@ -654,7 +654,7 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -654,7 +654,7 @@ class AlgebraicCanonizer(LocalOptimizer):
>>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\ >>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\
... lambda n, d: prod(n) / prod(d)) ... lambda n, d: prod(n) / prod(d))
Examples of optimizations ``mul_canonizer`` can perform: Examples of optimizations `mul_canonizer` can perform:
| x / x -> 1 | x / x -> 1
| (x * y) / x -> y | (x * y) / x -> y
...@@ -1082,14 +1082,14 @@ register_canonicalize(local_mul_canonizer, name="local_mul_canonizer") ...@@ -1082,14 +1082,14 @@ register_canonicalize(local_mul_canonizer, name="local_mul_canonizer")
@register_canonicalize @register_canonicalize
@local_optimizer([neg]) @node_rewriter([neg])
def local_neg_to_mul(fgraph, node): def local_neg_to_mul(fgraph, node):
if node.op == neg: if node.op == neg:
return [mul(np.array(-1, dtype=node.inputs[0].dtype), node.inputs[0])] return [mul(np.array(-1, dtype=node.inputs[0].dtype), node.inputs[0])]
@register_specialize @register_specialize
@local_optimizer([Sum, Prod]) @node_rewriter([Sum, Prod])
def local_sum_prod_mul_by_scalar(fgraph, node): def local_sum_prod_mul_by_scalar(fgraph, node):
""" """
sum(scalar * smth) -> scalar * sum(smth) sum(scalar * smth) -> scalar * sum(smth)
...@@ -1175,7 +1175,7 @@ def local_sum_prod_mul_by_scalar(fgraph, node): ...@@ -1175,7 +1175,7 @@ def local_sum_prod_mul_by_scalar(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_elemwise_sub_zeros(fgraph, node): def local_elemwise_sub_zeros(fgraph, node):
""" """
Elemwise{sub}(X,X) -> zeros_like(X) Elemwise{sub}(X,X) -> zeros_like(X)
...@@ -1197,7 +1197,7 @@ def local_elemwise_sub_zeros(fgraph, node): ...@@ -1197,7 +1197,7 @@ def local_elemwise_sub_zeros(fgraph, node):
@register_specialize @register_specialize
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_useless_elemwise_comparison(fgraph, node): def local_useless_elemwise_comparison(fgraph, node):
"""... """...
...@@ -1407,7 +1407,7 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -1407,7 +1407,7 @@ def local_useless_elemwise_comparison(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Sum, Prod]) @node_rewriter([Sum, Prod])
def local_sum_prod_div_dimshuffle(fgraph, node): def local_sum_prod_div_dimshuffle(fgraph, node):
""" """
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b, sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
...@@ -1499,7 +1499,7 @@ def local_sum_prod_div_dimshuffle(fgraph, node): ...@@ -1499,7 +1499,7 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([Sum, Prod]) @node_rewriter([Sum, Prod])
def local_sum_prod_all_to_none(fgraph, node): def local_sum_prod_all_to_none(fgraph, node):
""" """
Sum{0,1,...N} -> Sum{} or Sum{0,1,...N} -> Sum{} or
...@@ -1517,7 +1517,7 @@ def local_sum_prod_all_to_none(fgraph, node): ...@@ -1517,7 +1517,7 @@ def local_sum_prod_all_to_none(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([Sum, Prod]) @node_rewriter([Sum, Prod])
def local_op_of_op(fgraph, node): def local_op_of_op(fgraph, node):
""" """
Prod(Prod()) -> single Prod() Prod(Prod()) -> single Prod()
...@@ -1573,7 +1573,7 @@ ALL_REDUCE = ( ...@@ -1573,7 +1573,7 @@ ALL_REDUCE = (
@register_canonicalize @register_canonicalize
@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce @register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
@local_optimizer(ALL_REDUCE) @node_rewriter(ALL_REDUCE)
def local_reduce_join(fgraph, node): def local_reduce_join(fgraph, node):
""" """
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b) CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
...@@ -1645,7 +1645,7 @@ def local_reduce_join(fgraph, node): ...@@ -1645,7 +1645,7 @@ def local_reduce_join(fgraph, node):
@register_canonicalize("fast_compile", "local_cut_useless_reduce") @register_canonicalize("fast_compile", "local_cut_useless_reduce")
@register_useless("local_cut_useless_reduce") @register_useless("local_cut_useless_reduce")
@local_optimizer(ALL_REDUCE) @node_rewriter(ALL_REDUCE)
def local_useless_reduce(fgraph, node): def local_useless_reduce(fgraph, node):
"""Sum(a, axis=[]) -> a""" """Sum(a, axis=[]) -> a"""
if isinstance(node.op, CAReduce): if isinstance(node.op, CAReduce):
...@@ -1658,7 +1658,7 @@ def local_useless_reduce(fgraph, node): ...@@ -1658,7 +1658,7 @@ def local_useless_reduce(fgraph, node):
@register_canonicalize @register_canonicalize
@register_uncanonicalize @register_uncanonicalize
@register_specialize @register_specialize
@local_optimizer(ALL_REDUCE) @node_rewriter(ALL_REDUCE)
def local_reduce_broadcastable(fgraph, node): def local_reduce_broadcastable(fgraph, node):
"""Remove reduction over broadcastable dimensions.""" """Remove reduction over broadcastable dimensions."""
if isinstance(node.op, CAReduce): if isinstance(node.op, CAReduce):
...@@ -1700,7 +1700,7 @@ def local_reduce_broadcastable(fgraph, node): ...@@ -1700,7 +1700,7 @@ def local_reduce_broadcastable(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([Sum, Prod]) @node_rewriter([Sum, Prod])
def local_opt_alloc(fgraph, node): def local_opt_alloc(fgraph, node):
""" """
sum(alloc(constant,shapes...)) => constant*prod(shapes) sum(alloc(constant,shapes...)) => constant*prod(shapes)
...@@ -1764,7 +1764,7 @@ def local_opt_alloc(fgraph, node): ...@@ -1764,7 +1764,7 @@ def local_opt_alloc(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([neg]) @node_rewriter([neg])
def local_neg_div_neg(fgraph, node): def local_neg_div_neg(fgraph, node):
""" """
- (-a / b) -> a / b - (-a / b) -> a / b
...@@ -1788,7 +1788,7 @@ def local_neg_div_neg(fgraph, node): ...@@ -1788,7 +1788,7 @@ def local_neg_div_neg(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([mul]) @node_rewriter([mul])
def local_mul_zero(fgraph, node): def local_mul_zero(fgraph, node):
""" """
As part of canonicalization, we replace multiplication by zero As part of canonicalization, we replace multiplication by zero
...@@ -1811,7 +1811,7 @@ def local_mul_zero(fgraph, node): ...@@ -1811,7 +1811,7 @@ def local_mul_zero(fgraph, node):
# TODO: Add this to the canonicalization to reduce redundancy. # TODO: Add this to the canonicalization to reduce redundancy.
@register_specialize @register_specialize
@local_optimizer([true_div]) @node_rewriter([true_div])
def local_div_to_reciprocal(fgraph, node): def local_div_to_reciprocal(fgraph, node):
if node.op == true_div and np.all(get_constant(node.inputs[0]) == 1.0): if node.op == true_div and np.all(get_constant(node.inputs[0]) == 1.0):
out = node.outputs[0] out = node.outputs[0]
...@@ -1828,7 +1828,7 @@ def local_div_to_reciprocal(fgraph, node): ...@@ -1828,7 +1828,7 @@ def local_div_to_reciprocal(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([reciprocal]) @node_rewriter([reciprocal])
def local_reciprocal_canon(fgraph, node): def local_reciprocal_canon(fgraph, node):
if node.op == reciprocal: if node.op == reciprocal:
return [at_pow(node.inputs[0], -1.0)] return [at_pow(node.inputs[0], -1.0)]
...@@ -1837,7 +1837,7 @@ def local_reciprocal_canon(fgraph, node): ...@@ -1837,7 +1837,7 @@ def local_reciprocal_canon(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([at_pow]) @node_rewriter([at_pow])
def local_pow_canonicalize(fgraph, node): def local_pow_canonicalize(fgraph, node):
if node.op == at_pow: if node.op == at_pow:
cst = get_constant(node.inputs[1]) cst = get_constant(node.inputs[1])
...@@ -1850,7 +1850,7 @@ def local_pow_canonicalize(fgraph, node): ...@@ -1850,7 +1850,7 @@ def local_pow_canonicalize(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([mul]) @node_rewriter([mul])
def local_mul_to_sqr(fgraph, node): def local_mul_to_sqr(fgraph, node):
""" """
x*x -> sqr(x) x*x -> sqr(x)
...@@ -1862,7 +1862,7 @@ def local_mul_to_sqr(fgraph, node): ...@@ -1862,7 +1862,7 @@ def local_mul_to_sqr(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([int_div]) @node_rewriter([int_div])
def local_intdiv_by_one(fgraph, node): def local_intdiv_by_one(fgraph, node):
"""x // 1 -> x""" """x // 1 -> x"""
if node.op in [int_div]: if node.op in [int_div]:
...@@ -1874,7 +1874,7 @@ def local_intdiv_by_one(fgraph, node): ...@@ -1874,7 +1874,7 @@ def local_intdiv_by_one(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([int_div, true_div]) @node_rewriter([int_div, true_div])
def local_zero_div(fgraph, node): def local_zero_div(fgraph, node):
"""0 / x -> 0""" """0 / x -> 0"""
if isinstance(node.op, Elemwise) and isinstance( if isinstance(node.op, Elemwise) and isinstance(
...@@ -1887,7 +1887,7 @@ def local_zero_div(fgraph, node): ...@@ -1887,7 +1887,7 @@ def local_zero_div(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([at_pow]) @node_rewriter([at_pow])
def local_pow_specialize(fgraph, node): def local_pow_specialize(fgraph, node):
# here, we are past the point of canonicalization, so we don't want # here, we are past the point of canonicalization, so we don't want
# to put in un-necessary fills. # to put in un-necessary fills.
...@@ -1925,7 +1925,7 @@ def local_pow_specialize(fgraph, node): ...@@ -1925,7 +1925,7 @@ def local_pow_specialize(fgraph, node):
@register_specialize_device @register_specialize_device
@local_optimizer([at_pow]) @node_rewriter([at_pow])
def local_pow_specialize_device(fgraph, node): def local_pow_specialize_device(fgraph, node):
""" """
This optimization is not the same on all device. We do it only on cpu here. This optimization is not the same on all device. We do it only on cpu here.
...@@ -1992,7 +1992,7 @@ def local_pow_specialize_device(fgraph, node): ...@@ -1992,7 +1992,7 @@ def local_pow_specialize_device(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([mul]) @node_rewriter([mul])
def local_mul_specialize(fgraph, node): def local_mul_specialize(fgraph, node):
""" """
Remove special-case constants from mul arguments and useless neg in inputs. Remove special-case constants from mul arguments and useless neg in inputs.
...@@ -2068,7 +2068,7 @@ def local_mul_specialize(fgraph, node): ...@@ -2068,7 +2068,7 @@ def local_mul_specialize(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([add]) @node_rewriter([add])
def local_add_specialize(fgraph, node): def local_add_specialize(fgraph, node):
"""Remove zeros from ``add``s. """Remove zeros from ``add``s.
...@@ -2147,7 +2147,7 @@ local_mul_canonizer.add_simplifier(check_for_x_over_absX, "X_over_absX") ...@@ -2147,7 +2147,7 @@ local_mul_canonizer.add_simplifier(check_for_x_over_absX, "X_over_absX")
@register_canonicalize @register_canonicalize
@local_optimizer([at_abs]) @node_rewriter([at_abs])
def local_abs_lift(fgraph, node): def local_abs_lift(fgraph, node):
""" """
Move the abs toward the input. Move the abs toward the input.
...@@ -2165,7 +2165,7 @@ def local_abs_lift(fgraph, node): ...@@ -2165,7 +2165,7 @@ def local_abs_lift(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([mul, true_div]) @node_rewriter([mul, true_div])
def local_abs_merge(fgraph, node): def local_abs_merge(fgraph, node):
""" """
Merge abs generated by local_abs_lift when the canonizer don't Merge abs generated by local_abs_lift when the canonizer don't
...@@ -2201,7 +2201,7 @@ def local_abs_merge(fgraph, node): ...@@ -2201,7 +2201,7 @@ def local_abs_merge(fgraph, node):
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([log]) @node_rewriter([log])
def local_log1p(fgraph, node): def local_log1p(fgraph, node):
# log(1+x) -> log1p(x) # log(1+x) -> log1p(x)
# log(1-x) -> log1p(-x) # log(1-x) -> log1p(-x)
...@@ -2234,7 +2234,7 @@ def local_log1p(fgraph, node): ...@@ -2234,7 +2234,7 @@ def local_log1p(fgraph, node):
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([log]) @node_rewriter([log])
def local_log_add_exp(fgraph, node): def local_log_add_exp(fgraph, node):
""" """
``log(exp(x)+exp(y)+exp(z)) = max + log(x-max, y-max, z-max)`` ``log(exp(x)+exp(y)+exp(z)) = max + log(x-max, y-max, z-max)``
...@@ -2266,7 +2266,7 @@ def local_log_add_exp(fgraph, node): ...@@ -2266,7 +2266,7 @@ def local_log_add_exp(fgraph, node):
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([log]) @node_rewriter([log])
def local_log_sum_exp(fgraph, node): def local_log_sum_exp(fgraph, node):
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max))) # log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
...@@ -2435,7 +2435,7 @@ def attempt_distribution(factor, num, denum, out_type): ...@@ -2435,7 +2435,7 @@ def attempt_distribution(factor, num, denum, out_type):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@local_optimizer([mul, true_div, reciprocal]) @node_rewriter([mul, true_div, reciprocal])
def local_greedy_distributor(fgraph, node): def local_greedy_distributor(fgraph, node):
""" """
Optimize by reducing the number of multiplications and/or divisions. Optimize by reducing the number of multiplications and/or divisions.
...@@ -2609,7 +2609,7 @@ register_specialize(local_erf_neg_minus_one) ...@@ -2609,7 +2609,7 @@ register_specialize(local_erf_neg_minus_one)
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([log]) @node_rewriter([log])
def local_log_erfc(fgraph, node): def local_log_erfc(fgraph, node):
"""Stability optimization for `log(erfc(x))`. """Stability optimization for `log(erfc(x))`.
...@@ -2652,7 +2652,7 @@ def local_log_erfc(fgraph, node): ...@@ -2652,7 +2652,7 @@ def local_log_erfc(fgraph, node):
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([true_div]) @node_rewriter([true_div])
def local_grad_log_erfc_neg(fgraph, node): def local_grad_log_erfc_neg(fgraph, node):
"""Stability optimization for the grad of `log(erfc(x))`. """Stability optimization for the grad of `log(erfc(x))`.
...@@ -3093,7 +3093,7 @@ def is_neg(var): ...@@ -3093,7 +3093,7 @@ def is_neg(var):
@register_stabilize @register_stabilize
@local_optimizer([true_div]) @node_rewriter([true_div])
def local_exp_over_1_plus_exp(fgraph, node): def local_exp_over_1_plus_exp(fgraph, node):
""" """
exp(x)/(1+exp(x)) -> sigm(x) exp(x)/(1+exp(x)) -> sigm(x)
...@@ -3447,7 +3447,7 @@ def perform_sigm_times_exp( ...@@ -3447,7 +3447,7 @@ def perform_sigm_times_exp(
@register_stabilize @register_stabilize
@local_optimizer([mul]) @node_rewriter([mul])
def local_sigm_times_exp(fgraph, node): def local_sigm_times_exp(fgraph, node):
""" """
exp(x) * sigm(-x) -> sigm(x) exp(x) * sigm(-x) -> sigm(x)
...@@ -3476,7 +3476,7 @@ def local_sigm_times_exp(fgraph, node): ...@@ -3476,7 +3476,7 @@ def local_sigm_times_exp(fgraph, node):
@register_stabilize @register_stabilize
@local_optimizer([reciprocal]) @node_rewriter([reciprocal])
def local_reciprocal_1_plus_exp(fgraph, node): def local_reciprocal_1_plus_exp(fgraph, node):
"""``reciprocal(1+exp(x)) -> sigm(-x)`` """``reciprocal(1+exp(x)) -> sigm(-x)``
...@@ -3558,7 +3558,7 @@ register_specialize(local_sigmoid_logit) ...@@ -3558,7 +3558,7 @@ register_specialize(local_sigmoid_logit)
@register_canonicalize @register_canonicalize
@register_useless @register_useless
@local_optimizer([_conj]) @node_rewriter([_conj])
def local_useless_conj(fgraph, node): def local_useless_conj(fgraph, node):
r"""Remove `conj` `Op`\s applied to non-imaginary variable types.""" r"""Remove `conj` `Op`\s applied to non-imaginary variable types."""
x = node.inputs[0] x = node.inputs[0]
......
...@@ -18,7 +18,7 @@ from aesara.compile import optdb ...@@ -18,7 +18,7 @@ from aesara.compile import optdb
from aesara.gradient import DisconnectedType, grad_not_implemented from aesara.gradient import DisconnectedType, grad_not_implemented
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, local_optimizer, optimizer from aesara.graph.opt import copy_stack_trace, node_rewriter, optimizer
from aesara.link.c.op import COp from aesara.link.c.op import COp
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.scalar import UnaryScalarOp from aesara.scalar import UnaryScalarOp
...@@ -1046,7 +1046,7 @@ class LogSoftmax(COp): ...@@ -1046,7 +1046,7 @@ class LogSoftmax(COp):
# This is not registered in stabilize, as it cause some crossentropy # This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted. # optimization to not be inserted.
@register_specialize("stabilize", "fast_compile") @register_specialize("stabilize", "fast_compile")
@local_optimizer([Elemwise]) @node_rewriter([Elemwise])
def local_logsoftmax(fgraph, node): def local_logsoftmax(fgraph, node):
""" """
Detect Log(Softmax(x)) and replace it with LogSoftmax(x) Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
...@@ -1071,7 +1071,7 @@ def local_logsoftmax(fgraph, node): ...@@ -1071,7 +1071,7 @@ def local_logsoftmax(fgraph, node):
# This is not registered in stabilize, as it cause some crossentropy # This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted. # optimization to not be inserted.
@register_specialize("stabilize", "fast_compile") @register_specialize("stabilize", "fast_compile")
@local_optimizer([SoftmaxGrad]) @node_rewriter([SoftmaxGrad])
def local_logsoftmax_grad(fgraph, node): def local_logsoftmax_grad(fgraph, node):
""" """
Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad
...@@ -1150,7 +1150,7 @@ def logsoftmax(c, axis=UNSET_AXIS): ...@@ -1150,7 +1150,7 @@ def logsoftmax(c, axis=UNSET_AXIS):
@register_specialize("fast_compile") @register_specialize("fast_compile")
@local_optimizer([softmax_legacy]) @node_rewriter([softmax_legacy])
def local_softmax_with_bias(fgraph, node): def local_softmax_with_bias(fgraph, node):
""" """
Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias). Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias).
...@@ -1954,7 +1954,7 @@ optdb.register( ...@@ -1954,7 +1954,7 @@ optdb.register(
@register_specialize( @register_specialize(
"fast_compile", "local_crossentropy_to_crossentropy_with_softmax_grad" "fast_compile", "local_crossentropy_to_crossentropy_with_softmax_grad"
) # old name ) # old name
@local_optimizer([softmax_grad_legacy]) @node_rewriter([softmax_grad_legacy])
def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node): def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node):
if node.op == softmax_grad_legacy and node.inputs[1].ndim == 2: if node.op == softmax_grad_legacy and node.inputs[1].ndim == 2:
g_coding_dist, coding_dist = node.inputs g_coding_dist, coding_dist = node.inputs
...@@ -1971,7 +1971,7 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node): ...@@ -1971,7 +1971,7 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node):
@register_specialize("fast_compile") @register_specialize("fast_compile")
@local_optimizer([MaxAndArgmax]) @node_rewriter([MaxAndArgmax])
def local_argmax_pushdown(fgraph, node): def local_argmax_pushdown(fgraph, node):
if ( if (
isinstance(node.op, MaxAndArgmax) isinstance(node.op, MaxAndArgmax)
...@@ -2060,7 +2060,7 @@ def _is_const(z, val, approx=False): ...@@ -2060,7 +2060,7 @@ def _is_const(z, val, approx=False):
@register_specialize("fast_compile") @register_specialize("fast_compile")
@local_optimizer([AdvancedSubtensor, log]) @node_rewriter([AdvancedSubtensor, log])
def local_advanced_indexing_crossentropy_onehot(fgraph, node): def local_advanced_indexing_crossentropy_onehot(fgraph, node):
log_op = None log_op = None
sm = None sm = None
...@@ -2108,7 +2108,7 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node): ...@@ -2108,7 +2108,7 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node):
@register_specialize("fast_compile") @register_specialize("fast_compile")
@local_optimizer([softmax_grad_legacy]) @node_rewriter([softmax_grad_legacy])
def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node): def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node):
if not (node.op == softmax_grad_legacy and node.inputs[1].ndim == 2): if not (node.op == softmax_grad_legacy and node.inputs[1].ndim == 2):
return return
...@@ -2323,7 +2323,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node): ...@@ -2323,7 +2323,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node):
@register_specialize("fast_compile") @register_specialize("fast_compile")
@local_optimizer([softmax_with_bias]) @node_rewriter([softmax_with_bias])
def graph_merge_softmax_with_crossentropy_softmax(fgraph, node): def graph_merge_softmax_with_crossentropy_softmax(fgraph, node):
if node.op == softmax_with_bias: if node.op == softmax_with_bias:
x, b = node.inputs x, b = node.inputs
...@@ -2340,7 +2340,7 @@ def graph_merge_softmax_with_crossentropy_softmax(fgraph, node): ...@@ -2340,7 +2340,7 @@ def graph_merge_softmax_with_crossentropy_softmax(fgraph, node):
@register_specialize @register_specialize
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
@local_optimizer([CrossentropySoftmax1HotWithBiasDx]) @node_rewriter([CrossentropySoftmax1HotWithBiasDx])
def local_useless_crossentropy_softmax_1hot_with_bias_dx_alloc(fgraph, node): def local_useless_crossentropy_softmax_1hot_with_bias_dx_alloc(fgraph, node):
""" """
Replace a CrossentropySoftmax1HotWithBiasDx op, whose incoming gradient is Replace a CrossentropySoftmax1HotWithBiasDx op, whose incoming gradient is
......
...@@ -4,7 +4,7 @@ import aesara ...@@ -4,7 +4,7 @@ import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, local_optimizer from aesara.graph.opt import copy_stack_trace, node_rewriter
from aesara.scalar import Composite, add, as_common_dtype, mul, sub, true_div from aesara.scalar import Composite, add, as_common_dtype, mul, sub, true_div
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
...@@ -778,7 +778,7 @@ class AbstractBatchNormTrainGrad(Op): ...@@ -778,7 +778,7 @@ class AbstractBatchNormTrainGrad(Op):
output_storage[2][0] = g_wrt_bias output_storage[2][0] = g_wrt_bias
@local_optimizer([AbstractBatchNormTrain]) @node_rewriter([AbstractBatchNormTrain])
def local_abstract_batch_norm_train(fgraph, node): def local_abstract_batch_norm_train(fgraph, node):
if not isinstance(node.op, AbstractBatchNormTrain): if not isinstance(node.op, AbstractBatchNormTrain):
return None return None
...@@ -832,7 +832,7 @@ def local_abstract_batch_norm_train(fgraph, node): ...@@ -832,7 +832,7 @@ def local_abstract_batch_norm_train(fgraph, node):
return results return results
@local_optimizer([AbstractBatchNormTrainGrad]) @node_rewriter([AbstractBatchNormTrainGrad])
def local_abstract_batch_norm_train_grad(fgraph, node): def local_abstract_batch_norm_train_grad(fgraph, node):
if not isinstance(node.op, AbstractBatchNormTrainGrad): if not isinstance(node.op, AbstractBatchNormTrainGrad):
return None return None
...@@ -866,7 +866,7 @@ def local_abstract_batch_norm_train_grad(fgraph, node): ...@@ -866,7 +866,7 @@ def local_abstract_batch_norm_train_grad(fgraph, node):
return results return results
@local_optimizer([AbstractBatchNormInference]) @node_rewriter([AbstractBatchNormInference])
def local_abstract_batch_norm_inference(fgraph, node): def local_abstract_batch_norm_inference(fgraph, node):
if not isinstance(node.op, AbstractBatchNormInference): if not isinstance(node.op, AbstractBatchNormInference):
return None return None
......
...@@ -3,7 +3,7 @@ from aesara import tensor as at ...@@ -3,7 +3,7 @@ from aesara import tensor as at
from aesara.gradient import DisconnectedType from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import TopoOptimizer, copy_stack_trace, local_optimizer from aesara.graph.opt import TopoOptimizer, copy_stack_trace, node_rewriter
def get_diagonal_subtensor_view(x, i0, i1): def get_diagonal_subtensor_view(x, i0, i1):
...@@ -296,7 +296,7 @@ def conv3d( ...@@ -296,7 +296,7 @@ def conv3d(
return out_5d return out_5d
@local_optimizer([DiagonalSubtensor, IncDiagonalSubtensor]) @node_rewriter([DiagonalSubtensor, IncDiagonalSubtensor])
def local_inplace_DiagonalSubtensor(fgraph, node): def local_inplace_DiagonalSubtensor(fgraph, node):
"""Also work for IncDiagonalSubtensor.""" """Also work for IncDiagonalSubtensor."""
if ( if (
......
...@@ -5,7 +5,7 @@ import aesara.tensor as at ...@@ -5,7 +5,7 @@ import aesara.tensor as at
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import grad_undefined from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.opt import local_optimizer from aesara.graph.opt import node_rewriter
from aesara.link.c.cmodule import GCC_compiler from aesara.link.c.cmodule import GCC_compiler
from aesara.link.c.op import ExternalCOp, OpenMPOp from aesara.link.c.op import ExternalCOp, OpenMPOp
from aesara.tensor.basic_opt import register_canonicalize from aesara.tensor.basic_opt import register_canonicalize
...@@ -249,7 +249,7 @@ def ctc(activations, labels, input_lengths): ...@@ -249,7 +249,7 @@ def ctc(activations, labels, input_lengths):
# Disable gradient computation if not needed # Disable gradient computation if not needed
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@local_optimizer([ConnectionistTemporalClassification]) @node_rewriter([ConnectionistTemporalClassification])
def local_ctc_no_grad(fgraph, node): def local_ctc_no_grad(fgraph, node):
if isinstance(node.op, ConnectionistTemporalClassification): if isinstance(node.op, ConnectionistTemporalClassification):
if len(node.outputs) > 1: if len(node.outputs) > 1:
......
...@@ -11,7 +11,7 @@ from aesara.graph.opt import ( ...@@ -11,7 +11,7 @@ from aesara.graph.opt import (
TopoOptimizer, TopoOptimizer,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
local_optimizer, node_rewriter,
) )
from aesara.tensor.basic_opt import register_specialize_device from aesara.tensor.basic_opt import register_specialize_device
from aesara.tensor.nnet.abstract_conv import ( from aesara.tensor.nnet.abstract_conv import (
...@@ -37,7 +37,7 @@ from aesara.tensor.nnet.corr3d import Corr3dMM, Corr3dMMGradInputs, Corr3dMMGrad ...@@ -37,7 +37,7 @@ from aesara.tensor.nnet.corr3d import Corr3dMM, Corr3dMMGradInputs, Corr3dMMGrad
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
@local_optimizer([SparseBlockGemv], inplace=True) @node_rewriter([SparseBlockGemv], inplace=True)
def local_inplace_sparse_block_gemv(fgraph, node): def local_inplace_sparse_block_gemv(fgraph, node):
""" """
SparseBlockGemv(inplace=False) -> SparseBlockGemv(inplace=True) SparseBlockGemv(inplace=False) -> SparseBlockGemv(inplace=True)
...@@ -60,7 +60,7 @@ compile.optdb.register( ...@@ -60,7 +60,7 @@ compile.optdb.register(
) # DEBUG ) # DEBUG
@local_optimizer([SparseBlockOuter], inplace=True) @node_rewriter([SparseBlockOuter], inplace=True)
def local_inplace_sparse_block_outer(fgraph, node): def local_inplace_sparse_block_outer(fgraph, node):
""" """
SparseBlockOuter(inplace=False) -> SparseBlockOuter(inplace=True) SparseBlockOuter(inplace=False) -> SparseBlockOuter(inplace=True)
...@@ -85,7 +85,7 @@ compile.optdb.register( ...@@ -85,7 +85,7 @@ compile.optdb.register(
# Conv opts # Conv opts
@local_optimizer([AbstractConv2d]) @node_rewriter([AbstractConv2d])
def local_abstractconv_gemm(fgraph, node): def local_abstractconv_gemm(fgraph, node):
# If config.blas__ldflags is empty, Aesara will use # If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_. # a NumPy C implementation of [sd]gemm_.
...@@ -113,7 +113,7 @@ def local_abstractconv_gemm(fgraph, node): ...@@ -113,7 +113,7 @@ def local_abstractconv_gemm(fgraph, node):
return [rval] return [rval]
@local_optimizer([AbstractConv3d]) @node_rewriter([AbstractConv3d])
def local_abstractconv3d_gemm(fgraph, node): def local_abstractconv3d_gemm(fgraph, node):
# If config.blas__ldflags is empty, Aesara will use # If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_. # a NumPy C implementation of [sd]gemm_.
...@@ -139,7 +139,7 @@ def local_abstractconv3d_gemm(fgraph, node): ...@@ -139,7 +139,7 @@ def local_abstractconv3d_gemm(fgraph, node):
return [rval] return [rval]
@local_optimizer([AbstractConv2d_gradWeights]) @node_rewriter([AbstractConv2d_gradWeights])
def local_abstractconv_gradweight_gemm(fgraph, node): def local_abstractconv_gradweight_gemm(fgraph, node):
# If config.blas__ldflags is empty, Aesara will use # If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_. # a NumPy C implementation of [sd]gemm_.
...@@ -169,7 +169,7 @@ def local_abstractconv_gradweight_gemm(fgraph, node): ...@@ -169,7 +169,7 @@ def local_abstractconv_gradweight_gemm(fgraph, node):
return [rval] return [rval]
@local_optimizer([AbstractConv3d_gradWeights]) @node_rewriter([AbstractConv3d_gradWeights])
def local_abstractconv3d_gradweight_gemm(fgraph, node): def local_abstractconv3d_gradweight_gemm(fgraph, node):
# If config.blas__ldflags is empty, Aesara will use # If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_. # a NumPy C implementation of [sd]gemm_.
...@@ -197,7 +197,7 @@ def local_abstractconv3d_gradweight_gemm(fgraph, node): ...@@ -197,7 +197,7 @@ def local_abstractconv3d_gradweight_gemm(fgraph, node):
return [rval] return [rval]
@local_optimizer([AbstractConv2d_gradInputs]) @node_rewriter([AbstractConv2d_gradInputs])
def local_abstractconv_gradinputs_gemm(fgraph, node): def local_abstractconv_gradinputs_gemm(fgraph, node):
# If config.blas__ldflags is empty, Aesara will use # If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_. # a NumPy C implementation of [sd]gemm_.
...@@ -227,7 +227,7 @@ def local_abstractconv_gradinputs_gemm(fgraph, node): ...@@ -227,7 +227,7 @@ def local_abstractconv_gradinputs_gemm(fgraph, node):
return [rval] return [rval]
@local_optimizer([AbstractConv3d_gradInputs]) @node_rewriter([AbstractConv3d_gradInputs])
def local_abstractconv3d_gradinputs_gemm(fgraph, node): def local_abstractconv3d_gradinputs_gemm(fgraph, node):
# If config.blas__ldflags is empty, Aesara will use # If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_. # a NumPy C implementation of [sd]gemm_.
...@@ -255,7 +255,7 @@ def local_abstractconv3d_gradinputs_gemm(fgraph, node): ...@@ -255,7 +255,7 @@ def local_abstractconv3d_gradinputs_gemm(fgraph, node):
return [rval] return [rval]
@local_optimizer([AbstractConv2d]) @node_rewriter([AbstractConv2d])
def local_conv2d_cpu(fgraph, node): def local_conv2d_cpu(fgraph, node):
if not isinstance(node.op, AbstractConv2d) or node.inputs[0].dtype == "float16": if not isinstance(node.op, AbstractConv2d) or node.inputs[0].dtype == "float16":
...@@ -287,7 +287,7 @@ def local_conv2d_cpu(fgraph, node): ...@@ -287,7 +287,7 @@ def local_conv2d_cpu(fgraph, node):
return [rval] return [rval]
@local_optimizer([AbstractConv2d_gradWeights]) @node_rewriter([AbstractConv2d_gradWeights])
def local_conv2d_gradweight_cpu(fgraph, node): def local_conv2d_gradweight_cpu(fgraph, node):
if ( if (
not isinstance(node.op, AbstractConv2d_gradWeights) not isinstance(node.op, AbstractConv2d_gradWeights)
...@@ -396,7 +396,7 @@ def local_conv2d_gradweight_cpu(fgraph, node): ...@@ -396,7 +396,7 @@ def local_conv2d_gradweight_cpu(fgraph, node):
return [res] return [res]
@local_optimizer([AbstractConv2d_gradInputs]) @node_rewriter([AbstractConv2d_gradInputs])
def local_conv2d_gradinputs_cpu(fgraph, node): def local_conv2d_gradinputs_cpu(fgraph, node):
if ( if (
not isinstance(node.op, AbstractConv2d_gradInputs) not isinstance(node.op, AbstractConv2d_gradInputs)
...@@ -561,7 +561,7 @@ conv_groupopt.register( ...@@ -561,7 +561,7 @@ conv_groupopt.register(
# Verify that no AbstractConv are present in the graph # Verify that no AbstractConv are present in the graph
@local_optimizer( @node_rewriter(
[ [
AbstractConv2d, AbstractConv2d,
AbstractConv2d_gradWeights, AbstractConv2d_gradWeights,
......
...@@ -9,7 +9,7 @@ stability. ...@@ -9,7 +9,7 @@ stability.
import aesara import aesara
from aesara import printing from aesara import printing
from aesara import scalar as aes from aesara import scalar as aes
from aesara.graph.opt import copy_stack_trace, local_optimizer from aesara.graph.opt import copy_stack_trace, node_rewriter
from aesara.printing import pprint from aesara.printing import pprint
from aesara.scalar import sigmoid as scalar_sigmoid from aesara.scalar import sigmoid as scalar_sigmoid
from aesara.scalar.math import Sigmoid from aesara.scalar.math import Sigmoid
...@@ -99,7 +99,7 @@ pprint.assign(ultra_fast_sigmoid, printing.FunctionPrinter(["ultra_fast_sigmoid" ...@@ -99,7 +99,7 @@ pprint.assign(ultra_fast_sigmoid, printing.FunctionPrinter(["ultra_fast_sigmoid"
# @opt.register_uncanonicalize # @opt.register_uncanonicalize
@local_optimizer(None) @node_rewriter(None)
def local_ultra_fast_sigmoid(fgraph, node): def local_ultra_fast_sigmoid(fgraph, node):
""" """
When enabled, change all sigmoid to ultra_fast_sigmoid. When enabled, change all sigmoid to ultra_fast_sigmoid.
...@@ -159,7 +159,7 @@ def hard_sigmoid(x): ...@@ -159,7 +159,7 @@ def hard_sigmoid(x):
# @opt.register_uncanonicalize # @opt.register_uncanonicalize
@local_optimizer([sigmoid]) @node_rewriter([sigmoid])
def local_hard_sigmoid(fgraph, node): def local_hard_sigmoid(fgraph, node):
if isinstance(node.op, Elemwise) and node.op.scalar_op == scalar_sigmoid: if isinstance(node.op, Elemwise) and node.op.scalar_op == scalar_sigmoid:
out = hard_sigmoid(node.inputs[0]) out = hard_sigmoid(node.inputs[0])
......
...@@ -34,7 +34,7 @@ supposed to be canonical. ...@@ -34,7 +34,7 @@ supposed to be canonical.
import logging import logging
from aesara import scalar as aes from aesara import scalar as aes
from aesara.graph.opt import copy_stack_trace, local_optimizer from aesara.graph.opt import copy_stack_trace, node_rewriter
from aesara.tensor.basic import Alloc, alloc, constant from aesara.tensor.basic import Alloc, alloc, constant
from aesara.tensor.basic_opt import register_uncanonicalize from aesara.tensor.basic_opt import register_uncanonicalize
from aesara.tensor.elemwise import CAReduce, DimShuffle from aesara.tensor.elemwise import CAReduce, DimShuffle
...@@ -47,7 +47,7 @@ _logger = logging.getLogger("aesara.tensor.opt_uncanonicalize") ...@@ -47,7 +47,7 @@ _logger = logging.getLogger("aesara.tensor.opt_uncanonicalize")
@register_uncanonicalize @register_uncanonicalize
@local_optimizer([MaxAndArgmax]) @node_rewriter([MaxAndArgmax])
def local_max_and_argmax(fgraph, node): def local_max_and_argmax(fgraph, node):
""" """
If we don't use the argmax, change it to a max only. If we don't use the argmax, change it to a max only.
...@@ -66,7 +66,7 @@ def local_max_and_argmax(fgraph, node): ...@@ -66,7 +66,7 @@ def local_max_and_argmax(fgraph, node):
@register_uncanonicalize @register_uncanonicalize
@local_optimizer([neg]) @node_rewriter([neg])
def local_max_to_min(fgraph, node): def local_max_to_min(fgraph, node):
""" """
Change -(max(-x)) to min. Change -(max(-x)) to min.
...@@ -95,7 +95,7 @@ def local_max_to_min(fgraph, node): ...@@ -95,7 +95,7 @@ def local_max_to_min(fgraph, node):
@register_uncanonicalize @register_uncanonicalize
@local_optimizer([Alloc]) @node_rewriter([Alloc])
def local_alloc_dimshuffle(fgraph, node): def local_alloc_dimshuffle(fgraph, node):
""" """
If a dimshuffle is inside an alloc and only adds dimension to the If a dimshuffle is inside an alloc and only adds dimension to the
...@@ -118,7 +118,7 @@ def local_alloc_dimshuffle(fgraph, node): ...@@ -118,7 +118,7 @@ def local_alloc_dimshuffle(fgraph, node):
@register_uncanonicalize @register_uncanonicalize
@local_optimizer([Reshape]) @node_rewriter([Reshape])
def local_reshape_dimshuffle(fgraph, node): def local_reshape_dimshuffle(fgraph, node):
""" """
If a dimshuffle is inside a reshape and does not change the order If a dimshuffle is inside a reshape and does not change the order
...@@ -147,7 +147,7 @@ def local_reshape_dimshuffle(fgraph, node): ...@@ -147,7 +147,7 @@ def local_reshape_dimshuffle(fgraph, node):
@register_uncanonicalize @register_uncanonicalize
@local_optimizer([DimShuffle]) @node_rewriter([DimShuffle])
def local_dimshuffle_alloc(fgraph, node): def local_dimshuffle_alloc(fgraph, node):
""" """
If an alloc is inside a dimshuffle which only adds dimension to the left, If an alloc is inside a dimshuffle which only adds dimension to the left,
...@@ -175,7 +175,7 @@ def local_dimshuffle_alloc(fgraph, node): ...@@ -175,7 +175,7 @@ def local_dimshuffle_alloc(fgraph, node):
@register_uncanonicalize @register_uncanonicalize
@local_optimizer([DimShuffle]) @node_rewriter([DimShuffle])
def local_dimshuffle_subtensor(fgraph, node): def local_dimshuffle_subtensor(fgraph, node):
"""If a subtensor is inside a dimshuffle which only drop """If a subtensor is inside a dimshuffle which only drop
broadcastable dimensions, scrap the dimshuffle and index the broadcastable dimensions, scrap the dimshuffle and index the
......
from aesara.compile import optdb from aesara.compile import optdb
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.op import compute_test_value from aesara.graph.op import compute_test_value
from aesara.graph.opt import in2out, local_optimizer from aesara.graph.opt import in2out, node_rewriter
from aesara.tensor.basic import constant, get_vector_length from aesara.tensor.basic import constant, get_vector_length
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.extra_ops import broadcast_to from aesara.tensor.extra_ops import broadcast_to
...@@ -39,7 +39,7 @@ def is_rv_used_in_graph(base_rv, node, fgraph): ...@@ -39,7 +39,7 @@ def is_rv_used_in_graph(base_rv, node, fgraph):
return not all(_node_check(n, i) for n, i in fgraph.clients.get(base_rv, ())) return not all(_node_check(n, i) for n, i in fgraph.clients.get(base_rv, ()))
@local_optimizer([RandomVariable], inplace=True) @node_rewriter([RandomVariable], inplace=True)
def random_make_inplace(fgraph, node): def random_make_inplace(fgraph, node):
op = node.op op = node.op
...@@ -61,7 +61,7 @@ optdb.register( ...@@ -61,7 +61,7 @@ optdb.register(
) )
@local_optimizer(tracks=None) @node_rewriter(tracks=None)
def local_rv_size_lift(fgraph, node): def local_rv_size_lift(fgraph, node):
"""Lift the ``size`` parameter in a ``RandomVariable``. """Lift the ``size`` parameter in a ``RandomVariable``.
...@@ -109,7 +109,7 @@ def local_rv_size_lift(fgraph, node): ...@@ -109,7 +109,7 @@ def local_rv_size_lift(fgraph, node):
return new_node.outputs return new_node.outputs
@local_optimizer([DimShuffle]) @node_rewriter([DimShuffle])
def local_dimshuffle_rv_lift(fgraph, node): def local_dimshuffle_rv_lift(fgraph, node):
"""Lift a ``DimShuffle`` through ``RandomVariable`` inputs. """Lift a ``DimShuffle`` through ``RandomVariable`` inputs.
...@@ -266,7 +266,7 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -266,7 +266,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
return False return False
@local_optimizer([Subtensor, AdvancedSubtensor1, AdvancedSubtensor]) @node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])
def local_subtensor_rv_lift(fgraph, node): def local_subtensor_rv_lift(fgraph, node):
"""Lift a ``*Subtensor`` through ``RandomVariable`` inputs. """Lift a ``*Subtensor`` through ``RandomVariable`` inputs.
......
...@@ -7,7 +7,7 @@ import aesara ...@@ -7,7 +7,7 @@ import aesara
import aesara.scalar.basic as aes import aesara.scalar.basic as aes
from aesara import compile from aesara import compile
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.graph.opt import TopoOptimizer, copy_stack_trace, in2out, local_optimizer from aesara.graph.opt import TopoOptimizer, copy_stack_trace, in2out, node_rewriter
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.tensor.basic import ( from aesara.tensor.basic import (
Alloc, Alloc,
...@@ -202,7 +202,7 @@ def get_advsubtensor_axis(indices): ...@@ -202,7 +202,7 @@ def get_advsubtensor_axis(indices):
@register_specialize @register_specialize
@local_optimizer([AdvancedSubtensor]) @node_rewriter([AdvancedSubtensor])
def local_replace_AdvancedSubtensor(fgraph, node): def local_replace_AdvancedSubtensor(fgraph, node):
r""" r"""
This rewrite converts expressions like ``X[..., y]`` into ``X.T[y].T``, for This rewrite converts expressions like ``X[..., y]`` into ``X.T[y].T``, for
...@@ -231,7 +231,7 @@ def local_replace_AdvancedSubtensor(fgraph, node): ...@@ -231,7 +231,7 @@ def local_replace_AdvancedSubtensor(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([AdvancedIncSubtensor]) @node_rewriter([AdvancedIncSubtensor])
def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
r"""Replace `AdvancedIncSubtensor`\s with `AdvancedIncSubtensor1`\s. r"""Replace `AdvancedIncSubtensor`\s with `AdvancedIncSubtensor1`\s.
...@@ -268,7 +268,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): ...@@ -268,7 +268,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([Subtensor]) @node_rewriter([Subtensor])
def local_subtensor_of_dot(fgraph, node): def local_subtensor_of_dot(fgraph, node):
"""Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``. """Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``.
``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is ``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is
...@@ -326,7 +326,7 @@ def local_subtensor_of_dot(fgraph, node): ...@@ -326,7 +326,7 @@ def local_subtensor_of_dot(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Subtensor]) @node_rewriter([Subtensor])
def local_useless_slice(fgraph, node): def local_useless_slice(fgraph, node):
""" """
Remove Subtensor of the form X[0, :] -> X[0] Remove Subtensor of the form X[0, :] -> X[0]
...@@ -362,7 +362,7 @@ def local_useless_slice(fgraph, node): ...@@ -362,7 +362,7 @@ def local_useless_slice(fgraph, node):
# fast_compile to allow opt subtensor(cast{float32}(make_vector)) # fast_compile to allow opt subtensor(cast{float32}(make_vector))
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@local_optimizer([Subtensor]) @node_rewriter([Subtensor])
def local_subtensor_lift(fgraph, node): def local_subtensor_lift(fgraph, node):
""" """
unary(x)[idx] -> unary(x[idx])#any broadcast pattern. unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
...@@ -466,7 +466,7 @@ def local_subtensor_lift(fgraph, node): ...@@ -466,7 +466,7 @@ def local_subtensor_lift(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Subtensor]) @node_rewriter([Subtensor])
def local_subtensor_merge(fgraph, node): def local_subtensor_merge(fgraph, node):
""" """
Refactored optimization to deal with all cases of tensor merging. Refactored optimization to deal with all cases of tensor merging.
...@@ -537,7 +537,7 @@ def local_subtensor_merge(fgraph, node): ...@@ -537,7 +537,7 @@ def local_subtensor_merge(fgraph, node):
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@local_optimizer([Subtensor]) @node_rewriter([Subtensor])
def local_subtensor_remove_broadcastable_index(fgraph, node): def local_subtensor_remove_broadcastable_index(fgraph, node):
""" """
Remove broadcastable dimension with index 0 or -1 Remove broadcastable dimension with index 0 or -1
...@@ -586,7 +586,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): ...@@ -586,7 +586,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Subtensor]) @node_rewriter([Subtensor])
def local_subtensor_of_alloc(fgraph, node): def local_subtensor_of_alloc(fgraph, node):
""" """
...@@ -654,7 +654,7 @@ def local_subtensor_of_alloc(fgraph, node): ...@@ -654,7 +654,7 @@ def local_subtensor_of_alloc(fgraph, node):
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@local_optimizer([Subtensor]) @node_rewriter([Subtensor])
def local_subtensor_inc_subtensor(fgraph, node): def local_subtensor_inc_subtensor(fgraph, node):
""" """
Subtensor(SetSubtensor(x, y, idx), idx) -> y Subtensor(SetSubtensor(x, y, idx), idx) -> y
...@@ -694,7 +694,7 @@ def local_subtensor_inc_subtensor(fgraph, node): ...@@ -694,7 +694,7 @@ def local_subtensor_inc_subtensor(fgraph, node):
@register_specialize @register_specialize
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@register_useless @register_useless
@local_optimizer([Subtensor, AdvancedSubtensor1]) @node_rewriter([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(fgraph, node): def local_subtensor_make_vector(fgraph, node):
"""Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant. """Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant.
...@@ -770,7 +770,7 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -770,7 +770,7 @@ def local_subtensor_make_vector(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([IncSubtensor]) @node_rewriter([IncSubtensor])
def local_useless_inc_subtensor(fgraph, node): def local_useless_inc_subtensor(fgraph, node):
r"""Remove redundant `IncSubtensor`\s. r"""Remove redundant `IncSubtensor`\s.
...@@ -834,7 +834,7 @@ def local_useless_inc_subtensor(fgraph, node): ...@@ -834,7 +834,7 @@ def local_useless_inc_subtensor(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([AdvancedIncSubtensor1]) @node_rewriter([AdvancedIncSubtensor1])
def local_set_to_inc_subtensor(fgraph, node): def local_set_to_inc_subtensor(fgraph, node):
r""" r"""
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) -> AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
...@@ -878,7 +878,7 @@ def local_set_to_inc_subtensor(fgraph, node): ...@@ -878,7 +878,7 @@ def local_set_to_inc_subtensor(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Subtensor]) @node_rewriter([Subtensor])
def local_useless_subtensor(fgraph, node): def local_useless_subtensor(fgraph, node):
"""Remove `Subtensor` if it takes the full input.""" """Remove `Subtensor` if it takes the full input."""
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
...@@ -960,7 +960,7 @@ def local_useless_subtensor(fgraph, node): ...@@ -960,7 +960,7 @@ def local_useless_subtensor(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([AdvancedSubtensor1]) @node_rewriter([AdvancedSubtensor1])
def local_useless_AdvancedSubtensor1(fgraph, node): def local_useless_AdvancedSubtensor1(fgraph, node):
"""Remove `AdvancedSubtensor1` if it takes the full input. """Remove `AdvancedSubtensor1` if it takes the full input.
...@@ -1116,7 +1116,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2): ...@@ -1116,7 +1116,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
@register_canonicalize @register_canonicalize
@local_optimizer([add]) @node_rewriter([add])
def local_IncSubtensor_serialize(fgraph, node): def local_IncSubtensor_serialize(fgraph, node):
""" """
When using Subtensor, gradient graphs can be ugly. When using Subtensor, gradient graphs can be ugly.
...@@ -1216,7 +1216,7 @@ compile.optdb.register( ...@@ -1216,7 +1216,7 @@ compile.optdb.register(
# gemm is the first one now, at priority 70 # gemm is the first one now, at priority 70
@local_optimizer([IncSubtensor], inplace=True) @node_rewriter([IncSubtensor], inplace=True)
def local_inplace_setsubtensor(fgraph, node): def local_inplace_setsubtensor(fgraph, node):
if isinstance(node.op, IncSubtensor) and not node.op.inplace: if isinstance(node.op, IncSubtensor) and not node.op.inplace:
dta = node.op.destroyhandler_tolerate_aliased dta = node.op.destroyhandler_tolerate_aliased
...@@ -1249,7 +1249,7 @@ compile.optdb.register( ...@@ -1249,7 +1249,7 @@ compile.optdb.register(
) )
@local_optimizer([AdvancedIncSubtensor1], inplace=True) @node_rewriter([AdvancedIncSubtensor1], inplace=True)
def local_inplace_AdvancedIncSubtensor1(fgraph, node): def local_inplace_AdvancedIncSubtensor1(fgraph, node):
if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace: if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
new_op = node.op.clone_inplace() new_op = node.op.clone_inplace()
...@@ -1270,7 +1270,7 @@ compile.optdb.register( ...@@ -1270,7 +1270,7 @@ compile.optdb.register(
) )
@local_optimizer([AdvancedIncSubtensor], inplace=True) @node_rewriter([AdvancedIncSubtensor], inplace=True)
def local_inplace_AdvancedIncSubtensor(fgraph, node): def local_inplace_AdvancedIncSubtensor(fgraph, node):
if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace: if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace:
new_op = type(node.op)( new_op = type(node.op)(
...@@ -1298,7 +1298,7 @@ compile.optdb.register( ...@@ -1298,7 +1298,7 @@ compile.optdb.register(
# Register old name # Register old name
@register_canonicalize("local_incsubtensor_of_allocs") @register_canonicalize("local_incsubtensor_of_allocs")
@register_stabilize("local_incsubtensor_of_allocs") @register_stabilize("local_incsubtensor_of_allocs")
@local_optimizer([IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1]) @node_rewriter([IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1])
def local_incsubtensor_of_zeros(fgraph, node): def local_incsubtensor_of_zeros(fgraph, node):
""" """
IncSubtensor(x, zeros, idx) -> x IncSubtensor(x, zeros, idx) -> x
...@@ -1323,7 +1323,7 @@ def local_incsubtensor_of_zeros(fgraph, node): ...@@ -1323,7 +1323,7 @@ def local_incsubtensor_of_zeros(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([IncSubtensor]) @node_rewriter([IncSubtensor])
def local_incsubtensor_of_zeros_to_setsubtensor(fgraph, node): def local_incsubtensor_of_zeros_to_setsubtensor(fgraph, node):
""" """
IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...) IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...)
...@@ -1344,7 +1344,7 @@ def local_incsubtensor_of_zeros_to_setsubtensor(fgraph, node): ...@@ -1344,7 +1344,7 @@ def local_incsubtensor_of_zeros_to_setsubtensor(fgraph, node):
@register_canonicalize("local_setsubtensor_of_allocs") @register_canonicalize("local_setsubtensor_of_allocs")
@register_stabilize("local_setsubtensor_of_allocs") @register_stabilize("local_setsubtensor_of_allocs")
@local_optimizer([IncSubtensor]) @node_rewriter([IncSubtensor])
def local_setsubtensor_of_constants(fgraph, node): def local_setsubtensor_of_constants(fgraph, node):
""" """
SetSubtensor(x, x[idx], idx) -> x SetSubtensor(x, x[idx], idx) -> x
...@@ -1379,7 +1379,7 @@ def local_setsubtensor_of_constants(fgraph, node): ...@@ -1379,7 +1379,7 @@ def local_setsubtensor_of_constants(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([AdvancedSubtensor1]) @node_rewriter([AdvancedSubtensor1])
def local_adv_sub1_adv_inc_sub1(fgraph, node): def local_adv_sub1_adv_inc_sub1(fgraph, node):
"""Optimize the possible AdvSub1(AdvSetSub1(...), ...). """Optimize the possible AdvSub1(AdvSetSub1(...), ...).
...@@ -1446,7 +1446,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node): ...@@ -1446,7 +1446,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
@register_useless @register_useless
@local_optimizer([IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1]) @node_rewriter([IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1])
def local_useless_inc_subtensor_alloc(fgraph, node): def local_useless_inc_subtensor_alloc(fgraph, node):
""" """
Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of
...@@ -1552,7 +1552,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node): ...@@ -1552,7 +1552,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@local_optimizer([Subtensor]) @node_rewriter([Subtensor])
def local_subtensor_shape_constant(fgraph, node): def local_subtensor_shape_constant(fgraph, node):
r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known. r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known.
...@@ -1606,7 +1606,7 @@ def local_subtensor_shape_constant(fgraph, node): ...@@ -1606,7 +1606,7 @@ def local_subtensor_shape_constant(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([Subtensor]) @node_rewriter([Subtensor])
def local_subtensor_SpecifyShape_lift(fgraph, node): def local_subtensor_SpecifyShape_lift(fgraph, node):
"""Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``.""" """Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``."""
...@@ -1640,7 +1640,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node): ...@@ -1640,7 +1640,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([Join]) @node_rewriter([Join])
def local_join_subtensors(fgraph, node): def local_join_subtensors(fgraph, node):
r"""Simplify contiguous :class:`Subtensor`\s inside a :class:`Join`. r"""Simplify contiguous :class:`Subtensor`\s inside a :class:`Join`.
......
from aesara.compile import optdb from aesara.compile import optdb
from aesara.graph.opt import TopoOptimizer, local_optimizer from aesara.graph.opt import TopoOptimizer, node_rewriter
from aesara.typed_list.basic import Append, Extend, Insert, Remove, Reverse from aesara.typed_list.basic import Append, Extend, Insert, Remove, Reverse
@local_optimizer([Append, Extend, Insert, Reverse, Remove], inplace=True) @node_rewriter([Append, Extend, Insert, Reverse, Remove], inplace=True)
def typed_list_inplace_opt(fgraph, node): def typed_list_inplace_opt(fgraph, node):
if ( if (
isinstance(node.op, (Append, Extend, Insert, Reverse, Remove)) isinstance(node.op, (Append, Extend, Insert, Reverse, Remove))
......
...@@ -67,15 +67,15 @@ Local optimization ...@@ -67,15 +67,15 @@ Local optimization
A local optimization is an object which defines the following methods: A local optimization is an object which defines the following methods:
.. class:: LocalOptimizer .. class:: NodeRewriter
.. method:: transform(fgraph, node) .. method:: transform(fgraph, node)
This method takes a :class:`FunctionGraph` and an :class:`Apply` node and This method takes a :class:`FunctionGraph` and an :class:`Apply` node and
returns either ``False`` to signify that no changes are to be done or a returns either ``False`` to signify that no changes are to be done or a
list of :class:`Variable`\s which matches the length of the node's ``outputs`` list of :class:`Variable`\s which matches the length of the node's ``outputs``
list. When the :class:`LocalOptimizer` is applied by a :class:`NavigatorOptimizer`, the outputs list. When the :class:`NodeRewriter` is applied by a :class:`NavigatorOptimizer`, the outputs
of the node passed as argument to the :class:`LocalOptimizer` will be replaced by of the node passed as argument to the :class:`NodeRewriter` will be replaced by
the list returned. the list returned.
...@@ -218,10 +218,10 @@ The local version of the above code would be the following: ...@@ -218,10 +218,10 @@ The local version of the above code would be the following:
.. testcode:: .. testcode::
from aesara.graph.opt import LocalOptimizer from aesara.graph.opt import NodeRewriter
class LocalSimplify(LocalOptimizer): class LocalSimplify(NodeRewriter):
def transform(self, fgraph, node): def transform(self, fgraph, node):
if node.op == true_div: if node.op == true_div:
x, y = node.inputs x, y = node.inputs
...@@ -234,7 +234,7 @@ The local version of the above code would be the following: ...@@ -234,7 +234,7 @@ The local version of the above code would be the following:
return False return False
def tracks(self): def tracks(self):
# This tells certain navigators to only apply this `LocalOptimizer` # This tells certain navigators to only apply this `NodeRewriter`
# on these kinds of `Op`s # on these kinds of `Op`s
return [true_div] return [true_div]
...@@ -242,7 +242,7 @@ The local version of the above code would be the following: ...@@ -242,7 +242,7 @@ The local version of the above code would be the following:
In this case, the transformation is defined in the In this case, the transformation is defined in the
:meth:`LocalOptimizer.transform` method, which is given an explicit :meth:`NodeRewriter.transform` method, which is given an explicit
:class:`Apply` node on which to work. The entire graph--as a ``fgraph``--is :class:`Apply` node on which to work. The entire graph--as a ``fgraph``--is
also provided, in case global information is needed. also provided, in case global information is needed.
...@@ -273,7 +273,7 @@ FunctionGraph(add(z, mul(x, true_div(z, x)))) ...@@ -273,7 +273,7 @@ FunctionGraph(add(z, mul(x, true_div(z, x))))
:class:`OpSub`, :class:`OpRemove`, :class:`PatternSub` :class:`OpSub`, :class:`OpRemove`, :class:`PatternSub`
++++++++++++++++++++++++++++++++++++++++++++++++++++++ ++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aesara defines some shortcuts to make :class:`LocalOptimizer`\s: Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
.. function:: OpSub(op1, op2) .. function:: OpSub(op1, op2)
...@@ -433,7 +433,7 @@ This means that a relation that--say--represents :math:`x + x = 2 x` can be ...@@ -433,7 +433,7 @@ This means that a relation that--say--represents :math:`x + x = 2 x` can be
utilized in both directions. utilized in both directions.
Currently, the local optimizer :class:`KanrenRelationSub` provides a means of Currently, the local optimizer :class:`KanrenRelationSub` provides a means of
turning :mod:`kanren` relations into :class:`LocalOptimizer`\s; however, turning :mod:`kanren` relations into :class:`NodeRewriter`\s; however,
:mod:`kanren` can always be used directly from within a custom :class:`Rewriter`, so :mod:`kanren` can always be used directly from within a custom :class:`Rewriter`, so
:class:`KanrenRelationSub` is not necessary. :class:`KanrenRelationSub` is not necessary.
...@@ -561,7 +561,7 @@ serve as a basis for filtering. ...@@ -561,7 +561,7 @@ serve as a basis for filtering.
The point of :obj:`optdb` is that you might want to apply many optimizations The point of :obj:`optdb` is that you might want to apply many optimizations
to a computation graph in many unique patterns. For example, you might to a computation graph in many unique patterns. For example, you might
want to do optimization X, then optimization Y, then optimization Z. And then want to do optimization X, then optimization Y, then optimization Z. And then
maybe optimization Y is an :class:`EquilibriumOptimizer` containing :class:`LocalOptimizer`\s A, B maybe optimization Y is an :class:`EquilibriumOptimizer` containing :class:`NodeRewriter`\s A, B
and C which are applied on every node of the graph until they all fail to change and C which are applied on every node of the graph until they all fail to change
it. If some optimizations act up, we want an easy way to turn them off. Ditto if it. If some optimizations act up, we want an easy way to turn them off. Ditto if
some optimizations are very CPU-intensive and we don't want to take the time to some optimizations are very CPU-intensive and we don't want to take the time to
...@@ -596,14 +596,14 @@ is returned. If the :class:`SequenceDB` contains :class:`OptimizationDatabase` ...@@ -596,14 +596,14 @@ is returned. If the :class:`SequenceDB` contains :class:`OptimizationDatabase`
instances, the :class:`OptimizationQuery` will be passed to them as well and the instances, the :class:`OptimizationQuery` will be passed to them as well and the
optimizers they return will be put in their places. optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`LocalOptimizer` or :class:`OptimizationDatabase` objects. Each of them An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`OptimizationDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to
an :class:`EquilibriumDB`, all :class:`LocalOptimizer`\s that match the query are an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are
inserted into an :class:`EquilibriumOptimizer`, which is returned. If the inserted into an :class:`EquilibriumOptimizer`, which is returned. If the
:class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the :class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the
:class:`OptimizationQuery` will be passed to them as well and the :class:`OptimizationQuery` will be passed to them as well and the
:class:`LocalOptimizer`\s they return will be put in their places :class:`NodeRewriter`\s they return will be put in their places
(note that as of yet no :class:`OptimizationDatabase` can produce :class:`LocalOptimizer` objects, so this (note that as of yet no :class:`OptimizationDatabase` can produce :class:`NodeRewriter` objects, so this
is a moot point). is a moot point).
Aesara contains one principal :class:`OptimizationDatabase` object, :class:`optdb`, which Aesara contains one principal :class:`OptimizationDatabase` object, :class:`optdb`, which
...@@ -697,10 +697,10 @@ already-compiled functions will see no change. The 'order' parameter ...@@ -697,10 +697,10 @@ already-compiled functions will see no change. The 'order' parameter
Registering a :class:`LocalOptimizer` Registering a :class:`NodeRewriter`
------------------------------------- -----------------------------------
:class:`LocalOptimizer`\s may be registered in two ways: :class:`NodeRewriter`\s may be registered in two ways:
* Wrap them in a :class:`NavigatorOptimizer` and insert them like a global optimizer * Wrap them in a :class:`NavigatorOptimizer` and insert them like a global optimizer
(see previous section). (see previous section).
......
...@@ -18,7 +18,7 @@ from aesara.configdefaults import config ...@@ -18,7 +18,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.features import BadOptimization from aesara.graph.features import BadOptimization
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import local_optimizer from aesara.graph.opt import node_rewriter
from aesara.graph.optdb import EquilibriumDB from aesara.graph.optdb import EquilibriumDB
from aesara.link.c.op import COp from aesara.link.c.op import COp
from aesara.tensor.math import add, dot, log from aesara.tensor.math import add, dot, log
...@@ -237,7 +237,7 @@ def test_badthunkoutput(): ...@@ -237,7 +237,7 @@ def test_badthunkoutput():
def test_badoptimization(): def test_badoptimization():
@local_optimizer([add]) @node_rewriter([add])
def insert_broken_add(fgraph, node): def insert_broken_add(fgraph, node):
if node.op == add: if node.op == add:
return [off_by_half(*node.inputs)] return [off_by_half(*node.inputs)]
...@@ -263,7 +263,7 @@ def test_badoptimization(): ...@@ -263,7 +263,7 @@ def test_badoptimization():
def test_badoptimization_opt_err(): def test_badoptimization_opt_err():
# This variant of test_badoptimization() replace the working code # This variant of test_badoptimization() replace the working code
# with a new apply node that will raise an error. # with a new apply node that will raise an error.
@local_optimizer([add]) @node_rewriter([add])
def insert_bigger_b_add(fgraph, node): def insert_bigger_b_add(fgraph, node):
if node.op == add: if node.op == add:
inputs = list(node.inputs) inputs = list(node.inputs)
...@@ -272,7 +272,7 @@ def test_badoptimization_opt_err(): ...@@ -272,7 +272,7 @@ def test_badoptimization_opt_err():
return [node.op(*inputs)] return [node.op(*inputs)]
return False return False
@local_optimizer([add]) @node_rewriter([add])
def insert_bad_dtype(fgraph, node): def insert_bad_dtype(fgraph, node):
if node.op == add: if node.op == add:
inputs = list(node.inputs) inputs = list(node.inputs)
...@@ -326,7 +326,7 @@ def test_stochasticoptimization(): ...@@ -326,7 +326,7 @@ def test_stochasticoptimization():
last_time_replaced = [False] last_time_replaced = [False]
@local_optimizer([add]) @node_rewriter([add])
def insert_broken_add_sometimes(fgraph, node): def insert_broken_add_sometimes(fgraph, node):
if node.op == add: if node.op == add:
last_time_replaced[0] = not last_time_replaced[0] last_time_replaced[0] = not last_time_replaced[0]
......
...@@ -15,10 +15,10 @@ from aesara.graph.opt import ( ...@@ -15,10 +15,10 @@ from aesara.graph.opt import (
PatternSub, PatternSub,
TopoOptimizer, TopoOptimizer,
in2out, in2out,
local_optimizer,
logging, logging,
node_rewriter,
pre_constant_merge, pre_constant_merge,
pre_greedy_local_optimizer, pre_greedy_node_rewriter,
) )
from aesara.raise_op import assert_op from aesara.raise_op import assert_op
from aesara.tensor.basic_opt import constant_folding from aesara.tensor.basic_opt import constant_folding
...@@ -547,7 +547,7 @@ def test_pre_constant_merge(): ...@@ -547,7 +547,7 @@ def test_pre_constant_merge():
assert res == [adv] assert res == [adv]
def test_pre_greedy_local_optimizer(): def test_pre_greedy_node_rewriter():
empty_fgraph = FunctionGraph([], []) empty_fgraph = FunctionGraph([], [])
...@@ -564,7 +564,7 @@ def test_pre_greedy_local_optimizer(): ...@@ -564,7 +564,7 @@ def test_pre_greedy_local_optimizer():
# This should fold `o1`, because it has only `Constant` arguments, and # This should fold `o1`, because it has only `Constant` arguments, and
# replace it with the `Constant` result # replace it with the `Constant` result
cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], o2) cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], o2)
assert cst.owner.inputs[0].owner is None assert cst.owner.inputs[0].owner is None
assert cst.owner.inputs[1] is c2 assert cst.owner.inputs[1] is c2
...@@ -577,14 +577,14 @@ def test_pre_greedy_local_optimizer(): ...@@ -577,14 +577,14 @@ def test_pre_greedy_local_optimizer():
fg = FunctionGraph([], [o1], clone=False) fg = FunctionGraph([], [o1], clone=False)
o2 = op1(o1, c2, x, o3, o1) o2 = op1(o1, c2, x, o3, o1)
cst = pre_greedy_local_optimizer(fg, [constant_folding], o2) cst = pre_greedy_node_rewriter(fg, [constant_folding], o2)
assert cst.owner.inputs[0] is o1 assert cst.owner.inputs[0] is o1
assert cst.owner.inputs[4] is cst.owner.inputs[0] assert cst.owner.inputs[4] is cst.owner.inputs[0]
# What exactly is this supposed to test? # What exactly is this supposed to test?
ms = MakeSlice()(1) ms = MakeSlice()(1)
cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], ms) cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], ms)
assert isinstance(cst, SliceConstant) assert isinstance(cst, SliceConstant)
...@@ -673,13 +673,13 @@ class TestLocalOptGroup: ...@@ -673,13 +673,13 @@ class TestLocalOptGroup:
fgraph = FunctionGraph([x, y], [o1], clone=False) fgraph = FunctionGraph([x, y], [o1], clone=False)
@local_optimizer(None) @node_rewriter(None)
def local_opt_1(fgraph, node): def local_opt_1(fgraph, node):
if node.inputs[0] == x: if node.inputs[0] == x:
res = op2(y, *node.inputs[1:]) res = op2(y, *node.inputs[1:])
return [res] return [res]
@local_optimizer(None) @node_rewriter(None)
def local_opt_2(fgraph, node): def local_opt_2(fgraph, node):
if node.inputs[0] == y: if node.inputs[0] == y:
res = op2(x, *node.inputs[1:]) res = op2(x, *node.inputs[1:])
...@@ -703,8 +703,8 @@ class TestLocalOptGroup: ...@@ -703,8 +703,8 @@ class TestLocalOptGroup:
) )
def test_local_optimizer_str(): def test_node_rewriter_str():
@local_optimizer([op1, MyOp]) @node_rewriter([op1, MyOp])
def local_opt_1(fgraph, node): def local_opt_1(fgraph, node):
pass pass
...@@ -715,17 +715,17 @@ def test_local_optimizer_str(): ...@@ -715,17 +715,17 @@ def test_local_optimizer_str():
assert "local_opt_1" in res assert "local_opt_1" in res
def test_local_optimizer(): def test_node_rewriter():
with pytest.raises(ValueError): with pytest.raises(ValueError):
@local_optimizer([]) @node_rewriter([])
def local_bad_1(fgraph, node): def local_bad_1(fgraph, node):
return node.outputs return node.outputs
with pytest.raises(TypeError): with pytest.raises(TypeError):
@local_optimizer([None]) @node_rewriter([None])
def local_bad_2(fgraph, node): def local_bad_2(fgraph, node):
return node.outputs return node.outputs
...@@ -748,7 +748,7 @@ def test_local_optimizer(): ...@@ -748,7 +748,7 @@ def test_local_optimizer():
hits = [0] hits = [0]
@local_optimizer([op1, MyNewOp]) @node_rewriter([op1, MyNewOp])
def local_opt_1(fgraph, node, hits=hits): def local_opt_1(fgraph, node, hits=hits):
hits[0] += 1 hits[0] += 1
return node.outputs return node.outputs
...@@ -766,24 +766,24 @@ def test_local_optimizer(): ...@@ -766,24 +766,24 @@ def test_local_optimizer():
assert hits[0] == 2 assert hits[0] == 2
def test_TrackingLocalOptimizer(): def test_TrackingNodeRewriter():
@local_optimizer(None) @node_rewriter(None)
def local_opt_1(fgraph, node): def local_opt_1(fgraph, node):
pass pass
@local_optimizer([op1]) @node_rewriter([op1])
def local_opt_2(fgraph, node): def local_opt_2(fgraph, node):
pass pass
@local_optimizer([Op]) @node_rewriter([Op])
def local_opt_3(fgraph, node): def local_opt_3(fgraph, node):
pass pass
@local_optimizer([MyOp]) @node_rewriter([MyOp])
def local_opt_4(fgraph, node): def local_opt_4(fgraph, node):
pass pass
@local_optimizer([MyOp]) @node_rewriter([MyOp])
def local_opt_5(fgraph, node): def local_opt_5(fgraph, node):
pass pass
......
...@@ -16,7 +16,7 @@ from aesara.configdefaults import config ...@@ -16,7 +16,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import check_stack_trace, local_optimizer, out2in from aesara.graph.opt import check_stack_trace, node_rewriter, out2in
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type from aesara.graph.type import Type
...@@ -1752,7 +1752,7 @@ class TestShapeOptimizer: ...@@ -1752,7 +1752,7 @@ class TestShapeOptimizer:
identity_shape = IdentityShape() identity_shape = IdentityShape()
@local_optimizer([IdentityNoShape]) @node_rewriter([IdentityNoShape])
def local_identity_noshape_to_identity_shape(fgraph, node): def local_identity_noshape_to_identity_shape(fgraph, node):
"""Optimization transforming the first Op into the second""" """Optimization transforming the first Op into the second"""
if isinstance(node.op, IdentityNoShape): if isinstance(node.op, IdentityNoShape):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论