提交 0ce6eceb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor old global and local optimizers references and type hints

上级 550a6e98
...@@ -11,7 +11,7 @@ from aesara.graph.unify import eval_if_etuple ...@@ -11,7 +11,7 @@ from aesara.graph.unify import eval_if_etuple
class KanrenRelationSub(NodeRewriter): class KanrenRelationSub(NodeRewriter):
r"""A local optimizer that uses `kanren` to match and replace terms. r"""A rewriter that uses `kanren` to match and replace terms.
See `kanren <https://github.com/pythological/kanren>`__ for more information See `kanren <https://github.com/pythological/kanren>`__ for more information
miniKanren and the API for constructing `kanren` goals. miniKanren and the API for constructing `kanren` goals.
...@@ -56,7 +56,7 @@ class KanrenRelationSub(NodeRewriter): ...@@ -56,7 +56,7 @@ class KanrenRelationSub(NodeRewriter):
A function that takes an input graph and an output logic variable and A function that takes an input graph and an output logic variable and
returns a `kanren` goal. returns a `kanren` goal.
results_filter results_filter
A function that takes the direct output of `kanren.run(None, ...)` A function that takes the direct output of ``kanren.run(None, ...)``
and returns a single result. The default implementation returns and returns a single result. The default implementation returns
the first result. the first result.
node_filter node_filter
......
...@@ -17,7 +17,7 @@ from collections import UserList, defaultdict, deque ...@@ -17,7 +17,7 @@ from collections import UserList, defaultdict, deque
from collections.abc import Iterable from collections.abc import Iterable
from functools import _compose_mro, partial, reduce # type: ignore from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain from itertools import chain
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing_extensions import Literal from typing_extensions import Literal
...@@ -156,15 +156,20 @@ class NodeRewriter(Rewriter): ...@@ -156,15 +156,20 @@ class NodeRewriter(Rewriter):
@abc.abstractmethod @abc.abstractmethod
def transform( def transform(
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs self, fgraph: FunctionGraph, node: Apply, *args, **kwargs
) -> Union[bool, List[Variable], Dict[Variable, Variable]]: ) -> Union[
r"""Transform a subgraph whose output is `node`. bool,
Sequence[Variable],
Dict[Union[Variable, Literal["remove"]], Union[Variable, Sequence[Variable]]],
]:
r"""Rewrite the sub-graph given by `node`.
Subclasses should implement this function so that it returns one of the Subclasses should implement this function so that it returns one of the
following: following:
- ``False`` to indicate that this rewrite cannot be applied to `node` - ``False`` to indicate that this rewrite cannot be applied to `node`
- A list of `Variable`\s to use in place of the `node`'s current outputs - A list of `Variable`\s to use in place of the `node`'s current outputs
- A ``dict`` mapping old `Variable`\s to new `Variable`\s - A ``dict`` mapping old `Variable`\s to `Variable`\s, or the key
``"remove"`` mapping to a list of `Variable`\s to be removed.
Parameters Parameters
---------- ----------
...@@ -1850,10 +1855,15 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1850,10 +1855,15 @@ class NavigatorOptimizer(GraphRewriter):
if u is not None: if u is not None:
fgraph.remove_feature(u) fgraph.remove_feature(u)
def process_node(self, fgraph, node, lopt=None): def process_node(
r"""Apply `lopt` to `node`. self,
fgraph: FunctionGraph,
node: Apply,
node_rewriter: Optional[NodeRewriter] = None,
):
r"""Apply `node_rewriter` to `node`.
The :meth:`lopt.transform` method will return either ``False`` or a The :meth:`node_rewriter.transform` method will return either ``False`` or a
list of `Variable`\s that are intended to replace :attr:`node.outputs`. list of `Variable`\s that are intended to replace :attr:`node.outputs`.
If the `fgraph` accepts the replacement, then the optimization is If the `fgraph` accepts the replacement, then the optimization is
...@@ -1864,11 +1874,11 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1864,11 +1874,11 @@ class NavigatorOptimizer(GraphRewriter):
Parameters Parameters
---------- ----------
fgraph : fgraph
A `FunctionGraph`. A `FunctionGraph`.
node : node
An `Apply` instance in `fgraph` An `Apply` instance in `fgraph`
lopt : node_rewriter
A `NodeRewriter` instance that may have a better idea for A `NodeRewriter` instance that may have a better idea for
how to compute node's outputs. how to compute node's outputs.
...@@ -1878,13 +1888,15 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1878,13 +1888,15 @@ class NavigatorOptimizer(GraphRewriter):
``True`` iff the `node`'s outputs were replaced in the `fgraph`. ``True`` iff the `node`'s outputs were replaced in the `fgraph`.
""" """
lopt = lopt or self.node_rewriter node_rewriter = node_rewriter or self.node_rewriter
# TODO FIXME: This class's interface is broken
assert node_rewriter is not None
try: try:
replacements = lopt.transform(fgraph, node) replacements = node_rewriter.transform(fgraph, node)
except Exception as e: except Exception as e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback( self.failure_callback(
e, self, [(x, None) for x in node.outputs], lopt, node e, self, [(x, None) for x in node.outputs], node_rewriter, node
) )
return False return False
else: else:
...@@ -1892,25 +1904,27 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1892,25 +1904,27 @@ class NavigatorOptimizer(GraphRewriter):
if replacements is False or replacements is None: if replacements is False or replacements is None:
return False return False
old_vars = node.outputs old_vars = node.outputs
remove = [] remove: List[Variable] = []
if isinstance(replacements, dict): if isinstance(replacements, dict):
if "remove" in replacements: if "remove" in replacements:
remove = replacements.pop("remove") remove = list(cast(Sequence[Variable], replacements.pop("remove")))
old_vars = list(replacements.keys()) old_vars = list(cast(Sequence[Variable], replacements.keys()))
replacements = list(replacements.values()) replacements = list(cast(Sequence[Variable], replacements.values()))
elif not isinstance(replacements, (tuple, list)): elif not isinstance(replacements, (tuple, list)):
raise TypeError( raise TypeError(
f"Node rewriter {lopt} gave wrong type of replacement. " f"Node rewriter {node_rewriter} gave wrong type of replacement. "
f"Expected list or tuple; got {replacements}" f"Expected list or tuple; got {replacements}"
) )
if len(old_vars) != len(replacements): if len(old_vars) != len(replacements):
raise ValueError(f"Node rewriter {lopt} gave wrong number of replacements") raise ValueError(
f"Node rewriter {node_rewriter} gave wrong number of replacements"
)
# None in the replacement mean that this variable isn't used # None in the replacement mean that this variable isn't used
# and we want to remove it # and we want to remove it
for r, rnew in zip(old_vars, replacements): for r, rnew in zip(old_vars, replacements):
if rnew is None and len(fgraph.clients[r]) > 0: if rnew is None and len(fgraph.clients[r]) > 0:
raise ValueError( raise ValueError(
f"Node rewriter {lopt} tried to remove a variable" f"Node rewriter {node_rewriter} tried to remove a variable"
f" that is being used: {r}" f" that is being used: {r}"
) )
# If an output would be replaced by itself, no need to perform # If an output would be replaced by itself, no need to perform
...@@ -1924,7 +1938,9 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1924,7 +1938,9 @@ class NavigatorOptimizer(GraphRewriter):
if len(repl_pairs) == 0: if len(repl_pairs) == 0:
return False return False
try: try:
fgraph.replace_all_validate_remove(repl_pairs, reason=lopt, remove=remove) fgraph.replace_all_validate_remove( # type: ignore
repl_pairs, reason=node_rewriter, remove=remove
)
return True return True
except Exception as e: except Exception as e:
# This means the replacements were rejected by the fgraph. # This means the replacements were rejected by the fgraph.
...@@ -1932,7 +1948,7 @@ class NavigatorOptimizer(GraphRewriter): ...@@ -1932,7 +1948,7 @@ class NavigatorOptimizer(GraphRewriter):
# This is not supposed to happen. The default failure_callback # This is not supposed to happen. The default failure_callback
# will print a traceback as a warning. # will print a traceback as a warning.
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, repl_pairs, lopt, node) self.failure_callback(e, self, repl_pairs, node_rewriter, node)
return False return False
else: else:
raise raise
...@@ -2027,7 +2043,7 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -2027,7 +2043,7 @@ class TopoOptimizer(NavigatorOptimizer):
io_t, io_t,
loop_t, loop_t,
callback_time, callback_time,
lopt, node_rewriter,
) = prof ) = prof
print( print(
...@@ -2046,16 +2062,16 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -2046,16 +2062,16 @@ class TopoOptimizer(NavigatorOptimizer):
print(blanc, " init io_toposort", io_t, file=stream) print(blanc, " init io_toposort", io_t, file=stream)
print(blanc, " loop time", loop_t, file=stream) print(blanc, " loop time", loop_t, file=stream)
print(blanc, " callback_time", callback_time, file=stream) print(blanc, " callback_time", callback_time, file=stream)
if isinstance(lopt, LocalOptGroup): if isinstance(node_rewriter, LocalOptGroup):
if lopt.profile: if node_rewriter.profile:
lopt.print_profile( node_rewriter.print_profile(
stream, stream,
( (
lopt.time_opts, node_rewriter.time_opts,
lopt.process_count, node_rewriter.process_count,
lopt.applied_true, node_rewriter.applied_true,
lopt.node_created, node_rewriter.node_created,
lopt.profile, node_rewriter.profile,
), ),
level=level + 1, level=level + 1,
) )
...@@ -2228,11 +2244,11 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2228,11 +2244,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.global_optimizers: List[GraphRewriter] = [] self.global_optimizers: List[GraphRewriter] = []
self.tracks_on_change_inputs = tracks_on_change_inputs self.tracks_on_change_inputs = tracks_on_change_inputs
self.local_tracker = LocalOptTracker() self.node_tracker = LocalOptTracker()
for opt in optimizers: for opt in optimizers:
if isinstance(opt, NodeRewriter): if isinstance(opt, NodeRewriter):
self.local_tracker.add_tracker(opt) self.node_tracker.add_tracker(opt)
else: else:
assert isinstance(opt, GraphRewriter) assert isinstance(opt, GraphRewriter)
self.global_optimizers.append(opt) self.global_optimizers.append(opt)
...@@ -2250,7 +2266,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2250,7 +2266,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.max_use_ratio = max_use_ratio self.max_use_ratio = max_use_ratio
def get_node_rewriters(self): def get_node_rewriters(self):
yield from self.local_tracker.get_rewriters() yield from self.node_tracker.get_rewriters()
def get_local_optimizers(self): def get_local_optimizers(self):
warnings.warn( warnings.warn(
...@@ -2357,11 +2373,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2357,11 +2373,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
global_opt_timing.append(float(time.time() - t0)) global_opt_timing.append(float(time.time() - t0))
# apply clean up as global opt can have done changes that
# request that
changed |= apply_cleanup(iter_cleanup_sub_profs) changed |= apply_cleanup(iter_cleanup_sub_profs)
# apply local optimizer
topo_t0 = time.time() topo_t0 = time.time()
q = deque(io_toposort(fgraph.inputs, start_from)) q = deque(io_toposort(fgraph.inputs, start_from))
io_toposort_timing.append(time.time() - topo_t0) io_toposort_timing.append(time.time() - topo_t0)
...@@ -2390,23 +2403,25 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2390,23 +2403,25 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if node not in fgraph.apply_nodes: if node not in fgraph.apply_nodes:
continue continue
current_node = node current_node = node
for lopt in self.local_tracker.get_trackers(node.op): for node_rewriter in self.node_tracker.get_trackers(node.op):
nb = change_tracker.nb_imported nb = change_tracker.nb_imported
t_opt = time.time() t_opt = time.time()
lopt_change = self.process_node(fgraph, node, lopt) node_rewriter_change = self.process_node(
time_opts[lopt] += time.time() - t_opt fgraph, node, node_rewriter
if not lopt_change: )
time_opts[node_rewriter] += time.time() - t_opt
if not node_rewriter_change:
continue continue
process_count.setdefault(lopt, 0) process_count.setdefault(node_rewriter, 0)
process_count[lopt] += 1 process_count[node_rewriter] += 1
global_process_count[lopt] += 1 global_process_count[node_rewriter] += 1
changed = True changed = True
node_created[lopt] += change_tracker.nb_imported - nb node_created[node_rewriter] += change_tracker.nb_imported - nb
changed |= apply_cleanup(iter_cleanup_sub_profs) changed |= apply_cleanup(iter_cleanup_sub_profs)
if global_process_count[lopt] > max_use: if global_process_count[node_rewriter] > max_use:
max_use_abort = True max_use_abort = True
opt_name = getattr(lopt, "name", None) or getattr( opt_name = getattr(node_rewriter, "name", None) or getattr(
lopt, "__name__", "" node_rewriter, "__name__", ""
) )
if node not in fgraph.apply_nodes: if node not in fgraph.apply_nodes:
# go to next node # go to next node
...@@ -2494,8 +2509,10 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2494,8 +2509,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
f"{' ' * level}{self.__class__.__name__} {name} id={id(self)}", file=stream f"{' ' * level}{self.__class__.__name__} {name} id={id(self)}", file=stream
) )
if depth != 0: if depth != 0:
for lopt in self.get_node_rewriters(): for node_rewriter in self.get_node_rewriters():
lopt.print_summary(stream, level=(level + 2), depth=(depth - 1)) node_rewriter.print_summary(
stream, level=(level + 2), depth=(depth - 1)
)
@staticmethod @staticmethod
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
...@@ -2529,27 +2546,27 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2529,27 +2546,27 @@ class EquilibriumOptimizer(NavigatorOptimizer):
) )
print(blanc, f" time io_toposort {sum(io_toposort_timing):.3f}s", file=stream) print(blanc, f" time io_toposort {sum(io_toposort_timing):.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.get_node_rewriters()) s = sum(time_opts[o] for o in opt.get_node_rewriters())
print(blanc, f" time in local optimizers {s:.3f}s", file=stream) print(blanc, f" time in node rewriters {s:.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.global_optimizers) s = sum(time_opts[o] for o in opt.global_optimizers)
print(blanc, f" time in global optimizers {s:.3f}s", file=stream) print(blanc, f" time in graph rewriters {s:.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.final_optimizers) s = sum(time_opts[o] for o in opt.final_optimizers)
print(blanc, f" time in final optimizers {s:.3f}s", file=stream) print(blanc, f" time in final rewriters {s:.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.cleanup_optimizers) s = sum(time_opts[o] for o in opt.cleanup_optimizers)
print(blanc, f" time in cleanup optimizers {s:.3f}s", file=stream) print(blanc, f" time in cleanup rewriters {s:.3f}s", file=stream)
for i in range(len(loop_timing)): for i in range(len(loop_timing)):
lopt = "" loop_times = ""
if loop_process_count[i]: if loop_process_count[i]:
d = list( d = list(
reversed(sorted(loop_process_count[i].items(), key=lambda a: a[1])) reversed(sorted(loop_process_count[i].items(), key=lambda a: a[1]))
) )
lopt = " ".join([str((str(k), v)) for k, v in d[:5]]) loop_times = " ".join([str((str(k), v)) for k, v in d[:5]])
if len(d) > 5: if len(d) > 5:
lopt += " ..." loop_times += " ..."
print( print(
blanc, blanc,
( (
f" {int(i):2d} - {loop_timing[i]:.3f}s {int(sum(loop_process_count[i].values()))} ({global_opt_timing[i]:.3f}s in global opts, " f" {int(i):2d} - {loop_timing[i]:.3f}s {int(sum(loop_process_count[i].values()))} ({global_opt_timing[i]:.3f}s in graph rewriters, "
f"{io_toposort_timing[i]:.3f}s io_toposort) - {int(nb_nodes[i])} nodes - {lopt}" f"{io_toposort_timing[i]:.3f}s io_toposort) - {int(nb_nodes[i])} nodes - {loop_times}"
), ),
file=stream, file=stream,
) )
...@@ -2784,8 +2801,10 @@ def check_chain(r, *chain): ...@@ -2784,8 +2801,10 @@ def check_chain(r, *chain):
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain))) return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
def pre_greedy_node_rewriter(fgraph, optimizations, out): def pre_greedy_node_rewriter(
"""Apply local optimizations to a graph. fgraph: FunctionGraph, optimizations: Sequence[NodeRewriter], out: Variable
) -> Variable:
"""Apply node rewriters throughout a graph in a greedy, pre-traversal way.
This function traverses the computation graph in the graph before the This function traverses the computation graph in the graph before the
variable `out` but that are not in the `fgraph`. It applies variable `out` but that are not in the `fgraph`. It applies
...@@ -2796,7 +2815,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out): ...@@ -2796,7 +2815,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
This changes the nodes in a graph in-place. This changes the nodes in a graph in-place.
Its main use is to apply locally constant folding when generating Its main use is to apply locally constant folding when generating
the graph of the indices of a subtensor. the graph of the indices of a `Subtensor`.
Changes should not be applied to nodes that are in an `fgraph`, Changes should not be applied to nodes that are in an `fgraph`,
so we use `fgraph` to prevent that. so we use `fgraph` to prevent that.
...@@ -2810,16 +2829,21 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out): ...@@ -2810,16 +2829,21 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
Parameters Parameters
---------- ----------
fgraph : FunctionGraph fgraph
The graph used to avoid/filter nodes. The graph used to avoid/filter nodes.
optimizations : list of NodeRewriter optimizations
The list of local optimizations to apply A sequence of rewrites to apply.
out : Variable out
A `Variable` specifying the graph to optimize. The graph to optimize.
""" """
def local_recursive_function(list_opt, out, optimized_vars, depth): def local_recursive_function(
list_opt: Sequence[NodeRewriter],
out: Variable,
optimized_vars: Dict[Variable, Variable],
depth: int,
) -> Tuple[List[Variable], Dict[Variable, Variable]]:
if not getattr(out, "owner", None): if not getattr(out, "owner", None):
return [out], optimized_vars return [out], optimized_vars
node = out.owner node = out.owner
...@@ -2852,6 +2876,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out): ...@@ -2852,6 +2876,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
for opt in list_opt: for opt in list_opt:
ret = opt.transform(fgraph, node) ret = opt.transform(fgraph, node)
if ret is not False and ret is not None: if ret is not False and ret is not None:
assert isinstance(ret, Sequence)
assert len(ret) == len(node.outputs), opt assert len(ret) == len(node.outputs), opt
for k, v in zip(node.outputs, ret): for k, v in zip(node.outputs, ret):
optimized_vars[k] = v optimized_vars[k] = v
...@@ -2864,7 +2889,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out): ...@@ -2864,7 +2889,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
return results, optimized_vars return results, optimized_vars
if out.owner: if out.owner:
out_index = out.owner.outputs.index(out) out_index: int = out.owner.outputs.index(out)
else: else:
out_index = 0 out_index = 0
......
...@@ -290,55 +290,40 @@ class OptimizationQuery: ...@@ -290,55 +290,40 @@ class OptimizationQuery:
class EquilibriumDB(OptimizationDatabase): class EquilibriumDB(OptimizationDatabase):
""" """A database of rewrites that should be applied until equilibrium is reached.
A set of potential optimizations which should be applied in an arbitrary
order until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium optimizations. Canonicalize, Stabilize, and Specialize are all equilibrium optimizations.
Parameters
----------
ignore_newtrees
If False, we will apply local opt on new node introduced during local
optimization application. This could result in less fgraph iterations,
but this doesn't mean it will be faster globally.
tracks_on_change_inputs
If True, we will re-apply local opt on nodes whose inputs
changed during local optimization application. This could
result in less fgraph iterations, but this doesn't mean it
will be faster globally.
Notes Notes
----- -----
We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumOptimizer` We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumOptimizer`
supports both. supports both.
It is probably not a good idea to have ignore_newtrees=False and It is probably not a good idea to have both ``ignore_newtrees == False``
tracks_on_change_inputs=True and ``tracks_on_change_inputs == True``.
""" """
def __init__(self, ignore_newtrees=True, tracks_on_change_inputs=False): def __init__(
self, ignore_newtrees: bool = True, tracks_on_change_inputs: bool = False
):
""" """
Parameters Parameters
========== ----------
ignore_newtrees: ignore_newtrees
If False, we will apply local opt on new node introduced during local If ``False``, apply rewrites to new nodes introduced during
optimization application. This could result in less fgraph iterations, rewriting.
but this doesn't mean it will be faster globally.
tracks_on_change_inputs
tracks_on_change_inputs: If ``True``, re-apply rewrites on nodes with changed inputs.
If True, we will re-apply local opt on nodes whose inputs
changed during local optimization application. This could
result in less fgraph iterations, but this doesn't mean it
will be faster globally.
""" """
super().__init__() super().__init__()
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.tracks_on_change_inputs = tracks_on_change_inputs self.tracks_on_change_inputs = tracks_on_change_inputs
self.__final__ = {} self.__final__: Dict[str, aesara_opt.Rewriter] = {}
self.__cleanup__ = {} self.__cleanup__: Dict[str, aesara_opt.Rewriter] = {}
def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwargs): def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwargs):
if final_opt and cleanup: if final_opt and cleanup:
......
...@@ -6,7 +6,7 @@ import time ...@@ -6,7 +6,7 @@ import time
import traceback import traceback
from collections import defaultdict from collections import defaultdict
from io import StringIO from io import StringIO
from typing import Optional from typing import Optional, Union
import numpy as np import numpy as np
...@@ -28,13 +28,15 @@ from aesara.graph.fg import FunctionGraph ...@@ -28,13 +28,15 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value, get_test_value from aesara.graph.op import compute_test_value, get_test_value
from aesara.graph.opt import ( from aesara.graph.opt import (
GraphRewriter, GraphRewriter,
NodeRewriter,
OpRemove, OpRemove,
Rewriter,
check_chain, check_chain,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
node_rewriter, node_rewriter,
) )
from aesara.graph.optdb import SequenceDB from aesara.graph.optdb import OptimizationDatabase, SequenceDB
from aesara.graph.utils import ( from aesara.graph.utils import (
InconsistencyError, InconsistencyError,
MethodNotDefined, MethodNotDefined,
...@@ -193,21 +195,19 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -193,21 +195,19 @@ class InplaceElemwiseOptimizer(GraphRewriter):
print(blanc, n, ndim[n], file=stream) print(blanc, n, ndim[n], file=stream)
def apply(self, fgraph): def apply(self, fgraph):
""" r"""
Usage: InplaceElemwiseOptimizer(op).optimize(fgraph)
Attempts to replace all Broadcast ops by versions of them Attempts to replace all `Elemwise`\s by versions of them that operate
that operate inplace. It operates greedily: for each Broadcast inplace. It operates greedily: for each `Elemwise` that is encountered,
Op that is encountered, for each output, tries each input to for each output, it tries each input to see if it can operate inplace
see if it can operate inplace on that input. If so, makes the on that input. If so, it makes the change and goes to the next output
change and go to the next output or Broadcast Op. or `Elemwise`.
Examples Examples
-------- --------
`x + y + z -> x += y += z` x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
`(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)`
""" """
# We should not validate too often as this takes too much time to # We should not validate too often as this takes too much time to
...@@ -225,7 +225,7 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -225,7 +225,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
# Maybe Aesara should do online toposort as in # Maybe Aesara should do online toposort as in
# http://code.google.com/p/acyclic # http://code.google.com/p/acyclic
# #
# The next longest optimizer is the canonizer phase. # The next longest rewriter is the canonizer phase.
# Then I think it is the [io_?]toposort (need to validate) so check if # Then I think it is the [io_?]toposort (need to validate) so check if
# the solution is also applicable there. # the solution is also applicable there.
...@@ -429,8 +429,8 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -429,8 +429,8 @@ class InplaceElemwiseOptimizer(GraphRewriter):
if check_each_change != 1 and not raised_warning: if check_each_change != 1 and not raised_warning:
print( print(
( (
"Some inplace optimization was not " "Some inplace rewriting was not "
"performed due to unexpected error:" "performed due to an unexpected error:"
), ),
file=sys.stderr, file=sys.stderr,
) )
...@@ -450,8 +450,8 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -450,8 +450,8 @@ class InplaceElemwiseOptimizer(GraphRewriter):
if not raised_warning: if not raised_warning:
print( print(
( (
"Some inplace optimization was not " "Some inplace rewriting was not "
"performed due to unexpected error" "performed due to an unexpected error"
), ),
file=sys.stderr, file=sys.stderr,
) )
...@@ -478,91 +478,111 @@ compile.optdb.register( ...@@ -478,91 +478,111 @@ compile.optdb.register(
) )
def register_useless(lopt, *tags, **kwargs): def register_useless(
if isinstance(lopt, str): node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_lopt): def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
return register_useless(inner_lopt, lopt, *tags, **kwargs) return register_useless(inner_rewriter, node_rewriter, *tags, **kwargs)
return register return register
else: else:
name = kwargs.pop("name", None) or lopt.__name__ name = kwargs.pop("name", None) or node_rewriter.__name__
compile.mode.local_useless.register( compile.mode.local_useless.register(
name, lopt, "fast_run", *tags, position="last", **kwargs name, node_rewriter, "fast_run", *tags, position="last", **kwargs
) )
return lopt return node_rewriter
def register_canonicalize(lopt, *tags, **kwargs): def register_canonicalize(
if isinstance(lopt, str): node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_lopt): def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
return register_canonicalize(inner_lopt, lopt, *tags, **kwargs) return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register return register
else: else:
name = kwargs.pop("name", None) or lopt.__name__ name = kwargs.pop("name", None) or node_rewriter.__name__
compile.optdb["canonicalize"].register( compile.optdb["canonicalize"].register(
name, lopt, "fast_run", "fast_compile", *tags, **kwargs name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
) )
return lopt return node_rewriter
def register_stabilize(lopt, *tags, **kwargs): def register_stabilize(
if isinstance(lopt, str): node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_lopt): def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
return register_stabilize(inner_lopt, lopt, *tags, **kwargs) return register_stabilize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register return register
else: else:
name = kwargs.pop("name", None) or lopt.__name__ name = kwargs.pop("name", None) or node_rewriter.__name__
compile.optdb["stabilize"].register(name, lopt, "fast_run", *tags, **kwargs) compile.optdb["stabilize"].register(
return lopt name, node_rewriter, "fast_run", *tags, **kwargs
)
return node_rewriter
def register_specialize(lopt, *tags, **kwargs): def register_specialize(
if isinstance(lopt, str): node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_lopt): def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
return register_specialize(inner_lopt, lopt, *tags, **kwargs) return register_specialize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register return register
else: else:
name = kwargs.pop("name", None) or lopt.__name__ name = kwargs.pop("name", None) or node_rewriter.__name__
compile.optdb["specialize"].register(name, lopt, "fast_run", *tags, **kwargs) compile.optdb["specialize"].register(
return lopt name, node_rewriter, "fast_run", *tags, **kwargs
)
return node_rewriter
def register_uncanonicalize(lopt, *tags, **kwargs): def register_uncanonicalize(
if isinstance(lopt, str): node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_lopt): def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
return register_uncanonicalize(inner_lopt, lopt, *tags, **kwargs) return register_uncanonicalize(
inner_rewriter, node_rewriter, *tags, **kwargs
)
return register return register
else: else:
name = (kwargs and kwargs.pop("name", None)) or lopt.__name__ name = (kwargs and kwargs.pop("name", None)) or node_rewriter.__name__
compile.optdb["uncanonicalize"].register( compile.optdb["uncanonicalize"].register(
name, lopt, "fast_run", *tags, **kwargs name, node_rewriter, "fast_run", *tags, **kwargs
) )
return lopt return node_rewriter
def register_specialize_device(lopt, *tags, **kwargs): def register_specialize_device(
if isinstance(lopt, str): node_rewriter: Union[OptimizationDatabase, Rewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_lopt): def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
return register_specialize_device(inner_lopt, lopt, *tags, **kwargs) return register_specialize_device(
inner_rewriter, node_rewriter, *tags, **kwargs
)
return register return register
else: else:
name = (kwargs and kwargs.pop("name", None)) or lopt.__name__ name = (kwargs and kwargs.pop("name", None)) or node_rewriter.__name__
compile.optdb["specialize_device"].register( compile.optdb["specialize_device"].register(
name, lopt, "fast_run", *tags, **kwargs name, node_rewriter, "fast_run", *tags, **kwargs
) )
return lopt return node_rewriter
def apply_local_dimshuffle_lift(fgraph, var): def apply_local_dimshuffle_lift(fgraph, var):
...@@ -762,19 +782,17 @@ pprint.assign(MakeVector, MakeVectorPrinter()) ...@@ -762,19 +782,17 @@ pprint.assign(MakeVector, MakeVectorPrinter())
class ShapeFeature(Feature): class ShapeFeature(Feature):
"""Graph optimizer for removing all calls to shape(). r"""A `Feature` that tracks shape information in a graph.
This optimizer replaces all Shapes and Subtensors of Shapes with This `Feature` aids in the replacement of all `Shape`\s and `Subtensor`\s of `Shape`\s with
Shape_i and MakeVector Ops. `Shape_i` and `MakeVector` `Op`\s.
This optimizer has several goals: This `Feature` and its associated rewrites have several goals:
1. to 'lift' Shapes to as close to the inputs as possible.
1. to "lift" `Shape`\s to as close to the inputs as possible,
2. to infer the shape of every node in the graph in terms of the 2. to infer the shape of every node in the graph in terms of the
input shapes. input shapes, and
3. remove fill `Op`\s (e.g. `Second`) from the graph.
3. remove all fills ``(at.second, at.fill)`` from the graph
Lifting shapes as close to the inputs as possible is important for Lifting shapes as close to the inputs as possible is important for
canonicalization because it is very bad form to have to compute canonicalization because it is very bad form to have to compute
...@@ -782,7 +800,7 @@ class ShapeFeature(Feature): ...@@ -782,7 +800,7 @@ class ShapeFeature(Feature):
of time to compute such outputs. But it is important to get rid of time to compute such outputs. But it is important to get rid
of these outputs as early as possible in the compilation process of these outputs as early as possible in the compilation process
because the extra computations make it appear as if many internal because the extra computations make it appear as if many internal
graph nodes have multiple clients. Many optimizations refuse to graph nodes have multiple clients. Many rewrites refuse to
work on nodes with multiple clients. work on nodes with multiple clients.
Lifting is done by using an `<Op>.infer_shape` function if one is Lifting is done by using an `<Op>.infer_shape` function if one is
...@@ -802,7 +820,7 @@ class ShapeFeature(Feature): ...@@ -802,7 +820,7 @@ class ShapeFeature(Feature):
input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),). input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
Inferring the shape of internal nodes in the graph is important Inferring the shape of internal nodes in the graph is important
for doing size-driven optimizations. If we know how big various for doing size-driven rewrites. If we know how big various
intermediate results will be, we can estimate the cost of many Ops intermediate results will be, we can estimate the cost of many Ops
accurately, and generate c-code that is specific [e.g. unrolled] accurately, and generate c-code that is specific [e.g. unrolled]
to particular sizes. to particular sizes.
...@@ -818,14 +836,12 @@ class ShapeFeature(Feature): ...@@ -818,14 +836,12 @@ class ShapeFeature(Feature):
shape, either via a .tag or some similar hacking; and 2) to shape, either via a .tag or some similar hacking; and 2) to
add an optional In() argument to promise that inputs will add an optional In() argument to promise that inputs will
have a certain shape (or even to have certain shapes in have a certain shape (or even to have certain shapes in
certain dimensions). We can't automatically infer the shape of certain dimensions).
shared variables as they can change of shape during the
execution by default. (NOT IMPLEMENTED YET, BUT IS IN TRAC)
**Using Shape information in Optimizations** We can't automatically infer the shape of shared variables as they can
change of shape during the execution by default.
To use this shape information in OPTIMIZATIONS, use the To use this shape information in rewrites, use the
``shape_of`` dictionary. ``shape_of`` dictionary.
For example: For example:
...@@ -888,10 +904,10 @@ class ShapeFeature(Feature): ...@@ -888,10 +904,10 @@ class ShapeFeature(Feature):
return o_shapes return o_shapes
def get_shape(self, var, idx): def get_shape(self, var, idx):
"""Optimization can call this to get the current shape_i """Rewrites can call this to get a `Shape_i`.
It is better to call this then use directly shape_of[var][idx] It is better to call this then use directly ``shape_of[var][idx]``
as this method should update shape_of if needed. as this method should update `shape_of` if needed.
TODO: Up to now, we don't update it in all cases. Update in all cases. TODO: Up to now, we don't update it in all cases. Update in all cases.
""" """
...@@ -977,11 +993,9 @@ class ShapeFeature(Feature): ...@@ -977,11 +993,9 @@ class ShapeFeature(Feature):
error reporting. error reporting.
""" """
# unpack the s_i that the Op returned
assert s_i is not None assert s_i is not None
if s_i == 1: if s_i == 1:
# don't make the optimizer merge a zillion ones together
# by always returning the same object to represent 1
return self.lscalar_one return self.lscalar_one
if isinstance(s_i, float) and int(s_i) == s_i: if isinstance(s_i, float) and int(s_i) == s_i:
s_i = int(s_i) s_i = int(s_i)
...@@ -1080,10 +1094,9 @@ class ShapeFeature(Feature): ...@@ -1080,10 +1094,9 @@ class ShapeFeature(Feature):
else: else:
shape_vars.append(self.unpack(s[i], r)) shape_vars.append(self.unpack(s[i], r))
assert all( assert all(
not hasattr(r.type, "broadcastable") or not r.type.broadcastable[i] or not hasattr(r.type, "broadcastable")
# The two following comparison are a speed optimization or not r.type.broadcastable[i]
# But we never timed this speed optimization! or self.lscalar_one.equals(shape_vars[i])
self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i])) or self.lscalar_one.equals(extract_constant(shape_vars[i]))
for i in range(r.type.ndim) for i in range(r.type.ndim)
) )
...@@ -1118,9 +1131,9 @@ class ShapeFeature(Feature): ...@@ -1118,9 +1131,9 @@ class ShapeFeature(Feature):
and other_r.owner.inputs == r.owner.inputs and other_r.owner.inputs == r.owner.inputs
and other_r.owner.op == r.owner.op and other_r.owner.op == r.owner.op
): ):
# We are doing a merge. So the 2 shapes graph will be the # We are doing a merge, so the two shape graphs will be the
# same. This is only a speed optimization to call # same. This is only done so that we call `ancestors` less
# ancestors() less frequently. # frequently.
return return
# Merge other_shape with r_shape, giving the priority to other_shape # Merge other_shape with r_shape, giving the priority to other_shape
...@@ -1168,10 +1181,7 @@ class ShapeFeature(Feature): ...@@ -1168,10 +1181,7 @@ class ShapeFeature(Feature):
or not r.type.broadcastable[i] or not r.type.broadcastable[i]
and not other_r.type.broadcastable[i] and not other_r.type.broadcastable[i]
) )
or or self.lscalar_one.equals(merged_shape[i])
# The two following comparison are a speed optimization
# But we never timed this speed optimization!
self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals( or self.lscalar_one.equals(
extract_constant(merged_shape[i], only_process_constants=True) extract_constant(merged_shape[i], only_process_constants=True)
) )
...@@ -1194,10 +1204,9 @@ class ShapeFeature(Feature): ...@@ -1194,10 +1204,9 @@ class ShapeFeature(Feature):
else: else:
new_shape.append(s_j) new_shape.append(s_j)
assert all( assert all(
not hasattr(r.type, "broadcastable") or not r.type.broadcastable[idx] or not hasattr(r.type, "broadcastable")
# The two following comparison are a speed optimization or not r.type.broadcastable[idx]
# But we never timed this speed optimization! or self.lscalar_one.equals(new_shape[idx])
self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx])) or self.lscalar_one.equals(extract_constant(new_shape[idx]))
for idx in range(r.type.ndim) for idx in range(r.type.ndim)
) )
...@@ -1273,7 +1282,7 @@ class ShapeFeature(Feature): ...@@ -1273,7 +1282,7 @@ class ShapeFeature(Feature):
) )
# Ensure shapes are in 'int64'. This is to make sure the assert # Ensure shapes are in 'int64'. This is to make sure the assert
# found in the `local_useless_subtensor` optimization does not fail. # found in the `local_useless_subtensor` rewrite does not fail.
for sh_idx, sh in enumerate(o_shapes): for sh_idx, sh in enumerate(o_shapes):
if sh is None: if sh is None:
continue continue
...@@ -1444,7 +1453,7 @@ class ShapeFeature(Feature): ...@@ -1444,7 +1453,7 @@ class ShapeFeature(Feature):
class ShapeOptimizer(GraphRewriter): class ShapeOptimizer(GraphRewriter):
"""Optimizer that adds `ShapeFeature` as a feature.""" """Rewriter that adds `ShapeFeature` as a feature."""
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(ShapeFeature()) fgraph.attach_feature(ShapeFeature())
...@@ -1454,7 +1463,7 @@ class ShapeOptimizer(GraphRewriter): ...@@ -1454,7 +1463,7 @@ class ShapeOptimizer(GraphRewriter):
class UnShapeOptimizer(GraphRewriter): class UnShapeOptimizer(GraphRewriter):
"""Optimizer that removes `ShapeFeature` as a feature.""" """Rewriter that removes `ShapeFeature` as a feature."""
def apply(self, fgraph): def apply(self, fgraph):
for feature in fgraph._features: for feature in fgraph._features:
...@@ -1528,7 +1537,7 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1528,7 +1537,7 @@ def local_elemwise_alloc(fgraph, node):
for idx, i in enumerate(node.inputs): for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable: if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not an `Alloc` nor a `DimShuffle` of an # Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
# `Alloc`, so that all `Alloc`s can be optimized. # `Alloc`, so that all `Alloc`s can be rewritten.
if idx not in alloc_idxs: if idx not in alloc_idxs:
ref_var_idx = idx ref_var_idx = idx
break break
...@@ -1585,7 +1594,7 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1585,7 +1594,7 @@ def local_elemwise_alloc(fgraph, node):
new_inputs[idx] = new_alloc new_inputs[idx] = new_alloc
# If this assert is triggered, it means we are recreating an equivalent graph # If this assert is triggered, it means we are recreating an equivalent graph
# which would result in a cyclical merge optimization. # which would result in cyclical merge rewrites.
if all(new is old for new, old in zip(new_inputs, node.inputs)): if all(new is old for new, old in zip(new_inputs, node.inputs)):
return return
...@@ -1702,9 +1711,9 @@ compile.optdb.register( ...@@ -1702,9 +1711,9 @@ compile.optdb.register(
def local_useless_fill(fgraph, node): def local_useless_fill(fgraph, node):
"""fill(s,v) -> v """fill(s,v) -> v
This optimization is only needed in FAST_COMPILE to make the code This rewrite is only needed in FAST_COMPILE mode to make the code
more readable. Normally, it is done by the local_fill_to_alloc more readable. Normally, it is done by the `local_fill_to_alloc`
opt. rewrite.
""" """
r, v = node.inputs r, v = node.inputs
...@@ -1789,9 +1798,9 @@ def local_alloc_sink_dimshuffle(fgraph, node): ...@@ -1789,9 +1798,9 @@ def local_alloc_sink_dimshuffle(fgraph, node):
def local_alloc_empty_to_zeros(fgraph, node): def local_alloc_empty_to_zeros(fgraph, node):
"""This convert AllocEmpty to Alloc of 0. """This convert AllocEmpty to Alloc of 0.
This help investigate NaN with NanGuardMode. Not registered by This helps one investigate NaNs in `NanGuardMode`. Not registered by
default. To activate it, use the Aesara flag default. To activate it, use the setting
optimizer_including=alloc_empty_to_zeros. ``optimizer_including == alloc_empty_to_zeros``.
""" """
if isinstance(node.op, AllocEmpty): if isinstance(node.op, AllocEmpty):
return [zeros(node.inputs, dtype=node.outputs[0].dtype)] return [zeros(node.inputs, dtype=node.outputs[0].dtype)]
...@@ -1811,7 +1820,6 @@ compile.optdb.register( ...@@ -1811,7 +1820,6 @@ compile.optdb.register(
@node_rewriter([Shape]) @node_rewriter([Shape])
def local_shape_to_shape_i(fgraph, node): def local_shape_to_shape_i(fgraph, node):
if isinstance(node.op, Shape): if isinstance(node.op, Shape):
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(fgraph, "shape_feature"): if not hasattr(fgraph, "shape_feature"):
return return
shape_feature = fgraph.shape_feature shape_feature = fgraph.shape_feature
...@@ -1850,16 +1858,18 @@ def local_track_shape_i(fgraph, node): ...@@ -1850,16 +1858,18 @@ def local_track_shape_i(fgraph, node):
@node_rewriter([Elemwise]) @node_rewriter([Elemwise])
def local_useless_elemwise(fgraph, node): def local_useless_elemwise(fgraph, node):
""" """
eq(x, x) -> 1 eq(x, x) -> 1
neq(x, x) -> 0 neq(x, x) -> 0
mul(x) -> x mul(x) -> x
add(x) -> x add(x) -> x
identity(x) -> x identity(x) -> x
and(x, 1) -> x (if x.dtype == 'bool') and(x, 1) -> x (if x.dtype == 'bool')
and(x, 0) -> zeros_like(x) and(x, 0) -> zeros_like(x)
or(x, 0) -> x or(x, 0) -> x
or(x, 1) -> ones_like(x) (if x.dtype == 'bool') or(x, 1) -> ones_like(x) (if x.dtype == 'bool')
xor(x, x) -> zeros_like(x) xor(x, x) -> zeros_like(x)
TODO: This implementation is painfully redundant.
""" """
if isinstance(node.op, Elemwise): if isinstance(node.op, Elemwise):
...@@ -1905,7 +1915,7 @@ def local_useless_elemwise(fgraph, node): ...@@ -1905,7 +1915,7 @@ def local_useless_elemwise(fgraph, node):
return [zeros_like(node.inputs[1], dtype=dtype, opt=True)] return [zeros_like(node.inputs[1], dtype=dtype, opt=True)]
elif node.outputs[0].dtype == "bool": elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise AND, # If the output is not Boolean, it is the bitwise AND,
# and this optimization would be wrong # and this rewrite would be wrong
return [node.inputs[1].astype(node.outputs[0].dtype)] return [node.inputs[1].astype(node.outputs[0].dtype)]
if isinstance(node.inputs[1], TensorConstant): if isinstance(node.inputs[1], TensorConstant):
...@@ -1917,7 +1927,7 @@ def local_useless_elemwise(fgraph, node): ...@@ -1917,7 +1927,7 @@ def local_useless_elemwise(fgraph, node):
return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] return [zeros_like(node.inputs[0], dtype=dtype, opt=True)]
elif node.outputs[0].dtype == "bool": elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise AND, # If the output is not Boolean, it is the bitwise AND,
# and this optimization would be wrong # and this rewrite would be wrong
return [node.inputs[0].astype(node.outputs[0].dtype)] return [node.inputs[0].astype(node.outputs[0].dtype)]
elif isinstance(node.op.scalar_op, aes.OR) and len(node.inputs) == 2: elif isinstance(node.op.scalar_op, aes.OR) and len(node.inputs) == 2:
...@@ -1931,7 +1941,7 @@ def local_useless_elemwise(fgraph, node): ...@@ -1931,7 +1941,7 @@ def local_useless_elemwise(fgraph, node):
return [node.inputs[1].astype(node.outputs[0].dtype)] return [node.inputs[1].astype(node.outputs[0].dtype)]
elif node.outputs[0].dtype == "bool": elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise OR, # If the output is not Boolean, it is the bitwise OR,
# and this optimization would be wrong # and this rewrite would be wrong
return [ones_like(node.inputs[1], dtype=dtype, opt=True)] return [ones_like(node.inputs[1], dtype=dtype, opt=True)]
if isinstance(node.inputs[1], TensorConstant): if isinstance(node.inputs[1], TensorConstant):
...@@ -1943,7 +1953,7 @@ def local_useless_elemwise(fgraph, node): ...@@ -1943,7 +1953,7 @@ def local_useless_elemwise(fgraph, node):
return [node.inputs[0].astype(node.outputs[0].dtype)] return [node.inputs[0].astype(node.outputs[0].dtype)]
elif node.outputs[0].dtype == "bool": elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise OR, # If the output is not Boolean, it is the bitwise OR,
# and this optimization would be wrong # and this rewrite would be wrong
return [ones_like(node.inputs[0], dtype=dtype, opt=True)] return [ones_like(node.inputs[0], dtype=dtype, opt=True)]
elif isinstance(node.op.scalar_op, aes.XOR) and len(node.inputs) == 2: elif isinstance(node.op.scalar_op, aes.XOR) and len(node.inputs) == 2:
...@@ -2081,12 +2091,11 @@ def local_remove_useless_assert(fgraph, node): ...@@ -2081,12 +2091,11 @@ def local_remove_useless_assert(fgraph, node):
@node_rewriter([Assert]) @node_rewriter([Assert])
def local_remove_all_assert(fgraph, node): def local_remove_all_assert(fgraph, node):
"""An optimization disabled by default that removes all asserts from r"""A rewrite that removes all `Assert`\s from a graph.
the graph.
Notes Notes
----- -----
See the :ref:`unsafe` section to know how to enable it. See the :ref:`unsafe` section.
""" """
if not isinstance(node.op, Assert): if not isinstance(node.op, Assert):
...@@ -2346,7 +2355,7 @@ def local_join_make_vector(fgraph, node): ...@@ -2346,7 +2355,7 @@ def local_join_make_vector(fgraph, node):
Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...) Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...)
This in combination with the `local_join_1` optimization can make `Join`\s This, in combination with the `local_join_1` rewrite, can make `Join`\s
completely disappear. completely disappear.
""" """
if not isinstance(node.op, Join) or node.outputs[0].ndim != 1: if not isinstance(node.op, Join) or node.outputs[0].ndim != 1:
...@@ -2388,16 +2397,16 @@ def local_join_make_vector(fgraph, node): ...@@ -2388,16 +2397,16 @@ def local_join_make_vector(fgraph, node):
@node_rewriter([Elemwise]) @node_rewriter([Elemwise])
def local_useless_switch(fgraph, node): def local_useless_switch(fgraph, node):
""" """
This optimization makes the following changes in the graph: This rewrite makes the following changes in a graph:
``at.switch(cond, left, right)`` -> at.switch(cond, left, right) ->
``if cond is constant and cond == 0``: right if cond is constant and cond == 0: right
``if cond is constant and cond != 0``: left if cond is constant and cond != 0: left
``if left is right`` -> ``left`` if left is right -> left
and and
``at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X))`` -> ``shape_i{id}(X)`` at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
""" """
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Switch): if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Switch):
...@@ -2545,7 +2554,7 @@ def local_reshape_chain(op): ...@@ -2545,7 +2554,7 @@ def local_reshape_chain(op):
# replaced the shape by one for which this cannot be guessed. # replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this # We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this # constant value... but in the meantime, better not apply this
# optimization. # rewrite.
if rval.broadcastable == node.outputs[0].broadcastable: if rval.broadcastable == node.outputs[0].broadcastable:
return [rval] return [rval]
else: else:
...@@ -2709,10 +2718,12 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -2709,10 +2718,12 @@ def local_reshape_to_dimshuffle(fgraph, node):
@node_rewriter([Reshape]) @node_rewriter([Reshape])
def local_reshape_lift(fgraph, node): def local_reshape_lift(fgraph, node):
""" """
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x)) Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
This optimization is needed by optimization Notes
log1msigm_to_softplus to get applied when there is a reshape. -----
This rewrite is needed by `log1msigm_to_softplus` in order to get applied
when there is a reshape.
""" """
if ( if (
...@@ -2840,10 +2851,9 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -2840,10 +2851,9 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
The number of dimensions is validated at call time by Aesara itself. The number of dimensions is validated at call time by Aesara itself.
""" """
# META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!!
# TODO: use broadcast flag? # TODO: use broadcast flag?
# TODO: don't do this optimization as a localOptimizer. # TODO: don't do this rewrite as a `NodeRewriter`.
# Analyze the graph in terms of elemwise subgraphs, and then # Analyze the graph in terms of elemwise subgraphs, and then
# replace each subgraph with a Composite version. # replace each subgraph with a Composite version.
...@@ -2851,8 +2861,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -2851,8 +2861,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
# fit within the parameter space of 256 bytes # fit within the parameter space of 256 bytes
# #
# TODO: Merge with multiple output to merge when an inputs # TODO: Merge with multiple output to merge when an inputs
# have multiple clients. This can't be done with a local # have multiple clients. This can't be done with a `NodeRewriter`
# optimiser.
# TODO: Related: Support composites with multiple outputs # TODO: Related: Support composites with multiple outputs
...@@ -2963,7 +2972,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -2963,7 +2972,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
except (NotImplementedError, MethodNotDefined): except (NotImplementedError, MethodNotDefined):
_logger.warning( _logger.warning(
( (
"Optimization Warning: " "Rewrite warning: "
f"The Op {i.owner.op.scalar_op} does not provide a C implementation." f"The Op {i.owner.op.scalar_op} does not provide a C implementation."
" As well as being potentially slow, this also disables " " As well as being potentially slow, this also disables "
"loop fusion." "loop fusion."
...@@ -3015,10 +3024,9 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -3015,10 +3024,9 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
return False return False
if new_nb_input != len(inputs) or len(s_inputs) != len(inputs): if new_nb_input != len(inputs) or len(s_inputs) != len(inputs):
# TODO FIXME: This shouldn't be a generic `Exception`
raise Exception( raise Exception(
"""Something has gone wrong with the elemwise "Something has gone wrong with the elemwise fusion rewrite; skipping."
fusion optimization. We skip this optimization. You can ignore this message,
your code will run correctly, but may be slower."""
) )
s_new_out = node.op.scalar_op(*s_g, return_list=True) s_new_out = node.op.scalar_op(*s_g, return_list=True)
...@@ -3034,7 +3042,7 @@ your code will run correctly, but may be slower.""" ...@@ -3034,7 +3042,7 @@ your code will run correctly, but may be slower."""
name = str(s_new_out[0].owner.op) name = str(s_new_out[0].owner.op)
_logger.warning( _logger.warning(
( (
"Optimization Warning: " "Rewrite warning: "
f"The Op {name} does not provide a C implementation." f"The Op {name} does not provide a C implementation."
" As well as being potentially slow, this also disables " " As well as being potentially slow, this also disables "
"loop fusion." "loop fusion."
...@@ -3086,15 +3094,15 @@ local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fc ...@@ -3086,15 +3094,15 @@ local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fc
class FusionOptimizer(GraphRewriter): class FusionOptimizer(GraphRewriter):
"""Graph optimizer that simply runs local fusion operations. """Graph rewriter that simply runs node fusion operations.
TODO: This is basically a `EquilibriumOptimizer`; we should just use that. TODO: This is basically an `EquilibriumOptimizer`; we should just use that.
""" """
def __init__(self, node_rewriter): def __init__(self, node_rewriter):
super().__init__() super().__init__()
self.optimizer = node_rewriter self.node_rewriter = node_rewriter
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate()) fgraph.attach_feature(ReplaceValidate())
...@@ -3118,7 +3126,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -3118,7 +3126,7 @@ class FusionOptimizer(GraphRewriter):
for node in nodelist: for node in nodelist:
# Don't try to fuse node that have already been fused. # Don't try to fuse node that have already been fused.
if node in fgraph.apply_nodes: if node in fgraph.apply_nodes:
new_outputs = self.optimizer(fgraph, node) new_outputs = self.node_rewriter(fgraph, node)
if new_outputs: if new_outputs:
assert len(new_outputs) == len(node.outputs) assert len(new_outputs) == len(node.outputs)
try: try:
...@@ -3174,7 +3182,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -3174,7 +3182,7 @@ class FusionOptimizer(GraphRewriter):
if config.tensor__local_elemwise_fusion: if config.tensor__local_elemwise_fusion:
_logger.debug("Enabling Elemwise fusion optimizations in fast_run") _logger.debug("Enabling Elemwise fusion rewriters in fast_run")
# Must be after gpu(48.5) and before AddDestroyHandler(49.5) # Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt = SequenceDB() fuse_seqopt = SequenceDB()
fuse_seqopt.register( fuse_seqopt.register(
...@@ -3194,7 +3202,7 @@ if config.tensor__local_elemwise_fusion: ...@@ -3194,7 +3202,7 @@ if config.tensor__local_elemwise_fusion:
position=49, position=49,
) )
else: else:
_logger.debug("not enabling optimization fusion elemwise in fast_run") _logger.debug("Not enabling Elemwise fusion rewriters in fast_run")
compile.optdb.register( compile.optdb.register(
"elemwise_fusion", "elemwise_fusion",
FusionOptimizer(local_elemwise_fusion), FusionOptimizer(local_elemwise_fusion),
...@@ -3239,10 +3247,14 @@ def local_view_op(fgraph, node): ...@@ -3239,10 +3247,14 @@ def local_view_op(fgraph, node):
@register_specialize @register_specialize
@node_rewriter([Alloc]) @node_rewriter([Alloc])
def local_merge_alloc(fgraph, node): def local_merge_alloc(fgraph, node):
# This opt takes care of several cases: """
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) This rewriter takes care of the following cases:
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
# Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> Alloc(m, x, assert(y1, y1==y2), z, w) Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> Alloc(m, x, assert(y1, y1==y2), z, w)
"""
if not isinstance(node.op, Alloc): if not isinstance(node.op, Alloc):
return False return False
if not node.inputs[0].owner or not isinstance(node.inputs[0].owner.op, Alloc): if not node.inputs[0].owner or not isinstance(node.inputs[0].owner.op, Alloc):
...@@ -3276,11 +3288,7 @@ def local_merge_alloc(fgraph, node): ...@@ -3276,11 +3288,7 @@ def local_merge_alloc(fgraph, node):
@register_useless("fast_compile") @register_useless("fast_compile")
@node_rewriter([TopKOp]) @node_rewriter([TopKOp])
def local_useless_topk(fgraph, node): def local_useless_topk(fgraph, node):
""" """Remove unused `TopKOp` outputs."""
TopKOp generates two outputs by default
This opt removes the useless ones
"""
op = node.op op = node.op
if not isinstance(op, TopKOp): if not isinstance(op, TopKOp):
return return
......
...@@ -1849,16 +1849,6 @@ crossentropy_categorical_1hot = CrossentropyCategorical1Hot() ...@@ -1849,16 +1849,6 @@ crossentropy_categorical_1hot = CrossentropyCategorical1Hot()
@register_specialize("fast_compile") @register_specialize("fast_compile")
@optimizer @optimizer
def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph): def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
"""
This is a stabilization optimization.
Notes
-----
Not a local optimization because we are replacing outputs
from several nodes at once.
"""
def search_make_one_sub(): def search_make_one_sub():
for node in fgraph.toposort(): for node in fgraph.toposort():
if node.op == crossentropy_categorical_1hot: if node.op == crossentropy_categorical_1hot:
...@@ -1887,18 +1877,13 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph): ...@@ -1887,18 +1877,13 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
@optimizer @optimizer
def crossentropy_to_crossentropy_with_softmax(fgraph): def crossentropy_to_crossentropy_with_softmax(fgraph):
""" """
This is a stabilization optimization that is more general than This is a stabilization rewrite that is more general than
crossentropy_to_crossentropy_with_softmax_with_bias. `crossentropy_to_crossentropy_with_softmax_with_bias`.
It must be executed after local_softmax_with_bias optimization in
specialize.
TODO : This is a stabilization optimization! How to make this more cleanly?
Notes Notes
----- -----
Not a local optimization because we are replacing outputs from several It must be executed after `local_softmax_with_bias` during the
nodes at once. specialization passes.
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论