提交 0ce6eceb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor old global and local optimizers references and type hints

上级 550a6e98
......@@ -11,7 +11,7 @@ from aesara.graph.unify import eval_if_etuple
class KanrenRelationSub(NodeRewriter):
r"""A local optimizer that uses `kanren` to match and replace terms.
r"""A rewriter that uses `kanren` to match and replace terms.
See `kanren <https://github.com/pythological/kanren>`__ for more information
miniKanren and the API for constructing `kanren` goals.
......@@ -56,7 +56,7 @@ class KanrenRelationSub(NodeRewriter):
A function that takes an input graph and an output logic variable and
returns a `kanren` goal.
results_filter
A function that takes the direct output of `kanren.run(None, ...)`
A function that takes the direct output of ``kanren.run(None, ...)``
and returns a single result. The default implementation returns
the first result.
node_filter
......
......@@ -17,7 +17,7 @@ from collections import UserList, defaultdict, deque
from collections.abc import Iterable
from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing_extensions import Literal
......@@ -156,15 +156,20 @@ class NodeRewriter(Rewriter):
@abc.abstractmethod
def transform(
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs
) -> Union[bool, List[Variable], Dict[Variable, Variable]]:
r"""Transform a subgraph whose output is `node`.
) -> Union[
bool,
Sequence[Variable],
Dict[Union[Variable, Literal["remove"]], Union[Variable, Sequence[Variable]]],
]:
r"""Rewrite the sub-graph given by `node`.
Subclasses should implement this function so that it returns one of the
following:
- ``False`` to indicate that this rewrite cannot be applied to `node`
- A list of `Variable`\s to use in place of the `node`'s current outputs
- A ``dict`` mapping old `Variable`\s to new `Variable`\s
- A ``dict`` mapping old `Variable`\s to `Variable`\s, or the key
``"remove"`` mapping to a list of `Variable`\s to be removed.
Parameters
----------
......@@ -1850,10 +1855,15 @@ class NavigatorOptimizer(GraphRewriter):
if u is not None:
fgraph.remove_feature(u)
def process_node(self, fgraph, node, lopt=None):
r"""Apply `lopt` to `node`.
def process_node(
self,
fgraph: FunctionGraph,
node: Apply,
node_rewriter: Optional[NodeRewriter] = None,
):
r"""Apply `node_rewriter` to `node`.
The :meth:`lopt.transform` method will return either ``False`` or a
The :meth:`node_rewriter.transform` method will return either ``False`` or a
list of `Variable`\s that are intended to replace :attr:`node.outputs`.
If the `fgraph` accepts the replacement, then the optimization is
......@@ -1864,11 +1874,11 @@ class NavigatorOptimizer(GraphRewriter):
Parameters
----------
fgraph :
fgraph
A `FunctionGraph`.
node :
node
An `Apply` instance in `fgraph`
lopt :
node_rewriter
A `NodeRewriter` instance that may have a better idea for
how to compute node's outputs.
......@@ -1878,13 +1888,15 @@ class NavigatorOptimizer(GraphRewriter):
``True`` iff the `node`'s outputs were replaced in the `fgraph`.
"""
lopt = lopt or self.node_rewriter
node_rewriter = node_rewriter or self.node_rewriter
# TODO FIXME: This class's interface is broken
assert node_rewriter is not None
try:
replacements = lopt.transform(fgraph, node)
replacements = node_rewriter.transform(fgraph, node)
except Exception as e:
if self.failure_callback is not None:
self.failure_callback(
e, self, [(x, None) for x in node.outputs], lopt, node
e, self, [(x, None) for x in node.outputs], node_rewriter, node
)
return False
else:
......@@ -1892,25 +1904,27 @@ class NavigatorOptimizer(GraphRewriter):
if replacements is False or replacements is None:
return False
old_vars = node.outputs
remove = []
remove: List[Variable] = []
if isinstance(replacements, dict):
if "remove" in replacements:
remove = replacements.pop("remove")
old_vars = list(replacements.keys())
replacements = list(replacements.values())
remove = list(cast(Sequence[Variable], replacements.pop("remove")))
old_vars = list(cast(Sequence[Variable], replacements.keys()))
replacements = list(cast(Sequence[Variable], replacements.values()))
elif not isinstance(replacements, (tuple, list)):
raise TypeError(
f"Node rewriter {lopt} gave wrong type of replacement. "
f"Node rewriter {node_rewriter} gave wrong type of replacement. "
f"Expected list or tuple; got {replacements}"
)
if len(old_vars) != len(replacements):
raise ValueError(f"Node rewriter {lopt} gave wrong number of replacements")
raise ValueError(
f"Node rewriter {node_rewriter} gave wrong number of replacements"
)
# None in the replacement mean that this variable isn't used
# and we want to remove it
for r, rnew in zip(old_vars, replacements):
if rnew is None and len(fgraph.clients[r]) > 0:
raise ValueError(
f"Node rewriter {lopt} tried to remove a variable"
f"Node rewriter {node_rewriter} tried to remove a variable"
f" that is being used: {r}"
)
# If an output would be replaced by itself, no need to perform
......@@ -1924,7 +1938,9 @@ class NavigatorOptimizer(GraphRewriter):
if len(repl_pairs) == 0:
return False
try:
fgraph.replace_all_validate_remove(repl_pairs, reason=lopt, remove=remove)
fgraph.replace_all_validate_remove( # type: ignore
repl_pairs, reason=node_rewriter, remove=remove
)
return True
except Exception as e:
# This means the replacements were rejected by the fgraph.
......@@ -1932,7 +1948,7 @@ class NavigatorOptimizer(GraphRewriter):
# This is not supposed to happen. The default failure_callback
# will print a traceback as a warning.
if self.failure_callback is not None:
self.failure_callback(e, self, repl_pairs, lopt, node)
self.failure_callback(e, self, repl_pairs, node_rewriter, node)
return False
else:
raise
......@@ -2027,7 +2043,7 @@ class TopoOptimizer(NavigatorOptimizer):
io_t,
loop_t,
callback_time,
lopt,
node_rewriter,
) = prof
print(
......@@ -2046,16 +2062,16 @@ class TopoOptimizer(NavigatorOptimizer):
print(blanc, " init io_toposort", io_t, file=stream)
print(blanc, " loop time", loop_t, file=stream)
print(blanc, " callback_time", callback_time, file=stream)
if isinstance(lopt, LocalOptGroup):
if lopt.profile:
lopt.print_profile(
if isinstance(node_rewriter, LocalOptGroup):
if node_rewriter.profile:
node_rewriter.print_profile(
stream,
(
lopt.time_opts,
lopt.process_count,
lopt.applied_true,
lopt.node_created,
lopt.profile,
node_rewriter.time_opts,
node_rewriter.process_count,
node_rewriter.applied_true,
node_rewriter.node_created,
node_rewriter.profile,
),
level=level + 1,
)
......@@ -2228,11 +2244,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.global_optimizers: List[GraphRewriter] = []
self.tracks_on_change_inputs = tracks_on_change_inputs
self.local_tracker = LocalOptTracker()
self.node_tracker = LocalOptTracker()
for opt in optimizers:
if isinstance(opt, NodeRewriter):
self.local_tracker.add_tracker(opt)
self.node_tracker.add_tracker(opt)
else:
assert isinstance(opt, GraphRewriter)
self.global_optimizers.append(opt)
......@@ -2250,7 +2266,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.max_use_ratio = max_use_ratio
def get_node_rewriters(self):
yield from self.local_tracker.get_rewriters()
yield from self.node_tracker.get_rewriters()
def get_local_optimizers(self):
warnings.warn(
......@@ -2357,11 +2373,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
global_opt_timing.append(float(time.time() - t0))
# apply clean up as global opt can have done changes that
# request that
changed |= apply_cleanup(iter_cleanup_sub_profs)
# apply local optimizer
topo_t0 = time.time()
q = deque(io_toposort(fgraph.inputs, start_from))
io_toposort_timing.append(time.time() - topo_t0)
......@@ -2390,23 +2403,25 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if node not in fgraph.apply_nodes:
continue
current_node = node
for lopt in self.local_tracker.get_trackers(node.op):
for node_rewriter in self.node_tracker.get_trackers(node.op):
nb = change_tracker.nb_imported
t_opt = time.time()
lopt_change = self.process_node(fgraph, node, lopt)
time_opts[lopt] += time.time() - t_opt
if not lopt_change:
node_rewriter_change = self.process_node(
fgraph, node, node_rewriter
)
time_opts[node_rewriter] += time.time() - t_opt
if not node_rewriter_change:
continue
process_count.setdefault(lopt, 0)
process_count[lopt] += 1
global_process_count[lopt] += 1
process_count.setdefault(node_rewriter, 0)
process_count[node_rewriter] += 1
global_process_count[node_rewriter] += 1
changed = True
node_created[lopt] += change_tracker.nb_imported - nb
node_created[node_rewriter] += change_tracker.nb_imported - nb
changed |= apply_cleanup(iter_cleanup_sub_profs)
if global_process_count[lopt] > max_use:
if global_process_count[node_rewriter] > max_use:
max_use_abort = True
opt_name = getattr(lopt, "name", None) or getattr(
lopt, "__name__", ""
opt_name = getattr(node_rewriter, "name", None) or getattr(
node_rewriter, "__name__", ""
)
if node not in fgraph.apply_nodes:
# go to next node
......@@ -2494,8 +2509,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
f"{' ' * level}{self.__class__.__name__} {name} id={id(self)}", file=stream
)
if depth != 0:
for lopt in self.get_node_rewriters():
lopt.print_summary(stream, level=(level + 2), depth=(depth - 1))
for node_rewriter in self.get_node_rewriters():
node_rewriter.print_summary(
stream, level=(level + 2), depth=(depth - 1)
)
@staticmethod
def print_profile(stream, prof, level=0):
......@@ -2529,27 +2546,27 @@ class EquilibriumOptimizer(NavigatorOptimizer):
)
print(blanc, f" time io_toposort {sum(io_toposort_timing):.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.get_node_rewriters())
print(blanc, f" time in local optimizers {s:.3f}s", file=stream)
print(blanc, f" time in node rewriters {s:.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.global_optimizers)
print(blanc, f" time in global optimizers {s:.3f}s", file=stream)
print(blanc, f" time in graph rewriters {s:.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.final_optimizers)
print(blanc, f" time in final optimizers {s:.3f}s", file=stream)
print(blanc, f" time in final rewriters {s:.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.cleanup_optimizers)
print(blanc, f" time in cleanup optimizers {s:.3f}s", file=stream)
print(blanc, f" time in cleanup rewriters {s:.3f}s", file=stream)
for i in range(len(loop_timing)):
lopt = ""
loop_times = ""
if loop_process_count[i]:
d = list(
reversed(sorted(loop_process_count[i].items(), key=lambda a: a[1]))
)
lopt = " ".join([str((str(k), v)) for k, v in d[:5]])
loop_times = " ".join([str((str(k), v)) for k, v in d[:5]])
if len(d) > 5:
lopt += " ..."
loop_times += " ..."
print(
blanc,
(
f" {int(i):2d} - {loop_timing[i]:.3f}s {int(sum(loop_process_count[i].values()))} ({global_opt_timing[i]:.3f}s in global opts, "
f"{io_toposort_timing[i]:.3f}s io_toposort) - {int(nb_nodes[i])} nodes - {lopt}"
f" {int(i):2d} - {loop_timing[i]:.3f}s {int(sum(loop_process_count[i].values()))} ({global_opt_timing[i]:.3f}s in graph rewriters, "
f"{io_toposort_timing[i]:.3f}s io_toposort) - {int(nb_nodes[i])} nodes - {loop_times}"
),
file=stream,
)
......@@ -2784,8 +2801,10 @@ def check_chain(r, *chain):
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
def pre_greedy_node_rewriter(fgraph, optimizations, out):
"""Apply local optimizations to a graph.
def pre_greedy_node_rewriter(
fgraph: FunctionGraph, optimizations: Sequence[NodeRewriter], out: Variable
) -> Variable:
"""Apply node rewriters throughout a graph in a greedy, pre-traversal way.
This function traverses the computation graph in the graph before the
variable `out` but that are not in the `fgraph`. It applies
......@@ -2796,7 +2815,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
This changes the nodes in a graph in-place.
Its main use is to apply locally constant folding when generating
the graph of the indices of a subtensor.
the graph of the indices of a `Subtensor`.
Changes should not be applied to nodes that are in an `fgraph`,
so we use `fgraph` to prevent that.
......@@ -2810,16 +2829,21 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
Parameters
----------
fgraph : FunctionGraph
fgraph
The graph used to avoid/filter nodes.
optimizations : list of NodeRewriter
The list of local optimizations to apply
out : Variable
A `Variable` specifying the graph to optimize.
optimizations
A sequence of rewrites to apply.
out
The graph to optimize.
"""
def local_recursive_function(list_opt, out, optimized_vars, depth):
def local_recursive_function(
list_opt: Sequence[NodeRewriter],
out: Variable,
optimized_vars: Dict[Variable, Variable],
depth: int,
) -> Tuple[List[Variable], Dict[Variable, Variable]]:
if not getattr(out, "owner", None):
return [out], optimized_vars
node = out.owner
......@@ -2852,6 +2876,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
for opt in list_opt:
ret = opt.transform(fgraph, node)
if ret is not False and ret is not None:
assert isinstance(ret, Sequence)
assert len(ret) == len(node.outputs), opt
for k, v in zip(node.outputs, ret):
optimized_vars[k] = v
......@@ -2864,7 +2889,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
return results, optimized_vars
if out.owner:
out_index = out.owner.outputs.index(out)
out_index: int = out.owner.outputs.index(out)
else:
out_index = 0
......
......@@ -290,55 +290,40 @@ class OptimizationQuery:
class EquilibriumDB(OptimizationDatabase):
"""
A set of potential optimizations which should be applied in an arbitrary
order until equilibrium is reached.
"""A database of rewrites that should be applied until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium optimizations.
Parameters
----------
ignore_newtrees
If False, we will apply local opt on new node introduced during local
optimization application. This could result in less fgraph iterations,
but this doesn't mean it will be faster globally.
tracks_on_change_inputs
If True, we will re-apply local opt on nodes whose inputs
changed during local optimization application. This could
result in less fgraph iterations, but this doesn't mean it
will be faster globally.
Notes
-----
We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumOptimizer`
supports both.
It is probably not a good idea to have ignore_newtrees=False and
tracks_on_change_inputs=True
It is probably not a good idea to have both ``ignore_newtrees == False``
and ``tracks_on_change_inputs == True``.
"""
def __init__(self, ignore_newtrees=True, tracks_on_change_inputs=False):
def __init__(
self, ignore_newtrees: bool = True, tracks_on_change_inputs: bool = False
):
"""
Parameters
==========
ignore_newtrees:
If False, we will apply local opt on new node introduced during local
optimization application. This could result in less fgraph iterations,
but this doesn't mean it will be faster globally.
tracks_on_change_inputs:
If True, we will re-apply local opt on nodes whose inputs
changed during local optimization application. This could
result in less fgraph iterations, but this doesn't mean it
will be faster globally.
----------
ignore_newtrees
If ``False``, apply rewrites to new nodes introduced during
rewriting.
tracks_on_change_inputs
If ``True``, re-apply rewrites on nodes with changed inputs.
"""
super().__init__()
self.ignore_newtrees = ignore_newtrees
self.tracks_on_change_inputs = tracks_on_change_inputs
self.__final__ = {}
self.__cleanup__ = {}
self.__final__: Dict[str, aesara_opt.Rewriter] = {}
self.__cleanup__: Dict[str, aesara_opt.Rewriter] = {}
def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwargs):
if final_opt and cleanup:
......
......@@ -6,7 +6,7 @@ import time
import traceback
from collections import defaultdict
from io import StringIO
from typing import Optional
from typing import Optional, Union
import numpy as np
......@@ -28,13 +28,15 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value, get_test_value
from aesara.graph.opt import (
GraphRewriter,
NodeRewriter,
OpRemove,
Rewriter,
check_chain,
copy_stack_trace,
in2out,
node_rewriter,
)
from aesara.graph.optdb import SequenceDB
from aesara.graph.optdb import OptimizationDatabase, SequenceDB
from aesara.graph.utils import (
InconsistencyError,
MethodNotDefined,
......@@ -193,21 +195,19 @@ class InplaceElemwiseOptimizer(GraphRewriter):
print(blanc, n, ndim[n], file=stream)
def apply(self, fgraph):
"""
Usage: InplaceElemwiseOptimizer(op).optimize(fgraph)
r"""
Attempts to replace all Broadcast ops by versions of them
that operate inplace. It operates greedily: for each Broadcast
Op that is encountered, for each output, tries each input to
see if it can operate inplace on that input. If so, makes the
change and go to the next output or Broadcast Op.
Attempts to replace all `Elemwise`\s by versions of them that operate
inplace. It operates greedily: for each `Elemwise` that is encountered,
for each output, it tries each input to see if it can operate inplace
on that input. If so, it makes the change and goes to the next output
or `Elemwise`.
Examples
--------
`x + y + z -> x += y += z`
`(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)`
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
# We should not validate too often as this takes too much time to
......@@ -225,7 +225,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
# Maybe Aesara should do online toposort as in
# http://code.google.com/p/acyclic
#
# The next longest optimizer is the canonizer phase.
# The next longest rewriter is the canonizer phase.
# Then I think it is the [io_?]toposort (need to validate) so check if
# the solution is also applicable there.
......@@ -429,8 +429,8 @@ class InplaceElemwiseOptimizer(GraphRewriter):
if check_each_change != 1 and not raised_warning:
print(
(
"Some inplace optimization was not "
"performed due to unexpected error:"
"Some inplace rewriting was not "
"performed due to an unexpected error:"
),
file=sys.stderr,
)
......@@ -450,8 +450,8 @@ class InplaceElemwiseOptimizer(GraphRewriter):
if not raised_warning:
print(
(
"Some inplace optimization was not "
"performed due to unexpected error"
"Some inplace rewriting was not "
"performed due to an unexpected error"
),
file=sys.stderr,
)
......@@ -478,91 +478,111 @@ compile.optdb.register(
)
def register_useless(lopt, *tags, **kwargs):
if isinstance(lopt, str):
def register_useless(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_lopt):
return register_useless(inner_lopt, lopt, *tags, **kwargs)
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
return register_useless(inner_rewriter, node_rewriter, *tags, **kwargs)
return register
else:
name = kwargs.pop("name", None) or lopt.__name__
name = kwargs.pop("name", None) or node_rewriter.__name__
compile.mode.local_useless.register(
name, lopt, "fast_run", *tags, position="last", **kwargs
name, node_rewriter, "fast_run", *tags, position="last", **kwargs
)
return lopt
return node_rewriter
def register_canonicalize(lopt, *tags, **kwargs):
if isinstance(lopt, str):
def register_canonicalize(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_lopt):
return register_canonicalize(inner_lopt, lopt, *tags, **kwargs)
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register
else:
name = kwargs.pop("name", None) or lopt.__name__
name = kwargs.pop("name", None) or node_rewriter.__name__
compile.optdb["canonicalize"].register(
name, lopt, "fast_run", "fast_compile", *tags, **kwargs
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
)
return lopt
return node_rewriter
def register_stabilize(lopt, *tags, **kwargs):
if isinstance(lopt, str):
def register_stabilize(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_lopt):
return register_stabilize(inner_lopt, lopt, *tags, **kwargs)
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
return register_stabilize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register
else:
name = kwargs.pop("name", None) or lopt.__name__
compile.optdb["stabilize"].register(name, lopt, "fast_run", *tags, **kwargs)
return lopt
name = kwargs.pop("name", None) or node_rewriter.__name__
compile.optdb["stabilize"].register(
name, node_rewriter, "fast_run", *tags, **kwargs
)
return node_rewriter
def register_specialize(lopt, *tags, **kwargs):
if isinstance(lopt, str):
def register_specialize(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_lopt):
return register_specialize(inner_lopt, lopt, *tags, **kwargs)
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
return register_specialize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register
else:
name = kwargs.pop("name", None) or lopt.__name__
compile.optdb["specialize"].register(name, lopt, "fast_run", *tags, **kwargs)
return lopt
name = kwargs.pop("name", None) or node_rewriter.__name__
compile.optdb["specialize"].register(
name, node_rewriter, "fast_run", *tags, **kwargs
)
return node_rewriter
def register_uncanonicalize(lopt, *tags, **kwargs):
if isinstance(lopt, str):
def register_uncanonicalize(
node_rewriter: Union[OptimizationDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_lopt):
return register_uncanonicalize(inner_lopt, lopt, *tags, **kwargs)
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
return register_uncanonicalize(
inner_rewriter, node_rewriter, *tags, **kwargs
)
return register
else:
name = (kwargs and kwargs.pop("name", None)) or lopt.__name__
name = (kwargs and kwargs.pop("name", None)) or node_rewriter.__name__
compile.optdb["uncanonicalize"].register(
name, lopt, "fast_run", *tags, **kwargs
name, node_rewriter, "fast_run", *tags, **kwargs
)
return lopt
return node_rewriter
def register_specialize_device(lopt, *tags, **kwargs):
if isinstance(lopt, str):
def register_specialize_device(
node_rewriter: Union[OptimizationDatabase, Rewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_lopt):
return register_specialize_device(inner_lopt, lopt, *tags, **kwargs)
def register(inner_rewriter: Union[OptimizationDatabase, Rewriter]):
return register_specialize_device(
inner_rewriter, node_rewriter, *tags, **kwargs
)
return register
else:
name = (kwargs and kwargs.pop("name", None)) or lopt.__name__
name = (kwargs and kwargs.pop("name", None)) or node_rewriter.__name__
compile.optdb["specialize_device"].register(
name, lopt, "fast_run", *tags, **kwargs
name, node_rewriter, "fast_run", *tags, **kwargs
)
return lopt
return node_rewriter
def apply_local_dimshuffle_lift(fgraph, var):
......@@ -762,19 +782,17 @@ pprint.assign(MakeVector, MakeVectorPrinter())
class ShapeFeature(Feature):
"""Graph optimizer for removing all calls to shape().
r"""A `Feature` that tracks shape information in a graph.
This optimizer replaces all Shapes and Subtensors of Shapes with
Shape_i and MakeVector Ops.
This `Feature` aids in the replacement of all `Shape`\s and `Subtensor`\s of `Shape`\s with
`Shape_i` and `MakeVector` `Op`\s.
This optimizer has several goals:
1. to 'lift' Shapes to as close to the inputs as possible.
This `Feature` and its associated rewrites have several goals:
1. to "lift" `Shape`\s to as close to the inputs as possible,
2. to infer the shape of every node in the graph in terms of the
input shapes.
3. remove all fills ``(at.second, at.fill)`` from the graph
input shapes, and
3. remove fill `Op`\s (e.g. `Second`) from the graph.
Lifting shapes as close to the inputs as possible is important for
canonicalization because it is very bad form to have to compute
......@@ -782,7 +800,7 @@ class ShapeFeature(Feature):
of time to compute such outputs. But it is important to get rid
of these outputs as early as possible in the compilation process
because the extra computations make it appear as if many internal
graph nodes have multiple clients. Many optimizations refuse to
graph nodes have multiple clients. Many rewrites refuse to
work on nodes with multiple clients.
Lifting is done by using an `<Op>.infer_shape` function if one is
......@@ -802,7 +820,7 @@ class ShapeFeature(Feature):
input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
Inferring the shape of internal nodes in the graph is important
for doing size-driven optimizations. If we know how big various
for doing size-driven rewrites. If we know how big various
intermediate results will be, we can estimate the cost of many Ops
accurately, and generate c-code that is specific [e.g. unrolled]
to particular sizes.
......@@ -818,14 +836,12 @@ class ShapeFeature(Feature):
shape, either via a .tag or some similar hacking; and 2) to
add an optional In() argument to promise that inputs will
have a certain shape (or even to have certain shapes in
certain dimensions). We can't automatically infer the shape of
shared variables as they can change of shape during the
execution by default. (NOT IMPLEMENTED YET, BUT IS IN TRAC)
certain dimensions).
**Using Shape information in Optimizations**
We can't automatically infer the shape of shared variables as they can
change of shape during the execution by default.
To use this shape information in OPTIMIZATIONS, use the
To use this shape information in rewrites, use the
``shape_of`` dictionary.
For example:
......@@ -888,10 +904,10 @@ class ShapeFeature(Feature):
return o_shapes
def get_shape(self, var, idx):
"""Optimization can call this to get the current shape_i
"""Rewrites can call this to get a `Shape_i`.
It is better to call this then use directly shape_of[var][idx]
as this method should update shape_of if needed.
It is better to call this then use directly ``shape_of[var][idx]``
as this method should update `shape_of` if needed.
TODO: Up to now, we don't update it in all cases. Update in all cases.
"""
......@@ -977,11 +993,9 @@ class ShapeFeature(Feature):
error reporting.
"""
# unpack the s_i that the Op returned
assert s_i is not None
if s_i == 1:
# don't make the optimizer merge a zillion ones together
# by always returning the same object to represent 1
return self.lscalar_one
if isinstance(s_i, float) and int(s_i) == s_i:
s_i = int(s_i)
......@@ -1080,10 +1094,9 @@ class ShapeFeature(Feature):
else:
shape_vars.append(self.unpack(s[i], r))
assert all(
not hasattr(r.type, "broadcastable") or not r.type.broadcastable[i] or
# The two following comparison are a speed optimization
# But we never timed this speed optimization!
self.lscalar_one.equals(shape_vars[i])
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
or self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i]))
for i in range(r.type.ndim)
)
......@@ -1118,9 +1131,9 @@ class ShapeFeature(Feature):
and other_r.owner.inputs == r.owner.inputs
and other_r.owner.op == r.owner.op
):
# We are doing a merge. So the 2 shapes graph will be the
# same. This is only a speed optimization to call
# ancestors() less frequently.
# We are doing a merge, so the two shape graphs will be the
# same. This is only done so that we call `ancestors` less
# frequently.
return
# Merge other_shape with r_shape, giving the priority to other_shape
......@@ -1168,10 +1181,7 @@ class ShapeFeature(Feature):
or not r.type.broadcastable[i]
and not other_r.type.broadcastable[i]
)
or
# The two following comparison are a speed optimization
# But we never timed this speed optimization!
self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals(
extract_constant(merged_shape[i], only_process_constants=True)
)
......@@ -1194,10 +1204,9 @@ class ShapeFeature(Feature):
else:
new_shape.append(s_j)
assert all(
not hasattr(r.type, "broadcastable") or not r.type.broadcastable[idx] or
# The two following comparison are a speed optimization
# But we never timed this speed optimization!
self.lscalar_one.equals(new_shape[idx])
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[idx]
or self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx]))
for idx in range(r.type.ndim)
)
......@@ -1273,7 +1282,7 @@ class ShapeFeature(Feature):
)
# Ensure shapes are in 'int64'. This is to make sure the assert
# found in the `local_useless_subtensor` optimization does not fail.
# found in the `local_useless_subtensor` rewrite does not fail.
for sh_idx, sh in enumerate(o_shapes):
if sh is None:
continue
......@@ -1444,7 +1453,7 @@ class ShapeFeature(Feature):
class ShapeOptimizer(GraphRewriter):
"""Optimizer that adds `ShapeFeature` as a feature."""
"""Rewriter that adds `ShapeFeature` as a feature."""
def add_requirements(self, fgraph):
fgraph.attach_feature(ShapeFeature())
......@@ -1454,7 +1463,7 @@ class ShapeOptimizer(GraphRewriter):
class UnShapeOptimizer(GraphRewriter):
"""Optimizer that removes `ShapeFeature` as a feature."""
"""Rewriter that removes `ShapeFeature` as a feature."""
def apply(self, fgraph):
for feature in fgraph._features:
......@@ -1528,7 +1537,7 @@ def local_elemwise_alloc(fgraph, node):
for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
# `Alloc`, so that all `Alloc`s can be optimized.
# `Alloc`, so that all `Alloc`s can be rewritten.
if idx not in alloc_idxs:
ref_var_idx = idx
break
......@@ -1585,7 +1594,7 @@ def local_elemwise_alloc(fgraph, node):
new_inputs[idx] = new_alloc
# If this assert is triggered, it means we are recreating an equivalent graph
# which would result in a cyclical merge optimization.
# which would result in cyclical merge rewrites.
if all(new is old for new, old in zip(new_inputs, node.inputs)):
return
......@@ -1702,9 +1711,9 @@ compile.optdb.register(
def local_useless_fill(fgraph, node):
"""fill(s,v) -> v
This optimization is only needed in FAST_COMPILE to make the code
more readable. Normally, it is done by the local_fill_to_alloc
opt.
This rewrite is only needed in FAST_COMPILE mode to make the code
more readable. Normally, it is done by the `local_fill_to_alloc`
rewrite.
"""
r, v = node.inputs
......@@ -1789,9 +1798,9 @@ def local_alloc_sink_dimshuffle(fgraph, node):
def local_alloc_empty_to_zeros(fgraph, node):
"""This convert AllocEmpty to Alloc of 0.
This help investigate NaN with NanGuardMode. Not registered by
default. To activate it, use the Aesara flag
optimizer_including=alloc_empty_to_zeros.
This helps one investigate NaNs in `NanGuardMode`. Not registered by
default. To activate it, use the setting
``optimizer_including == alloc_empty_to_zeros``.
"""
if isinstance(node.op, AllocEmpty):
return [zeros(node.inputs, dtype=node.outputs[0].dtype)]
......@@ -1811,7 +1820,6 @@ compile.optdb.register(
@node_rewriter([Shape])
def local_shape_to_shape_i(fgraph, node):
if isinstance(node.op, Shape):
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(fgraph, "shape_feature"):
return
shape_feature = fgraph.shape_feature
......@@ -1850,16 +1858,18 @@ def local_track_shape_i(fgraph, node):
@node_rewriter([Elemwise])
def local_useless_elemwise(fgraph, node):
"""
eq(x, x) -> 1
neq(x, x) -> 0
mul(x) -> x
add(x) -> x
identity(x) -> x
and(x, 1) -> x (if x.dtype == 'bool')
and(x, 0) -> zeros_like(x)
or(x, 0) -> x
or(x, 1) -> ones_like(x) (if x.dtype == 'bool')
xor(x, x) -> zeros_like(x)
eq(x, x) -> 1
neq(x, x) -> 0
mul(x) -> x
add(x) -> x
identity(x) -> x
and(x, 1) -> x (if x.dtype == 'bool')
and(x, 0) -> zeros_like(x)
or(x, 0) -> x
or(x, 1) -> ones_like(x) (if x.dtype == 'bool')
xor(x, x) -> zeros_like(x)
TODO: This implementation is painfully redundant.
"""
if isinstance(node.op, Elemwise):
......@@ -1905,7 +1915,7 @@ def local_useless_elemwise(fgraph, node):
return [zeros_like(node.inputs[1], dtype=dtype, opt=True)]
elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise AND,
# and this optimization would be wrong
# and this rewrite would be wrong
return [node.inputs[1].astype(node.outputs[0].dtype)]
if isinstance(node.inputs[1], TensorConstant):
......@@ -1917,7 +1927,7 @@ def local_useless_elemwise(fgraph, node):
return [zeros_like(node.inputs[0], dtype=dtype, opt=True)]
elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise AND,
# and this optimization would be wrong
# and this rewrite would be wrong
return [node.inputs[0].astype(node.outputs[0].dtype)]
elif isinstance(node.op.scalar_op, aes.OR) and len(node.inputs) == 2:
......@@ -1931,7 +1941,7 @@ def local_useless_elemwise(fgraph, node):
return [node.inputs[1].astype(node.outputs[0].dtype)]
elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise OR,
# and this optimization would be wrong
# and this rewrite would be wrong
return [ones_like(node.inputs[1], dtype=dtype, opt=True)]
if isinstance(node.inputs[1], TensorConstant):
......@@ -1943,7 +1953,7 @@ def local_useless_elemwise(fgraph, node):
return [node.inputs[0].astype(node.outputs[0].dtype)]
elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise OR,
# and this optimization would be wrong
# and this rewrite would be wrong
return [ones_like(node.inputs[0], dtype=dtype, opt=True)]
elif isinstance(node.op.scalar_op, aes.XOR) and len(node.inputs) == 2:
......@@ -2081,12 +2091,11 @@ def local_remove_useless_assert(fgraph, node):
@node_rewriter([Assert])
def local_remove_all_assert(fgraph, node):
"""An optimization disabled by default that removes all asserts from
the graph.
r"""A rewrite that removes all `Assert`\s from a graph.
Notes
-----
See the :ref:`unsafe` section to know how to enable it.
See the :ref:`unsafe` section.
"""
if not isinstance(node.op, Assert):
......@@ -2346,7 +2355,7 @@ def local_join_make_vector(fgraph, node):
Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...)
This in combination with the `local_join_1` optimization can make `Join`\s
This, in combination with the `local_join_1` rewrite, can make `Join`\s
completely disappear.
"""
if not isinstance(node.op, Join) or node.outputs[0].ndim != 1:
......@@ -2388,16 +2397,16 @@ def local_join_make_vector(fgraph, node):
@node_rewriter([Elemwise])
def local_useless_switch(fgraph, node):
"""
This optimization makes the following changes in the graph:
This rewrite makes the following changes in a graph:
``at.switch(cond, left, right)`` ->
``if cond is constant and cond == 0``: right
``if cond is constant and cond != 0``: left
``if left is right`` -> ``left``
at.switch(cond, left, right) ->
if cond is constant and cond == 0: right
if cond is constant and cond != 0: left
if left is right -> left
and
``at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X))`` -> ``shape_i{id}(X)``
at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
"""
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Switch):
......@@ -2545,7 +2554,7 @@ def local_reshape_chain(op):
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# optimization.
# rewrite.
if rval.broadcastable == node.outputs[0].broadcastable:
return [rval]
else:
......@@ -2709,10 +2718,12 @@ def local_reshape_to_dimshuffle(fgraph, node):
@node_rewriter([Reshape])
def local_reshape_lift(fgraph, node):
"""
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
This optimization is needed by optimization
log1msigm_to_softplus to get applied when there is a reshape.
Notes
-----
This rewrite is needed by `log1msigm_to_softplus` in order to get applied
when there is a reshape.
"""
if (
......@@ -2840,10 +2851,9 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
The number of dimensions is validated at call time by Aesara itself.
"""
# META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!!
# TODO: use broadcast flag?
# TODO: don't do this optimization as a localOptimizer.
# TODO: don't do this rewrite as a `NodeRewriter`.
# Analyze the graph in terms of elemwise subgraphs, and then
# replace each subgraph with a Composite version.
......@@ -2851,8 +2861,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
# fit within the parameter space of 256 bytes
#
# TODO: Merge with multiple output to merge when an inputs
# have multiple clients. This can't be done with a local
# optimiser.
# have multiple clients. This can't be done with a `NodeRewriter`
# TODO: Related: Support composites with multiple outputs
......@@ -2963,7 +2972,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
except (NotImplementedError, MethodNotDefined):
_logger.warning(
(
"Optimization Warning: "
"Rewrite warning: "
f"The Op {i.owner.op.scalar_op} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
......@@ -3015,10 +3024,9 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
return False
if new_nb_input != len(inputs) or len(s_inputs) != len(inputs):
# TODO FIXME: This shouldn't be a generic `Exception`
raise Exception(
"""Something has gone wrong with the elemwise
fusion optimization. We skip this optimization. You can ignore this message,
your code will run correctly, but may be slower."""
"Something has gone wrong with the elemwise fusion rewrite; skipping."
)
s_new_out = node.op.scalar_op(*s_g, return_list=True)
......@@ -3034,7 +3042,7 @@ your code will run correctly, but may be slower."""
name = str(s_new_out[0].owner.op)
_logger.warning(
(
"Optimization Warning: "
"Rewrite warning: "
f"The Op {name} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
......@@ -3086,15 +3094,15 @@ local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fc
class FusionOptimizer(GraphRewriter):
"""Graph optimizer that simply runs local fusion operations.
"""Graph rewriter that simply runs node fusion operations.
TODO: This is basically a `EquilibriumOptimizer`; we should just use that.
TODO: This is basically an `EquilibriumOptimizer`; we should just use that.
"""
def __init__(self, node_rewriter):
super().__init__()
self.optimizer = node_rewriter
self.node_rewriter = node_rewriter
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
......@@ -3118,7 +3126,7 @@ class FusionOptimizer(GraphRewriter):
for node in nodelist:
# Don't try to fuse node that have already been fused.
if node in fgraph.apply_nodes:
new_outputs = self.optimizer(fgraph, node)
new_outputs = self.node_rewriter(fgraph, node)
if new_outputs:
assert len(new_outputs) == len(node.outputs)
try:
......@@ -3174,7 +3182,7 @@ class FusionOptimizer(GraphRewriter):
if config.tensor__local_elemwise_fusion:
_logger.debug("Enabling Elemwise fusion optimizations in fast_run")
_logger.debug("Enabling Elemwise fusion rewriters in fast_run")
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt = SequenceDB()
fuse_seqopt.register(
......@@ -3194,7 +3202,7 @@ if config.tensor__local_elemwise_fusion:
position=49,
)
else:
_logger.debug("not enabling optimization fusion elemwise in fast_run")
_logger.debug("Not enabling Elemwise fusion rewriters in fast_run")
compile.optdb.register(
"elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
......@@ -3239,10 +3247,14 @@ def local_view_op(fgraph, node):
@register_specialize
@node_rewriter([Alloc])
def local_merge_alloc(fgraph, node):
# This opt takes care of several cases:
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
# Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> Alloc(m, x, assert(y1, y1==y2), z, w)
"""
This rewriter takes care of the following cases:
Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> Alloc(m, x, assert(y1, y1==y2), z, w)
"""
if not isinstance(node.op, Alloc):
return False
if not node.inputs[0].owner or not isinstance(node.inputs[0].owner.op, Alloc):
......@@ -3276,11 +3288,7 @@ def local_merge_alloc(fgraph, node):
@register_useless("fast_compile")
@node_rewriter([TopKOp])
def local_useless_topk(fgraph, node):
"""
TopKOp generates two outputs by default
This opt removes the useless ones
"""
"""Remove unused `TopKOp` outputs."""
op = node.op
if not isinstance(op, TopKOp):
return
......
......@@ -1849,16 +1849,6 @@ crossentropy_categorical_1hot = CrossentropyCategorical1Hot()
@register_specialize("fast_compile")
@optimizer
def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
"""
This is a stabilization optimization.
Notes
-----
Not a local optimization because we are replacing outputs
from several nodes at once.
"""
def search_make_one_sub():
for node in fgraph.toposort():
if node.op == crossentropy_categorical_1hot:
......@@ -1887,18 +1877,13 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
@optimizer
def crossentropy_to_crossentropy_with_softmax(fgraph):
"""
This is a stabilization optimization that is more general than
crossentropy_to_crossentropy_with_softmax_with_bias.
It must be executed after local_softmax_with_bias optimization in
specialize.
TODO : This is a stabilization optimization! How to make this more cleanly?
This is a stabilization rewrite that is more general than
`crossentropy_to_crossentropy_with_softmax_with_bias`.
Notes
-----
Not a local optimization because we are replacing outputs from several
nodes at once.
It must be executed after `local_softmax_with_bias` during the
specialization passes.
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论