提交 0ce6eceb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor old global and local optimizers references and type hints

上级 550a6e98
...@@ -11,7 +11,7 @@ from aesara.graph.unify import eval_if_etuple ...@@ -11,7 +11,7 @@ from aesara.graph.unify import eval_if_etuple
class KanrenRelationSub(NodeRewriter): class KanrenRelationSub(NodeRewriter):
r"""A local optimizer that uses `kanren` to match and replace terms. r"""A rewriter that uses `kanren` to match and replace terms.
See `kanren <https://github.com/pythological/kanren>`__ for more information See `kanren <https://github.com/pythological/kanren>`__ for more information
miniKanren and the API for constructing `kanren` goals. miniKanren and the API for constructing `kanren` goals.
...@@ -56,7 +56,7 @@ class KanrenRelationSub(NodeRewriter): ...@@ -56,7 +56,7 @@ class KanrenRelationSub(NodeRewriter):
A function that takes an input graph and an output logic variable and A function that takes an input graph and an output logic variable and
returns a `kanren` goal. returns a `kanren` goal.
results_filter results_filter
A function that takes the direct output of `kanren.run(None, ...)` A function that takes the direct output of ``kanren.run(None, ...)``
and returns a single result. The default implementation returns and returns a single result. The default implementation returns
the first result. the first result.
node_filter node_filter
......
差异被折叠。
...@@ -290,55 +290,40 @@ class OptimizationQuery: ...@@ -290,55 +290,40 @@ class OptimizationQuery:
class EquilibriumDB(OptimizationDatabase): class EquilibriumDB(OptimizationDatabase):
""" """A database of rewrites that should be applied until equilibrium is reached.
A set of potential optimizations which should be applied in an arbitrary
order until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium optimizations. Canonicalize, Stabilize, and Specialize are all equilibrium optimizations.
Parameters
----------
ignore_newtrees
If False, we will apply local opt on new node introduced during local
optimization application. This could result in less fgraph iterations,
but this doesn't mean it will be faster globally.
tracks_on_change_inputs
If True, we will re-apply local opt on nodes whose inputs
changed during local optimization application. This could
result in less fgraph iterations, but this doesn't mean it
will be faster globally.
Notes Notes
----- -----
We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumOptimizer` We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumOptimizer`
supports both. supports both.
It is probably not a good idea to have ignore_newtrees=False and It is probably not a good idea to have both ``ignore_newtrees == False``
tracks_on_change_inputs=True and ``tracks_on_change_inputs == True``.
""" """
def __init__(self, ignore_newtrees=True, tracks_on_change_inputs=False): def __init__(
self, ignore_newtrees: bool = True, tracks_on_change_inputs: bool = False
):
""" """
Parameters Parameters
========== ----------
ignore_newtrees: ignore_newtrees
If False, we will apply local opt on new node introduced during local If ``False``, apply rewrites to new nodes introduced during
optimization application. This could result in less fgraph iterations, rewriting.
but this doesn't mean it will be faster globally.
tracks_on_change_inputs
tracks_on_change_inputs: If ``True``, re-apply rewrites on nodes with changed inputs.
If True, we will re-apply local opt on nodes whose inputs
changed during local optimization application. This could
result in less fgraph iterations, but this doesn't mean it
will be faster globally.
""" """
super().__init__() super().__init__()
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.tracks_on_change_inputs = tracks_on_change_inputs self.tracks_on_change_inputs = tracks_on_change_inputs
self.__final__ = {} self.__final__: Dict[str, aesara_opt.Rewriter] = {}
self.__cleanup__ = {} self.__cleanup__: Dict[str, aesara_opt.Rewriter] = {}
def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwargs): def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwargs):
if final_opt and cleanup: if final_opt and cleanup:
......
差异被折叠。
...@@ -1849,16 +1849,6 @@ crossentropy_categorical_1hot = CrossentropyCategorical1Hot() ...@@ -1849,16 +1849,6 @@ crossentropy_categorical_1hot = CrossentropyCategorical1Hot()
@register_specialize("fast_compile") @register_specialize("fast_compile")
@optimizer @optimizer
def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph): def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
"""
This is a stabilization optimization.
Notes
-----
Not a local optimization because we are replacing outputs
from several nodes at once.
"""
def search_make_one_sub(): def search_make_one_sub():
for node in fgraph.toposort(): for node in fgraph.toposort():
if node.op == crossentropy_categorical_1hot: if node.op == crossentropy_categorical_1hot:
...@@ -1887,18 +1877,13 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph): ...@@ -1887,18 +1877,13 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
@optimizer @optimizer
def crossentropy_to_crossentropy_with_softmax(fgraph): def crossentropy_to_crossentropy_with_softmax(fgraph):
""" """
This is a stabilization optimization that is more general than This is a stabilization rewrite that is more general than
crossentropy_to_crossentropy_with_softmax_with_bias. `crossentropy_to_crossentropy_with_softmax_with_bias`.
It must be executed after local_softmax_with_bias optimization in
specialize.
TODO : This is a stabilization optimization! How to make this more cleanly?
Notes Notes
----- -----
Not a local optimization because we are replacing outputs from several It must be executed after `local_softmax_with_bias` during the
nodes at once. specialization passes.
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论