提交 ac213377 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename EquilibriumOptimizer to EquilibriumGraphRewriter

上级 e6635af8
...@@ -212,10 +212,10 @@ optdb.register( ...@@ -212,10 +212,10 @@ optdb.register(
"canonicalize_db", "canonicalize_db",
position=1, position=1,
) )
# Register in the canonizer Equilibrium as a clean up opt the merge opt. # Register in the canonizer Equilibrium as a clean-up rewrite the merge rewrite.
# Without this, as the equilibrium have ignore_newtrees=False, we # Without this, as the equilibrium have ignore_newtrees=False, we
# won't merge all nodes if it is set as a global optimizer with # won't merge all nodes if it is set as a global rewriter with
# final_opt=True. # final_rewriter=True.
# We need a new instance of MergeOptimizer to don't have its name # We need a new instance of MergeOptimizer to don't have its name
# changed by other usage of it. # changed by other usage of it.
......
...@@ -1107,7 +1107,7 @@ def add_optimizer_configvars(): ...@@ -1107,7 +1107,7 @@ def add_optimizer_configvars():
config.add( config.add(
"optdb__max_use_ratio", "optdb__max_use_ratio",
"A ratio that prevent infinite loop in EquilibriumOptimizer.", "A ratio that prevent infinite loop in EquilibriumGraphRewriter.",
FloatParam(8), FloatParam(8),
in_c_key=False, in_c_key=False,
) )
......
差异被折叠。
...@@ -31,19 +31,18 @@ class OptimizationDatabase: ...@@ -31,19 +31,18 @@ class OptimizationDatabase:
def register( def register(
self, self,
name: str, name: str,
optimizer: Union["OptimizationDatabase", OptimizersType], rewriter: Union["OptimizationDatabase", OptimizersType],
*tags: str, *tags: str,
use_db_name_as_tag=True, use_db_name_as_tag=True,
**kwargs,
): ):
"""Register a new optimizer to the database. """Register a new rewriter to the database.
Parameters Parameters
---------- ----------
name: name:
Name of the optimizer. Name of the rewriter.
opt: rewriter:
The optimizer to register. The rewriter to register.
tags: tags:
Tag name that allow to select the optimizer. Tag name that allow to select the optimizer.
use_db_name_as_tag: use_db_name_as_tag:
...@@ -58,14 +57,14 @@ class OptimizationDatabase: ...@@ -58,14 +57,14 @@ class OptimizationDatabase:
""" """
if not isinstance( if not isinstance(
optimizer, rewriter,
( (
OptimizationDatabase, OptimizationDatabase,
aesara_opt.GraphRewriter, aesara_opt.GraphRewriter,
aesara_opt.NodeRewriter, aesara_opt.NodeRewriter,
), ),
): ):
raise TypeError(f"{optimizer} is not a valid optimizer type.") raise TypeError(f"{rewriter} is not a valid optimizer type.")
if name in self.__db__: if name in self.__db__:
raise ValueError(f"The tag '{name}' is already present in the database.") raise ValueError(f"The tag '{name}' is already present in the database.")
...@@ -74,18 +73,18 @@ class OptimizationDatabase: ...@@ -74,18 +73,18 @@ class OptimizationDatabase:
if self.name is not None: if self.name is not None:
tags = tags + (self.name,) tags = tags + (self.name,)
optimizer.name = name rewriter.name = name
# This restriction is there because in many place we suppose that # This restriction is there because in many place we suppose that
# something in the OptimizationDatabase is there only once. # something in the OptimizationDatabase is there only once.
if optimizer.name in self.__db__: if rewriter.name in self.__db__:
raise ValueError( raise ValueError(
f"Tried to register {optimizer.name} again under the new name {name}. " f"Tried to register {rewriter.name} again under the new name {name}. "
"The same optimization cannot be registered multiple times in" "The same optimization cannot be registered multiple times in"
" an ``OptimizationDatabase``; use ProxyDB instead." " an ``OptimizationDatabase``; use ProxyDB instead."
) )
self.__db__[name] = OrderedSet([optimizer]) self.__db__[name] = OrderedSet([rewriter])
self._names.add(name) self._names.add(name)
self.__db__[optimizer.__class__.__name__].add(optimizer) self.__db__[rewriter.__class__.__name__].add(rewriter)
self.add_tags(name, *tags) self.add_tags(name, *tags)
def add_tags(self, name, *tags): def add_tags(self, name, *tags):
...@@ -292,11 +291,11 @@ class OptimizationQuery: ...@@ -292,11 +291,11 @@ class OptimizationQuery:
class EquilibriumDB(OptimizationDatabase): class EquilibriumDB(OptimizationDatabase):
"""A database of rewrites that should be applied until equilibrium is reached. """A database of rewrites that should be applied until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium optimizations. Canonicalize, Stabilize, and Specialize are all equilibrium rewriters.
Notes Notes
----- -----
We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumOptimizer` We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumGraphRewriter`
supports both. supports both.
It is probably not a good idea to have both ``ignore_newtrees == False`` It is probably not a good idea to have both ``ignore_newtrees == False``
...@@ -322,33 +321,47 @@ class EquilibriumDB(OptimizationDatabase): ...@@ -322,33 +321,47 @@ class EquilibriumDB(OptimizationDatabase):
super().__init__() super().__init__()
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.tracks_on_change_inputs = tracks_on_change_inputs self.tracks_on_change_inputs = tracks_on_change_inputs
self.__final__: Dict[str, aesara_opt.Rewriter] = {} self.__final__: Dict[str, bool] = {}
self.__cleanup__: Dict[str, aesara_opt.Rewriter] = {} self.__cleanup__: Dict[str, bool] = {}
def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwargs): def register(
if final_opt and cleanup: self,
raise ValueError("`final_opt` and `cleanup` cannot both be true.") name: str,
super().register(name, obj, *tags, **kwargs) rewriter: Union["OptimizationDatabase", OptimizersType],
self.__final__[name] = final_opt *tags: str,
final_rewriter: bool = False,
cleanup: bool = False,
**kwargs,
):
if final_rewriter and cleanup:
raise ValueError("`final_rewriter` and `cleanup` cannot both be true.")
super().register(name, rewriter, *tags, **kwargs)
self.__final__[name] = final_rewriter
self.__cleanup__[name] = cleanup self.__cleanup__[name] = cleanup
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
_opts = super().query(*tags, **kwtags) _rewriters = super().query(*tags, **kwtags)
final_opts = [o for o in _opts if self.__final__.get(o.name, False)] final_rewriters = [o for o in _rewriters if self.__final__.get(o.name, False)]
cleanup_opts = [o for o in _opts if self.__cleanup__.get(o.name, False)] cleanup_rewriters = [
opts = [o for o in _opts if o not in final_opts and o not in cleanup_opts] o for o in _rewriters if self.__cleanup__.get(o.name, False)
if len(final_opts) == 0: ]
final_opts = None rewriters = [
if len(cleanup_opts) == 0: o
cleanup_opts = None for o in _rewriters
return aesara_opt.EquilibriumOptimizer( if o not in final_rewriters and o not in cleanup_rewriters
opts, ]
if len(final_rewriters) == 0:
final_rewriters = None
if len(cleanup_rewriters) == 0:
cleanup_rewriters = None
return aesara_opt.EquilibriumGraphRewriter(
rewriters,
max_use_ratio=config.optdb__max_use_ratio, max_use_ratio=config.optdb__max_use_ratio,
ignore_newtrees=self.ignore_newtrees, ignore_newtrees=self.ignore_newtrees,
tracks_on_change_inputs=self.tracks_on_change_inputs, tracks_on_change_inputs=self.tracks_on_change_inputs,
failure_callback=aesara_opt.NodeProcessingGraphRewriter.warn_inplace, failure_callback=aesara_opt.NodeProcessingGraphRewriter.warn_inplace,
final_optimizers=final_opts, final_rewriters=final_rewriters,
cleanup_optimizers=cleanup_opts, cleanup_rewriters=cleanup_rewriters,
) )
...@@ -372,8 +385,10 @@ class SequenceDB(OptimizationDatabase): ...@@ -372,8 +385,10 @@ class SequenceDB(OptimizationDatabase):
self.failure_callback = failure_callback self.failure_callback = failure_callback
def register(self, name, obj, *tags, **kwargs): def register(self, name, obj, *tags, **kwargs):
super().register(name, obj, *tags, **kwargs)
position = kwargs.pop("position", "last") position = kwargs.pop("position", "last")
super().register(name, obj, *tags, **kwargs)
if position == "last": if position == "last":
if len(self.__position__) == 0: if len(self.__position__) == 0:
self.__position__[name] = 0 self.__position__[name] = 0
......
...@@ -2373,7 +2373,7 @@ optdb.register( ...@@ -2373,7 +2373,7 @@ optdb.register(
position=75, position=75,
) )
scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan", position=1) scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan")
scan_seqopt1.register( scan_seqopt1.register(
...@@ -2419,7 +2419,7 @@ scan_seqopt1.register( ...@@ -2419,7 +2419,7 @@ scan_seqopt1.register(
scan_seqopt1.register( scan_seqopt1.register(
"scan_pushout_add", "scan_pushout_add",
# TODO: Perhaps this should be an `EquilibriumOptimizer`? # TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
in2out(push_out_add_scan, ignore_newtrees=False), in2out(push_out_add_scan, ignore_newtrees=False),
"fast_run", "fast_run",
"more_mem", "more_mem",
...@@ -2434,7 +2434,6 @@ scan_eqopt2.register( ...@@ -2434,7 +2434,6 @@ scan_eqopt2.register(
in2out(basic_opt.constant_folding, ignore_newtrees=True), in2out(basic_opt.constant_folding, ignore_newtrees=True),
"fast_run", "fast_run",
"scan", "scan",
position=1,
) )
...@@ -2444,14 +2443,13 @@ scan_eqopt2.register( ...@@ -2444,14 +2443,13 @@ scan_eqopt2.register(
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
"fast_run", "fast_run",
"scan", "scan",
position=2,
) )
# after const merge but before stabilize so that we can have identity # after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out # for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later. # of the scan later.
scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan", position=4) scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan")
# After Merge optimization # After Merge optimization
scan_eqopt2.register( scan_eqopt2.register(
...@@ -2460,7 +2458,6 @@ scan_eqopt2.register( ...@@ -2460,7 +2458,6 @@ scan_eqopt2.register(
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
"fast_run", "fast_run",
"scan", "scan",
position=5,
) )
scan_eqopt2.register( scan_eqopt2.register(
...@@ -2468,7 +2465,6 @@ scan_eqopt2.register( ...@@ -2468,7 +2465,6 @@ scan_eqopt2.register(
in2out(scan_merge_inouts, ignore_newtrees=True), in2out(scan_merge_inouts, ignore_newtrees=True),
"fast_run", "fast_run",
"scan", "scan",
position=6,
) )
# After everything else # After everything else
...@@ -2478,5 +2474,4 @@ scan_eqopt2.register( ...@@ -2478,5 +2474,4 @@ scan_eqopt2.register(
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
"fast_run", "fast_run",
"scan", "scan",
position=8,
) )
...@@ -2802,10 +2802,10 @@ def constant_folding(fgraph, node): ...@@ -2802,10 +2802,10 @@ def constant_folding(fgraph, node):
topo_constant_folding = in2out( topo_constant_folding = in2out(
constant_folding, ignore_newtrees=True, name="topo_constant_folding" constant_folding, ignore_newtrees=True, name="topo_constant_folding"
) )
register_canonicalize(topo_constant_folding, "fast_compile", final_opt=True) register_canonicalize(topo_constant_folding, "fast_compile", final_rewriter=True)
register_uncanonicalize(topo_constant_folding, "fast_compile", final_opt=True) register_uncanonicalize(topo_constant_folding, "fast_compile", final_rewriter=True)
register_stabilize(topo_constant_folding, "fast_compile", final_opt=True) register_stabilize(topo_constant_folding, "fast_compile", final_rewriter=True)
register_specialize(topo_constant_folding, "fast_compile", final_opt=True) register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True)
def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None): def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None):
...@@ -3096,7 +3096,7 @@ local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fc ...@@ -3096,7 +3096,7 @@ local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fc
class FusionOptimizer(GraphRewriter): class FusionOptimizer(GraphRewriter):
"""Graph rewriter that simply runs node fusion operations. """Graph rewriter that simply runs node fusion operations.
TODO: This is basically an `EquilibriumOptimizer`; we should just use that. TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that.
""" """
......
...@@ -146,7 +146,7 @@ from aesara.graph.basic import Apply, view_roots ...@@ -146,7 +146,7 @@ from aesara.graph.basic import Apply, view_roots
from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
EquilibriumOptimizer, EquilibriumGraphRewriter,
GraphRewriter, GraphRewriter,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
...@@ -1906,7 +1906,7 @@ blas_optdb.register( ...@@ -1906,7 +1906,7 @@ blas_optdb.register(
blas_optdb.register("gemm_optimizer", GemmOptimizer(), "fast_run", position=10) blas_optdb.register("gemm_optimizer", GemmOptimizer(), "fast_run", position=10)
blas_optdb.register( blas_optdb.register(
"local_gemm_to_gemv", "local_gemm_to_gemv",
EquilibriumOptimizer( EquilibriumGraphRewriter(
[ [
local_gemm_to_gemv, local_gemm_to_gemv,
local_gemm_to_ger, local_gemm_to_ger,
......
...@@ -444,7 +444,7 @@ The following is an example that distributes dot products across additions. ...@@ -444,7 +444,7 @@ The following is an example that distributes dot products across additions.
import aesara import aesara
import aesara.tensor as at import aesara.tensor as at
from aesara.graph.kanren import KanrenRelationSub from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.opt import EquilibriumOptimizer from aesara.graph.opt import EquilibriumGraphRewriter
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.math import _dot from aesara.tensor.math import _dot
from etuples import etuple from etuples import etuple
...@@ -484,7 +484,7 @@ The following is an example that distributes dot products across additions. ...@@ -484,7 +484,7 @@ The following is an example that distributes dot products across additions.
) )
dot_distribute_opt = EquilibriumOptimizer([KanrenRelationSub(dot_distributeo)], max_use_ratio=10) dot_distribute_opt = EquilibriumGraphRewriter([KanrenRelationSub(dot_distributeo)], max_use_ratio=10)
Below, we apply `dot_distribute_opt` to a few example graphs. First we create simple test graph: Below, we apply `dot_distribute_opt` to a few example graphs. First we create simple test graph:
...@@ -531,7 +531,7 @@ relational properties. ...@@ -531,7 +531,7 @@ relational properties.
To do that, we will create another :class:`Rewriter` that simply reverses the arguments To do that, we will create another :class:`Rewriter` that simply reverses the arguments
to the relation :func:`dot_distributeo` and apply it to the distributed result in ``res``: to the relation :func:`dot_distributeo` and apply it to the distributed result in ``res``:
>>> dot_gather_opt = EquilibriumOptimizer([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10) >>> dot_gather_opt = EquilibriumGraphRewriter([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10)
>>> rev_res = optimize_graph(res, include=[], custom_opt=dot_gather_opt, clone=False) >>> rev_res = optimize_graph(res, include=[], custom_opt=dot_gather_opt, clone=False)
>>> print(aesara.pprint(rev_res)) >>> print(aesara.pprint(rev_res))
(A @ (x + (y + (B @ (z + w))))) (A @ (x + (y + (B @ (z + w)))))
...@@ -561,7 +561,7 @@ serve as a basis for filtering. ...@@ -561,7 +561,7 @@ serve as a basis for filtering.
The point of :obj:`optdb` is that you might want to apply many optimizations The point of :obj:`optdb` is that you might want to apply many optimizations
to a computation graph in many unique patterns. For example, you might to a computation graph in many unique patterns. For example, you might
want to do optimization X, then optimization Y, then optimization Z. And then want to do optimization X, then optimization Y, then optimization Z. And then
maybe optimization Y is an :class:`EquilibriumOptimizer` containing :class:`NodeRewriter`\s A, B maybe optimization Y is an :class:`EquilibriumGraphRewriter` containing :class:`NodeRewriter`\s A, B
and C which are applied on every node of the graph until they all fail to change and C which are applied on every node of the graph until they all fail to change
it. If some optimizations act up, we want an easy way to turn them off. Ditto if it. If some optimizations act up, we want an easy way to turn them off. Ditto if
some optimizations are very CPU-intensive and we don't want to take the time to some optimizations are very CPU-intensive and we don't want to take the time to
...@@ -599,7 +599,7 @@ optimizers they return will be put in their places. ...@@ -599,7 +599,7 @@ optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`OptimizationDatabase` objects. Each of them An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`OptimizationDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to
an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are
inserted into an :class:`EquilibriumOptimizer`, which is returned. If the inserted into an :class:`EquilibriumGraphRewriter`, which is returned. If the
:class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the :class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the
:class:`OptimizationQuery` will be passed to them as well and the :class:`OptimizationQuery` will be passed to them as well and the
:class:`NodeRewriter`\s they return will be put in their places :class:`NodeRewriter`\s they return will be put in their places
...@@ -859,8 +859,8 @@ This will output something like this: ...@@ -859,8 +859,8 @@ This will output something like this:
0.028s for fgraph.validate() 0.028s for fgraph.validate()
0.131s for callback 0.131s for callback
time - (name, class, index) - validate time time - (name, class, index) - validate time
0.751816s - ('canonicalize', 'EquilibriumOptimizer', 4) - 0.004s 0.751816s - ('canonicalize', 'EquilibriumGraphRewriter', 4) - 0.004s
EquilibriumOptimizer canonicalize EquilibriumGraphRewriter canonicalize
time 0.751s for 14 passes time 0.751s for 14 passes
nb nodes (start, end, max) 108 81 117 nb nodes (start, end, max) 108 81 117
time io_toposort 0.029s time io_toposort 0.029s
...@@ -974,8 +974,8 @@ This will output something like this: ...@@ -974,8 +974,8 @@ This will output something like this:
init io_toposort 0.00171804428101 init io_toposort 0.00171804428101
loop time 0.000502109527588 loop time 0.000502109527588
callback_time 0.0 callback_time 0.0
0.002257s - ('local_gemm_to_gemv', 'EquilibriumOptimizer', 3) - 0.000s 0.002257s - ('local_gemm_to_gemv', 'EquilibriumGraphRewriter', 3) - 0.000s
EquilibriumOptimizer local_gemm_to_gemv EquilibriumGraphRewriter local_gemm_to_gemv
time 0.002s for 1 passes time 0.002s for 1 passes
nb nodes (start, end, max) 80 80 80 nb nodes (start, end, max) 80 80 80
time io_toposort 0.001s time io_toposort 0.001s
...@@ -994,8 +994,8 @@ This will output something like this: ...@@ -994,8 +994,8 @@ This will output something like this:
init io_toposort 0.00138401985168 init io_toposort 0.00138401985168
loop time 0.000202178955078 loop time 0.000202178955078
callback_time 0.0 callback_time 0.0
0.031740s - ('specialize', 'EquilibriumOptimizer', 9) - 0.000s 0.031740s - ('specialize', 'EquilibriumGraphRewriter', 9) - 0.000s
EquilibriumOptimizer specialize EquilibriumGraphRewriter specialize
time 0.031s for 2 passes time 0.031s for 2 passes
nb nodes (start, end, max) 80 78 80 nb nodes (start, end, max) 80 78 80
time io_toposort 0.003s time io_toposort 0.003s
...@@ -1080,8 +1080,8 @@ To understand this profile here is some explanation of how optimizations work: ...@@ -1080,8 +1080,8 @@ To understand this profile here is some explanation of how optimizations work:
.. code-block:: none .. code-block:: none
0.751816s - ('canonicalize', 'EquilibriumOptimizer', 4) - 0.004s 0.751816s - ('canonicalize', 'EquilibriumGraphRewriter', 4) - 0.004s
EquilibriumOptimizer canonicalize EquilibriumGraphRewriter canonicalize
time 0.751s for 14 passes time 0.751s for 14 passes
nb nodes (start, end, max) 108 81 117 nb nodes (start, end, max) 108 81 117
time io_toposort 0.029s time io_toposort 0.029s
...@@ -1146,15 +1146,15 @@ To understand this profile here is some explanation of how optimizations work: ...@@ -1146,15 +1146,15 @@ To understand this profile here is some explanation of how optimizations work:
0.000s - local_subtensor_of_dot 0.000s - local_subtensor_of_dot
0.000s - local_subtensor_merge 0.000s - local_subtensor_merge
* ``0.751816s - ('canonicalize', 'EquilibriumOptimizer', 4) - 0.004s`` * ``0.751816s - ('canonicalize', 'EquilibriumGraphRewriter', 4) - 0.004s``
This line is from :class:`SequentialGraphRewriter`, and indicates information related This line is from :class:`SequentialGraphRewriter`, and indicates information related
to a sub-optimizer. It means that this sub-optimizer took to a sub-optimizer. It means that this sub-optimizer took
a total of .7s. Its name is ``'canonicalize'``. It is an a total of .7s. Its name is ``'canonicalize'``. It is an
:class:`EquilibriumOptimizer`. It was executed at index 4 by the :class:`EquilibriumGraphRewriter`. It was executed at index 4 by the
:class:`SequentialGraphRewriter`. It spent 0.004s in the *validate* phase. :class:`SequentialGraphRewriter`. It spent 0.004s in the *validate* phase.
* All other lines are from the profiler of the :class:`EquilibriumOptimizer`. * All other lines are from the profiler of the :class:`EquilibriumGraphRewriter`.
* An :class:`EquilibriumOptimizer` does multiple passes on the Apply nodes from * An :class:`EquilibriumGraphRewriter` does multiple passes on the Apply nodes from
the graph, trying to apply local and global optimizations. the graph, trying to apply local and global optimizations.
Conceptually, it tries to execute all global optimizations, Conceptually, it tries to execute all global optimizations,
and to apply all local optimizations on all and to apply all local optimizations on all
......
...@@ -13,7 +13,7 @@ from aesara.graph.basic import Apply ...@@ -13,7 +13,7 @@ from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.kanren import KanrenRelationSub from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import EquilibriumOptimizer from aesara.graph.opt import EquilibriumGraphRewriter
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.graph.unify import eval_if_etuple from aesara.graph.unify import eval_if_etuple
from aesara.tensor.math import Dot, _dot from aesara.tensor.math import Dot, _dot
...@@ -151,7 +151,7 @@ def test_KanrenRelationSub_dot(): ...@@ -151,7 +151,7 @@ def test_KanrenRelationSub_dot():
), ),
) )
distribute_opt = EquilibriumOptimizer( distribute_opt = EquilibriumGraphRewriter(
[KanrenRelationSub(distributes)], max_use_ratio=10 [KanrenRelationSub(distributes)], max_use_ratio=10
) )
......
...@@ -6,7 +6,7 @@ from aesara.graph.features import Feature ...@@ -6,7 +6,7 @@ from aesara.graph.features import Feature
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
EquilibriumOptimizer, EquilibriumGraphRewriter,
MergeOptimizer, MergeOptimizer,
OpKeyGraphRewriter, OpKeyGraphRewriter,
OpToRewriterTracker, OpToRewriterTracker,
...@@ -446,7 +446,7 @@ class TestEquilibrium: ...@@ -446,7 +446,7 @@ class TestEquilibrium:
e = op3(op4(x, y)) e = op3(op4(x, y))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
# print g # print g
opt = EquilibriumOptimizer( opt = EquilibriumGraphRewriter(
[ [
PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")), PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")),
PatternNodeRewriter((op4, "x", "y"), (op1, "x", "y")), PatternNodeRewriter((op4, "x", "y"), (op1, "x", "y")),
...@@ -463,7 +463,7 @@ class TestEquilibrium: ...@@ -463,7 +463,7 @@ class TestEquilibrium:
e = op1(op1(op3(x, y))) e = op1(op1(op3(x, y)))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
# print g # print g
opt = EquilibriumOptimizer( opt = EquilibriumGraphRewriter(
[ [
PatternNodeRewriter((op1, (op2, "x", "y")), (op4, "x", "y")), PatternNodeRewriter((op1, (op2, "x", "y")), (op4, "x", "y")),
PatternNodeRewriter((op3, "x", "y"), (op4, "x", "y")), PatternNodeRewriter((op3, "x", "y"), (op4, "x", "y")),
...@@ -488,7 +488,7 @@ class TestEquilibrium: ...@@ -488,7 +488,7 @@ class TestEquilibrium:
oldlevel = _logger.level oldlevel = _logger.level
_logger.setLevel(logging.CRITICAL) _logger.setLevel(logging.CRITICAL)
try: try:
opt = EquilibriumOptimizer( opt = EquilibriumGraphRewriter(
[ [
PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")), PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")),
PatternNodeRewriter((op4, "x", "y"), (op1, "x", "y")), PatternNodeRewriter((op4, "x", "y"), (op1, "x", "y")),
...@@ -600,7 +600,7 @@ def test_patternsub_values_eq_approx(out_pattern, tracks): ...@@ -600,7 +600,7 @@ def test_patternsub_values_eq_approx(out_pattern, tracks):
e = op1(x) e = op1(x)
fg = FunctionGraph([x], [e], clone=False) fg = FunctionGraph([x], [e], clone=False)
opt = EquilibriumOptimizer( opt = EquilibriumGraphRewriter(
[ [
PatternNodeRewriter( PatternNodeRewriter(
(op1, "x"), (op1, "x"),
...@@ -633,7 +633,7 @@ def test_patternsub_invalid_dtype(out_pattern): ...@@ -633,7 +633,7 @@ def test_patternsub_invalid_dtype(out_pattern):
e = op_cast_type2(x) e = op_cast_type2(x)
fg = FunctionGraph([x], [e]) fg = FunctionGraph([x], [e])
opt = EquilibriumOptimizer( opt = EquilibriumGraphRewriter(
[ [
PatternNodeRewriter( PatternNodeRewriter(
(op_cast_type2, "x"), (op_cast_type2, "x"),
......
...@@ -45,8 +45,8 @@ class TestDB: ...@@ -45,8 +45,8 @@ class TestDB:
def test_EquilibriumDB(self): def test_EquilibriumDB(self):
eq_db = EquilibriumDB() eq_db = EquilibriumDB()
with pytest.raises(ValueError, match=r"`final_opt` and.*"): with pytest.raises(ValueError, match=r"`final_rewriter` and.*"):
eq_db.register("d", TestOpt(), final_opt=True, cleanup=True) eq_db.register("d", TestOpt(), final_rewriter=True, cleanup=True)
def test_SequenceDB(self): def test_SequenceDB(self):
seq_db = SequenceDB(failure_callback=None) seq_db = SequenceDB(failure_callback=None)
......
...@@ -7,7 +7,7 @@ from aesara.compile.function import function ...@@ -7,7 +7,7 @@ from aesara.compile.function import function
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import EquilibriumOptimizer from aesara.graph.opt import EquilibriumGraphRewriter
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import OptimizationQuery
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.random.basic import ( from aesara.tensor.random.basic import (
...@@ -50,7 +50,7 @@ def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None ...@@ -50,7 +50,7 @@ def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None
p for p in dist_params_at + size_at if not isinstance(p, (slice, Constant)) p for p in dist_params_at + size_at if not isinstance(p, (slice, Constant))
] ]
mode = Mode("py", EquilibriumOptimizer([opt], max_use_ratio=100)) mode = Mode("py", EquilibriumGraphRewriter([opt], max_use_ratio=100))
f_opt = function( f_opt = function(
f_inputs, f_inputs,
...@@ -519,7 +519,7 @@ def test_Subtensor_lift_restrictions(): ...@@ -519,7 +519,7 @@ def test_Subtensor_lift_restrictions():
z = x - y z = x - y
fg = FunctionGraph([rng], [z], clone=False) fg = FunctionGraph([rng], [z], clone=False)
_ = EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) _ = EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner
assert subtensor_node == y.owner assert subtensor_node == y.owner
...@@ -531,7 +531,7 @@ def test_Subtensor_lift_restrictions(): ...@@ -531,7 +531,7 @@ def test_Subtensor_lift_restrictions():
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles # We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly. # `"output"` "nodes" correctly.
fg = FunctionGraph([rng], [z, x], clone=False) fg = FunctionGraph([rng], [z, x], clone=False)
EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
assert fg.outputs[0] == z assert fg.outputs[0] == z
assert fg.outputs[1] == x assert fg.outputs[1] == x
...@@ -539,7 +539,7 @@ def test_Subtensor_lift_restrictions(): ...@@ -539,7 +539,7 @@ def test_Subtensor_lift_restrictions():
# The non-`Subtensor` client doesn't depend on the RNG state, so we can # The non-`Subtensor` client doesn't depend on the RNG state, so we can
# perform the lift # perform the lift
fg = FunctionGraph([rng], [z], clone=False) fg = FunctionGraph([rng], [z], clone=False)
EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner
assert rv_node.op == normal assert rv_node.op == normal
...@@ -557,7 +557,9 @@ def test_Dimshuffle_lift_restrictions(): ...@@ -557,7 +557,9 @@ def test_Dimshuffle_lift_restrictions():
z = x - y z = x - y
fg = FunctionGraph([rng], [z, y], clone=False) fg = FunctionGraph([rng], [z, y], clone=False)
_ = EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) _ = EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(
fg
)
dimshuffle_node = fg.outputs[0].owner.inputs[1].owner dimshuffle_node = fg.outputs[0].owner.inputs[1].owner
assert dimshuffle_node == y.owner assert dimshuffle_node == y.owner
...@@ -569,7 +571,7 @@ def test_Dimshuffle_lift_restrictions(): ...@@ -569,7 +571,7 @@ def test_Dimshuffle_lift_restrictions():
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles # We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly. # `"output"` "nodes" correctly.
fg = FunctionGraph([rng], [z, x], clone=False) fg = FunctionGraph([rng], [z, x], clone=False)
EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
assert fg.outputs[0] == z assert fg.outputs[0] == z
assert fg.outputs[1] == x assert fg.outputs[1] == x
...@@ -577,7 +579,7 @@ def test_Dimshuffle_lift_restrictions(): ...@@ -577,7 +579,7 @@ def test_Dimshuffle_lift_restrictions():
# The non-`Dimshuffle` client doesn't depend on the RNG state, so we can # The non-`Dimshuffle` client doesn't depend on the RNG state, so we can
# perform the lift # perform the lift
fg = FunctionGraph([rng], [z], clone=False) fg = FunctionGraph([rng], [z], clone=False)
EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
rv_node = fg.outputs[0].owner.inputs[1].owner rv_node = fg.outputs[0].owner.inputs[1].owner
assert rv_node.op == normal assert rv_node.op == normal
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论