提交 550a6e98 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename LocalOptimizer to NodeRewriter

上级 214ef4cf
......@@ -24,7 +24,7 @@ from aesara.graph.basic import (
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.opt import in2out, local_optimizer
from aesara.graph.opt import in2out, node_rewriter
from aesara.graph.utils import MissingInputError
from aesara.tensor.basic_opt import ShapeFeature
......@@ -928,7 +928,7 @@ class OpFromGraph(Op, HasInnerGraph):
output[0] = variable
@local_optimizer([OpFromGraph])
@node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node):
"""
This optimization expands internal graph of OpFromGraph.
......
......@@ -13,7 +13,7 @@ from aesara.graph.basic import (
from aesara.graph.op import Op
from aesara.graph.type import Type
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import local_optimizer, optimizer
from aesara.graph.opt import node_rewriter, optimizer
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery
......
......@@ -6,11 +6,11 @@ from unification import var
from unification.variable import Var
from aesara.graph.basic import Apply, Variable
from aesara.graph.opt import LocalOptimizer
from aesara.graph.opt import NodeRewriter
from aesara.graph.unify import eval_if_etuple
class KanrenRelationSub(LocalOptimizer):
class KanrenRelationSub(NodeRewriter):
r"""A local optimizer that uses `kanren` to match and replace terms.
See `kanren <https://github.com/pythological/kanren>`__ for more information
......
差异被折叠。
......@@ -11,14 +11,14 @@ from aesara.misc.ordered_set import OrderedSet
from aesara.utils import DefaultOrderedDict
OptimizersType = Union[aesara_opt.GraphRewriter, aesara_opt.LocalOptimizer]
OptimizersType = Union[aesara_opt.GraphRewriter, aesara_opt.NodeRewriter]
class OptimizationDatabase:
r"""A class that represents a collection/database of optimizations.
These databases are used to logically organize collections of optimizers
(i.e. `GraphRewriter`\s and `LocalOptimizer`).
(i.e. `GraphRewriter`\s and `NodeRewriter`).
"""
def __init__(self):
......@@ -62,7 +62,7 @@ class OptimizationDatabase:
(
OptimizationDatabase,
aesara_opt.GraphRewriter,
aesara_opt.LocalOptimizer,
aesara_opt.NodeRewriter,
),
):
raise TypeError(f"{optimizer} is not a valid optimizer type.")
......@@ -311,7 +311,7 @@ class EquilibriumDB(OptimizationDatabase):
Notes
-----
We can use `LocalOptimizer` and `GraphRewriter` since `EquilibriumOptimizer`
We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumOptimizer`
supports both.
It is probably not a good idea to have ignore_newtrees=False and
......@@ -474,24 +474,18 @@ class SequenceDB(OptimizationDatabase):
class LocalGroupDB(SequenceDB):
"""
Generate a local optimizer of type LocalOptGroup instead
of a global optimizer.
It supports the tracks, to only get applied to some Op.
"""
r"""A database that generates `NodeRewriter`\s of type `LocalOptGroup`."""
def __init__(
self,
apply_all_opts: bool = False,
profile: bool = False,
local_opt=aesara_opt.LocalOptGroup,
node_rewriter=aesara_opt.LocalOptGroup,
):
super().__init__(failure_callback=None)
self.apply_all_opts = apply_all_opts
self.profile = profile
self.local_opt = local_opt
self.node_rewriter = node_rewriter
self.__name__: str = ""
def register(self, name, obj, *tags, position="last", **kwargs):
......@@ -499,7 +493,7 @@ class LocalGroupDB(SequenceDB):
def query(self, *tags, **kwtags):
opts = list(super().query(*tags, **kwtags))
ret = self.local_opt(
ret = self.node_rewriter(
*opts, apply_all_opts=self.apply_all_opts, profile=self.profile
)
return ret
......
......@@ -22,7 +22,7 @@ from aesara.compile import optdb
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GraphRewriter, in2out, local_optimizer
from aesara.graph.opt import GraphRewriter, in2out, node_rewriter
from aesara.graph.type import HasDataType, HasShape
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
......@@ -404,7 +404,7 @@ def ifelse(
return tuple(rval)
@local_optimizer([IfElse])
@node_rewriter([IfElse])
def cond_make_inplace(fgraph, node):
op = node.op
if (
......@@ -482,7 +482,7 @@ acceptable_ops = (
)
@local_optimizer(acceptable_ops)
@node_rewriter(acceptable_ops)
def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node):
"""This optimization lifts up certain ifelse instances.
......@@ -529,7 +529,7 @@ def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node):
return nw_outs
@local_optimizer([IfElse])
@node_rewriter([IfElse])
def cond_merge_ifs_true(fgraph, node):
op = node.op
if not isinstance(op, IfElse):
......@@ -556,7 +556,7 @@ def cond_merge_ifs_true(fgraph, node):
return op(*old_ins, return_list=True)
@local_optimizer([IfElse])
@node_rewriter([IfElse])
def cond_merge_ifs_false(fgraph, node):
op = node.op
if not isinstance(op, IfElse):
......@@ -635,7 +635,7 @@ class CondMerge(GraphRewriter):
fgraph.replace_all_validate(pairs, reason="cond_merge")
@local_optimizer([IfElse])
@node_rewriter([IfElse])
def cond_remove_identical(fgraph, node):
op = node.op
......@@ -681,7 +681,7 @@ def cond_remove_identical(fgraph, node):
return rval
@local_optimizer([IfElse])
@node_rewriter([IfElse])
def cond_merge_random_op(fgraph, main_node):
if isinstance(main_node.op, IfElse):
return False
......
import logging
from aesara.graph.opt import local_optimizer
from aesara.graph.opt import node_rewriter
from aesara.tensor import basic as at
from aesara.tensor.basic_opt import (
register_canonicalize,
......@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
@register_canonicalize
@local_optimizer([DimShuffle])
@node_rewriter([DimShuffle])
def transinv_to_invtrans(fgraph, node):
if isinstance(node.op, DimShuffle):
if node.op.new_order == (1, 0):
......@@ -32,7 +32,7 @@ def transinv_to_invtrans(fgraph, node):
@register_stabilize
@local_optimizer([Dot, Dot22])
@node_rewriter([Dot, Dot22])
def inv_as_solve(fgraph, node):
"""
This utilizes a boolean `symmetric` tag on the matrices.
......@@ -51,7 +51,7 @@ def inv_as_solve(fgraph, node):
@register_stabilize
@register_canonicalize
@local_optimizer([Solve])
@node_rewriter([Solve])
def tag_solve_triangular(fgraph, node):
"""
If a general solve() is applied to the output of a cholesky op, then
......@@ -82,7 +82,7 @@ def tag_solve_triangular(fgraph, node):
@register_canonicalize
@register_stabilize
@register_specialize
@local_optimizer([DimShuffle])
@node_rewriter([DimShuffle])
def no_transpose_symmetric(fgraph, node):
if isinstance(node.op, DimShuffle):
x = node.inputs[0]
......@@ -92,7 +92,7 @@ def no_transpose_symmetric(fgraph, node):
@register_stabilize
@local_optimizer([Solve])
@node_rewriter([Solve])
def psd_solve_with_chol(fgraph, node):
"""
This utilizes a boolean `psd` tag on matrices.
......@@ -111,7 +111,7 @@ def psd_solve_with_chol(fgraph, node):
@register_stabilize
@register_specialize
@local_optimizer([Det])
@node_rewriter([Det])
def local_det_chol(fgraph, node):
"""
If we have det(X) and there is already an L=cholesky(X)
......@@ -129,7 +129,7 @@ def local_det_chol(fgraph, node):
@register_canonicalize
@register_stabilize
@register_specialize
@local_optimizer([log])
@node_rewriter([log])
def local_log_prod_sqr(fgraph, node):
"""
This utilizes a boolean `positive` tag on matrices.
......
......@@ -25,7 +25,7 @@ from aesara.compile import optdb
from aesara.configdefaults import config
from aesara.gradient import undefined_grad
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.opt import in2out, local_optimizer
from aesara.graph.opt import in2out, node_rewriter
from aesara.link.c.op import COp, Op
from aesara.link.c.params_type import ParamsType
from aesara.sandbox import multinomial
......@@ -1343,7 +1343,7 @@ def _check_size(size):
return at.as_tensor_variable(size, ndim=1)
@local_optimizer((mrg_uniform_base,))
@node_rewriter((mrg_uniform_base,))
def mrg_random_make_inplace(fgraph, node):
op = node.op
......
......@@ -28,7 +28,7 @@ from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value
from aesara.graph.opt import GraphRewriter, in2out, local_optimizer
from aesara.graph.opt import GraphRewriter, in2out, node_rewriter
from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.graph.type import HasShape
from aesara.graph.utils import InconsistencyError
......@@ -67,7 +67,7 @@ list_opt_slice = [
]
@local_optimizer([Scan])
@node_rewriter([Scan])
def remove_constants_and_unused_inputs_scan(fgraph, node):
"""Move constants into the inner graph, and remove unused inputs.
......@@ -192,7 +192,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
return False
@local_optimizer([Scan])
@node_rewriter([Scan])
def push_out_non_seq_scan(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on non-sequences.
......@@ -400,7 +400,7 @@ def push_out_non_seq_scan(fgraph, node):
return False
@local_optimizer([Scan])
@node_rewriter([Scan])
def push_out_seq_scan(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
......@@ -812,7 +812,7 @@ def add_nitsot_outputs(
return new_scan_node, {}
@local_optimizer([Scan])
@node_rewriter([Scan])
def push_out_add_scan(fgraph, node):
r"""Push `Add` operations performed at the end of the inner graph to the outside.
......@@ -1113,7 +1113,7 @@ def sanitize(x):
return at.as_tensor_variable(x)
@local_optimizer([Scan])
@node_rewriter([Scan])
def save_mem_new_scan(fgraph, node):
r"""Graph optimizer that reduces scan memory consumption.
......@@ -1950,7 +1950,7 @@ def make_equiv(lo, li):
return left, right
@local_optimizer([Scan])
@node_rewriter([Scan])
def scan_merge_inouts(fgraph, node):
"""
This optimization attempts to merge a `Scan` `Op`'s identical outer inputs as well
......@@ -2154,7 +2154,7 @@ def scan_merge_inouts(fgraph, node):
return na.outer_outputs
@local_optimizer([Scan])
@node_rewriter([Scan])
def push_out_dot1_scan(fgraph, node):
r"""
This is another optimization that attempts to detect certain patterns of
......
......@@ -4,7 +4,7 @@ import aesara
import aesara.scalar as aes
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.opt import PatternSub, TopoOptimizer, local_optimizer
from aesara.graph.opt import PatternSub, TopoOptimizer, node_rewriter
from aesara.link.c.op import COp, _NoPythonCOp
from aesara.misc.safe_asarray import _asarray
from aesara.sparse import basic as sparse
......@@ -32,7 +32,7 @@ _is_dense = sparse._is_dense
# This is tested in tests/test_opt.py:test_local_csm_properties_csm
@local_optimizer([csm_properties])
@node_rewriter([csm_properties])
def local_csm_properties_csm(fgraph, node):
"""
If we find csm_properties(CSM(*args)), then we can replace that with the
......@@ -51,7 +51,7 @@ register_specialize(local_csm_properties_csm)
# This is tested in tests/test_basic.py:test_remove0
@local_optimizer([sparse.Remove0])
@node_rewriter([sparse.Remove0])
def local_inplace_remove0(fgraph, node):
"""
Optimization to insert inplace versions of Remove0.
......@@ -188,7 +188,7 @@ class AddSD_ccode(_NoPythonCOp):
return (2,)
@local_optimizer([sparse.AddSD])
@node_rewriter([sparse.AddSD])
def local_inplace_addsd_ccode(fgraph, node):
"""
Optimization to insert inplace versions of AddSD.
......@@ -218,7 +218,7 @@ aesara.compile.optdb.register(
@register_canonicalize("fast_compile")
@register_specialize
@local_optimizer([sparse.DenseFromSparse])
@node_rewriter([sparse.DenseFromSparse])
def local_dense_from_sparse_sparse_from_dense(fgraph, node):
if isinstance(node.op, sparse.DenseFromSparse):
inp = node.inputs[0]
......@@ -226,7 +226,7 @@ def local_dense_from_sparse_sparse_from_dense(fgraph, node):
return inp.owner.inputs
@local_optimizer([sparse.AddSD])
@node_rewriter([sparse.AddSD])
def local_addsd_ccode(fgraph, node):
"""
Convert AddSD to faster AddSD_ccode.
......@@ -638,7 +638,7 @@ sd_csr = StructuredDotCSR()
# register a specialization to replace StructuredDot -> StructuredDotCSx
# This is tested in tests/test_basic.py:792
@local_optimizer([sparse._structured_dot])
@node_rewriter([sparse._structured_dot])
def local_structured_dot(fgraph, node):
if node.op == sparse._structured_dot:
a, b = node.inputs
......@@ -950,7 +950,7 @@ register_specialize(local_usmm, name="local_usmm")
# register a specialization to replace usmm_csc_dense -> usmm_csc_dense_inplace
# This is tested in tests/test_basic.py:UsmmTests
@local_optimizer([usmm_csc_dense])
@node_rewriter([usmm_csc_dense])
def local_usmm_csc_dense_inplace(fgraph, node):
if node.op == usmm_csc_dense:
return [usmm_csc_dense_inplace(*node.inputs)]
......@@ -960,7 +960,7 @@ register_specialize(local_usmm_csc_dense_inplace, "cxx_only", "inplace")
# This is tested in tests/test_basic.py:UsmmTests
@local_optimizer([usmm])
@node_rewriter([usmm])
def local_usmm_csx(fgraph, node):
"""
usmm -> usmm_csc_dense
......@@ -1120,7 +1120,7 @@ csm_grad_c = CSMGradC()
# register a specialization to replace csm_grad -> csm_grad_c
# This is tested in tests/test_opt.py:test_local_csm_grad_c
@local_optimizer([csm_grad(None)])
@node_rewriter([csm_grad(None)])
def local_csm_grad_c(fgraph, node):
"""
csm_grad(None) -> csm_grad_c
......@@ -1404,7 +1404,7 @@ mul_s_d_csr = MulSDCSR()
# register a specialization to replace MulSD -> MulSDCSX
@local_optimizer([sparse.mul_s_d])
@node_rewriter([sparse.mul_s_d])
def local_mul_s_d(fgraph, node):
if node.op == sparse.mul_s_d:
x, y = node.inputs
......@@ -1584,7 +1584,7 @@ mul_s_v_csr = MulSVCSR()
# register a specialization to replace MulSV -> MulSVCSR
@local_optimizer([sparse.mul_s_v])
@node_rewriter([sparse.mul_s_v])
def local_mul_s_v(fgraph, node):
if node.op == sparse.mul_s_v:
x, y = node.inputs
......@@ -1762,7 +1762,7 @@ structured_add_s_v_csr = StructuredAddSVCSR()
# register a specialization to replace
# structured_add_s_v -> structured_add_s_v_csr
@local_optimizer([sparse.structured_add_s_v])
@node_rewriter([sparse.structured_add_s_v])
def local_structured_add_s_v(fgraph, node):
if node.op == sparse.structured_add_s_v:
x, y = node.inputs
......@@ -2051,7 +2051,7 @@ sampling_dot_csr = SamplingDotCSR()
# register a specialization to replace SamplingDot -> SamplingDotCsr
@local_optimizer([sparse.sampling_dot])
@node_rewriter([sparse.sampling_dot])
def local_sampling_dot_csr(fgraph, node):
if not config.blas__ldflags:
# The C implementation of SamplingDotCsr relies on BLAS routines
......
差异被折叠。
......@@ -150,7 +150,7 @@ from aesara.graph.opt import (
GraphRewriter,
copy_stack_trace,
in2out,
local_optimizer,
node_rewriter,
)
from aesara.graph.optdb import SequenceDB
from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
......@@ -1733,7 +1733,7 @@ class Dot22(GemmRelated):
_dot22 = Dot22()
@local_optimizer([Dot])
@node_rewriter([Dot])
def local_dot_to_dot22(fgraph, node):
# This works for tensor.outer too because basic.outer is a macro that
# produces a dot(dimshuffle,dimshuffle) of form 4 below
......@@ -1766,7 +1766,7 @@ def local_dot_to_dot22(fgraph, node):
_logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}")
@local_optimizer([gemm_no_inplace], inplace=True)
@node_rewriter([gemm_no_inplace], inplace=True)
def local_inplace_gemm(fgraph, node):
if node.op == gemm_no_inplace:
new_out = [gemm_inplace(*node.inputs)]
......@@ -1774,7 +1774,7 @@ def local_inplace_gemm(fgraph, node):
return new_out
@local_optimizer([gemv_no_inplace], inplace=True)
@node_rewriter([gemv_no_inplace], inplace=True)
def local_inplace_gemv(fgraph, node):
if node.op == gemv_no_inplace:
new_out = [gemv_inplace(*node.inputs)]
......@@ -1782,7 +1782,7 @@ def local_inplace_gemv(fgraph, node):
return new_out
@local_optimizer([ger], inplace=True)
@node_rewriter([ger], inplace=True)
def local_inplace_ger(fgraph, node):
if node.op == ger:
new_out = [ger_destructive(*node.inputs)]
......@@ -1790,7 +1790,7 @@ def local_inplace_ger(fgraph, node):
return new_out
@local_optimizer([gemm_no_inplace])
@node_rewriter([gemm_no_inplace])
def local_gemm_to_gemv(fgraph, node):
"""GEMM acting on row or column matrices -> GEMV."""
if node.op == gemm_no_inplace:
......@@ -1807,7 +1807,7 @@ def local_gemm_to_gemv(fgraph, node):
return new_out
@local_optimizer([gemm_no_inplace])
@node_rewriter([gemm_no_inplace])
def local_gemm_to_ger(fgraph, node):
"""GEMM computing an outer-product -> GER."""
if node.op == gemm_no_inplace:
......@@ -1839,7 +1839,7 @@ def local_gemm_to_ger(fgraph, node):
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
# working
@local_optimizer([_dot22])
@node_rewriter([_dot22])
def local_dot22_to_ger_or_gemv(fgraph, node):
"""dot22 computing an outer-product -> GER."""
if node.op == _dot22:
......@@ -2033,7 +2033,7 @@ class Dot22Scalar(GemmRelated):
_dot22scalar = Dot22Scalar()
@local_optimizer([mul])
@node_rewriter([mul])
def local_dot22_to_dot22scalar(fgraph, node):
"""
Notes
......@@ -2651,7 +2651,7 @@ _batched_dot = BatchedDot()
# from opt import register_specialize, register_canonicalize
# @register_specialize
@local_optimizer([sub, add])
@node_rewriter([sub, add])
def local_print_as_we_go_along(fgraph, node):
if node.op in (sub, add):
debugprint(node)
......
......@@ -15,7 +15,7 @@ from aesara.tensor.blas import (
ger,
ger_destructive,
ldflags,
local_optimizer,
node_rewriter,
optdb,
)
......@@ -344,7 +344,7 @@ cger_inplace = CGer(True)
cger_no_inplace = CGer(False)
@local_optimizer([ger, ger_destructive])
@node_rewriter([ger, ger_destructive])
def use_c_ger(fgraph, node):
if not config.blas__ldflags:
return
......@@ -355,7 +355,7 @@ def use_c_ger(fgraph, node):
return [CGer(True)(*node.inputs)]
@local_optimizer([CGer(False)])
@node_rewriter([CGer(False)])
def make_c_ger_destructive(fgraph, node):
if isinstance(node.op, CGer) and not node.op.destructive:
return [cger_inplace(*node.inputs)]
......@@ -699,7 +699,7 @@ int main() {
check_force_gemv_init._force_init_beta = None
@local_optimizer([gemv_inplace, gemv_no_inplace])
@node_rewriter([gemv_inplace, gemv_no_inplace])
def use_c_gemv(fgraph, node):
if not config.blas__ldflags:
return
......@@ -710,7 +710,7 @@ def use_c_gemv(fgraph, node):
return [cgemv_inplace(*node.inputs)]
@local_optimizer([CGemv(inplace=False)])
@node_rewriter([CGemv(inplace=False)])
def make_c_gemv_destructive(fgraph, node):
if isinstance(node.op, CGemv) and not node.op.inplace:
inputs = list(node.inputs)
......
......@@ -11,7 +11,7 @@ from aesara.tensor.blas import (
ger,
ger_destructive,
have_fblas,
local_optimizer,
node_rewriter,
optdb,
)
......@@ -58,13 +58,13 @@ scipy_ger_no_inplace = ScipyGer(False)
scipy_ger_inplace = ScipyGer(True)
@local_optimizer([ger, ger_destructive])
@node_rewriter([ger, ger_destructive])
def use_scipy_ger(fgraph, node):
if node.op == ger:
return [scipy_ger_no_inplace(*node.inputs)]
@local_optimizer([scipy_ger_no_inplace])
@node_rewriter([scipy_ger_no_inplace])
def make_ger_destructive(fgraph, node):
if node.op == scipy_ger_no_inplace:
return [scipy_ger_inplace(*node.inputs)]
......
差异被折叠。
......@@ -18,7 +18,7 @@ from aesara.compile import optdb
from aesara.gradient import DisconnectedType, grad_not_implemented
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, local_optimizer, optimizer
from aesara.graph.opt import copy_stack_trace, node_rewriter, optimizer
from aesara.link.c.op import COp
from aesara.raise_op import Assert
from aesara.scalar import UnaryScalarOp
......@@ -1046,7 +1046,7 @@ class LogSoftmax(COp):
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@local_optimizer([Elemwise])
@node_rewriter([Elemwise])
def local_logsoftmax(fgraph, node):
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
......@@ -1071,7 +1071,7 @@ def local_logsoftmax(fgraph, node):
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@local_optimizer([SoftmaxGrad])
@node_rewriter([SoftmaxGrad])
def local_logsoftmax_grad(fgraph, node):
"""
Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad
......@@ -1150,7 +1150,7 @@ def logsoftmax(c, axis=UNSET_AXIS):
@register_specialize("fast_compile")
@local_optimizer([softmax_legacy])
@node_rewriter([softmax_legacy])
def local_softmax_with_bias(fgraph, node):
"""
Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias).
......@@ -1954,7 +1954,7 @@ optdb.register(
@register_specialize(
"fast_compile", "local_crossentropy_to_crossentropy_with_softmax_grad"
) # old name
@local_optimizer([softmax_grad_legacy])
@node_rewriter([softmax_grad_legacy])
def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node):
if node.op == softmax_grad_legacy and node.inputs[1].ndim == 2:
g_coding_dist, coding_dist = node.inputs
......@@ -1971,7 +1971,7 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node):
@register_specialize("fast_compile")
@local_optimizer([MaxAndArgmax])
@node_rewriter([MaxAndArgmax])
def local_argmax_pushdown(fgraph, node):
if (
isinstance(node.op, MaxAndArgmax)
......@@ -2060,7 +2060,7 @@ def _is_const(z, val, approx=False):
@register_specialize("fast_compile")
@local_optimizer([AdvancedSubtensor, log])
@node_rewriter([AdvancedSubtensor, log])
def local_advanced_indexing_crossentropy_onehot(fgraph, node):
log_op = None
sm = None
......@@ -2108,7 +2108,7 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node):
@register_specialize("fast_compile")
@local_optimizer([softmax_grad_legacy])
@node_rewriter([softmax_grad_legacy])
def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node):
if not (node.op == softmax_grad_legacy and node.inputs[1].ndim == 2):
return
......@@ -2323,7 +2323,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node):
@register_specialize("fast_compile")
@local_optimizer([softmax_with_bias])
@node_rewriter([softmax_with_bias])
def graph_merge_softmax_with_crossentropy_softmax(fgraph, node):
if node.op == softmax_with_bias:
x, b = node.inputs
......@@ -2340,7 +2340,7 @@ def graph_merge_softmax_with_crossentropy_softmax(fgraph, node):
@register_specialize
@register_stabilize
@register_canonicalize
@local_optimizer([CrossentropySoftmax1HotWithBiasDx])
@node_rewriter([CrossentropySoftmax1HotWithBiasDx])
def local_useless_crossentropy_softmax_1hot_with_bias_dx_alloc(fgraph, node):
"""
Replace a CrossentropySoftmax1HotWithBiasDx op, whose incoming gradient is
......
......@@ -4,7 +4,7 @@ import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, local_optimizer
from aesara.graph.opt import copy_stack_trace, node_rewriter
from aesara.scalar import Composite, add, as_common_dtype, mul, sub, true_div
from aesara.tensor import basic as at
from aesara.tensor.basic import as_tensor_variable
......@@ -778,7 +778,7 @@ class AbstractBatchNormTrainGrad(Op):
output_storage[2][0] = g_wrt_bias
@local_optimizer([AbstractBatchNormTrain])
@node_rewriter([AbstractBatchNormTrain])
def local_abstract_batch_norm_train(fgraph, node):
if not isinstance(node.op, AbstractBatchNormTrain):
return None
......@@ -832,7 +832,7 @@ def local_abstract_batch_norm_train(fgraph, node):
return results
@local_optimizer([AbstractBatchNormTrainGrad])
@node_rewriter([AbstractBatchNormTrainGrad])
def local_abstract_batch_norm_train_grad(fgraph, node):
if not isinstance(node.op, AbstractBatchNormTrainGrad):
return None
......@@ -866,7 +866,7 @@ def local_abstract_batch_norm_train_grad(fgraph, node):
return results
@local_optimizer([AbstractBatchNormInference])
@node_rewriter([AbstractBatchNormInference])
def local_abstract_batch_norm_inference(fgraph, node):
if not isinstance(node.op, AbstractBatchNormInference):
return None
......
......@@ -3,7 +3,7 @@ from aesara import tensor as at
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.graph.opt import TopoOptimizer, copy_stack_trace, local_optimizer
from aesara.graph.opt import TopoOptimizer, copy_stack_trace, node_rewriter
def get_diagonal_subtensor_view(x, i0, i1):
......@@ -296,7 +296,7 @@ def conv3d(
return out_5d
@local_optimizer([DiagonalSubtensor, IncDiagonalSubtensor])
@node_rewriter([DiagonalSubtensor, IncDiagonalSubtensor])
def local_inplace_DiagonalSubtensor(fgraph, node):
"""Also work for IncDiagonalSubtensor."""
if (
......
......@@ -5,7 +5,7 @@ import aesara.tensor as at
from aesara.configdefaults import config
from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply
from aesara.graph.opt import local_optimizer
from aesara.graph.opt import node_rewriter
from aesara.link.c.cmodule import GCC_compiler
from aesara.link.c.op import ExternalCOp, OpenMPOp
from aesara.tensor.basic_opt import register_canonicalize
......@@ -249,7 +249,7 @@ def ctc(activations, labels, input_lengths):
# Disable gradient computation if not needed
@register_canonicalize("fast_compile")
@local_optimizer([ConnectionistTemporalClassification])
@node_rewriter([ConnectionistTemporalClassification])
def local_ctc_no_grad(fgraph, node):
if isinstance(node.op, ConnectionistTemporalClassification):
if len(node.outputs) > 1:
......
......@@ -11,7 +11,7 @@ from aesara.graph.opt import (
TopoOptimizer,
copy_stack_trace,
in2out,
local_optimizer,
node_rewriter,
)
from aesara.tensor.basic_opt import register_specialize_device
from aesara.tensor.nnet.abstract_conv import (
......@@ -37,7 +37,7 @@ from aesara.tensor.nnet.corr3d import Corr3dMM, Corr3dMMGradInputs, Corr3dMMGrad
from aesara.tensor.type import TensorType
@local_optimizer([SparseBlockGemv], inplace=True)
@node_rewriter([SparseBlockGemv], inplace=True)
def local_inplace_sparse_block_gemv(fgraph, node):
"""
SparseBlockGemv(inplace=False) -> SparseBlockGemv(inplace=True)
......@@ -60,7 +60,7 @@ compile.optdb.register(
) # DEBUG
@local_optimizer([SparseBlockOuter], inplace=True)
@node_rewriter([SparseBlockOuter], inplace=True)
def local_inplace_sparse_block_outer(fgraph, node):
"""
SparseBlockOuter(inplace=False) -> SparseBlockOuter(inplace=True)
......@@ -85,7 +85,7 @@ compile.optdb.register(
# Conv opts
@local_optimizer([AbstractConv2d])
@node_rewriter([AbstractConv2d])
def local_abstractconv_gemm(fgraph, node):
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
......@@ -113,7 +113,7 @@ def local_abstractconv_gemm(fgraph, node):
return [rval]
@local_optimizer([AbstractConv3d])
@node_rewriter([AbstractConv3d])
def local_abstractconv3d_gemm(fgraph, node):
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
......@@ -139,7 +139,7 @@ def local_abstractconv3d_gemm(fgraph, node):
return [rval]
@local_optimizer([AbstractConv2d_gradWeights])
@node_rewriter([AbstractConv2d_gradWeights])
def local_abstractconv_gradweight_gemm(fgraph, node):
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
......@@ -169,7 +169,7 @@ def local_abstractconv_gradweight_gemm(fgraph, node):
return [rval]
@local_optimizer([AbstractConv3d_gradWeights])
@node_rewriter([AbstractConv3d_gradWeights])
def local_abstractconv3d_gradweight_gemm(fgraph, node):
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
......@@ -197,7 +197,7 @@ def local_abstractconv3d_gradweight_gemm(fgraph, node):
return [rval]
@local_optimizer([AbstractConv2d_gradInputs])
@node_rewriter([AbstractConv2d_gradInputs])
def local_abstractconv_gradinputs_gemm(fgraph, node):
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
......@@ -227,7 +227,7 @@ def local_abstractconv_gradinputs_gemm(fgraph, node):
return [rval]
@local_optimizer([AbstractConv3d_gradInputs])
@node_rewriter([AbstractConv3d_gradInputs])
def local_abstractconv3d_gradinputs_gemm(fgraph, node):
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
......@@ -255,7 +255,7 @@ def local_abstractconv3d_gradinputs_gemm(fgraph, node):
return [rval]
@local_optimizer([AbstractConv2d])
@node_rewriter([AbstractConv2d])
def local_conv2d_cpu(fgraph, node):
if not isinstance(node.op, AbstractConv2d) or node.inputs[0].dtype == "float16":
......@@ -287,7 +287,7 @@ def local_conv2d_cpu(fgraph, node):
return [rval]
@local_optimizer([AbstractConv2d_gradWeights])
@node_rewriter([AbstractConv2d_gradWeights])
def local_conv2d_gradweight_cpu(fgraph, node):
if (
not isinstance(node.op, AbstractConv2d_gradWeights)
......@@ -396,7 +396,7 @@ def local_conv2d_gradweight_cpu(fgraph, node):
return [res]
@local_optimizer([AbstractConv2d_gradInputs])
@node_rewriter([AbstractConv2d_gradInputs])
def local_conv2d_gradinputs_cpu(fgraph, node):
if (
not isinstance(node.op, AbstractConv2d_gradInputs)
......@@ -561,7 +561,7 @@ conv_groupopt.register(
# Verify that no AbstractConv are present in the graph
@local_optimizer(
@node_rewriter(
[
AbstractConv2d,
AbstractConv2d_gradWeights,
......
......@@ -9,7 +9,7 @@ stability.
import aesara
from aesara import printing
from aesara import scalar as aes
from aesara.graph.opt import copy_stack_trace, local_optimizer
from aesara.graph.opt import copy_stack_trace, node_rewriter
from aesara.printing import pprint
from aesara.scalar import sigmoid as scalar_sigmoid
from aesara.scalar.math import Sigmoid
......@@ -99,7 +99,7 @@ pprint.assign(ultra_fast_sigmoid, printing.FunctionPrinter(["ultra_fast_sigmoid"
# @opt.register_uncanonicalize
@local_optimizer(None)
@node_rewriter(None)
def local_ultra_fast_sigmoid(fgraph, node):
"""
When enabled, change all sigmoid to ultra_fast_sigmoid.
......@@ -159,7 +159,7 @@ def hard_sigmoid(x):
# @opt.register_uncanonicalize
@local_optimizer([sigmoid])
@node_rewriter([sigmoid])
def local_hard_sigmoid(fgraph, node):
if isinstance(node.op, Elemwise) and node.op.scalar_op == scalar_sigmoid:
out = hard_sigmoid(node.inputs[0])
......
......@@ -34,7 +34,7 @@ supposed to be canonical.
import logging
from aesara import scalar as aes
from aesara.graph.opt import copy_stack_trace, local_optimizer
from aesara.graph.opt import copy_stack_trace, node_rewriter
from aesara.tensor.basic import Alloc, alloc, constant
from aesara.tensor.basic_opt import register_uncanonicalize
from aesara.tensor.elemwise import CAReduce, DimShuffle
......@@ -47,7 +47,7 @@ _logger = logging.getLogger("aesara.tensor.opt_uncanonicalize")
@register_uncanonicalize
@local_optimizer([MaxAndArgmax])
@node_rewriter([MaxAndArgmax])
def local_max_and_argmax(fgraph, node):
"""
If we don't use the argmax, change it to a max only.
......@@ -66,7 +66,7 @@ def local_max_and_argmax(fgraph, node):
@register_uncanonicalize
@local_optimizer([neg])
@node_rewriter([neg])
def local_max_to_min(fgraph, node):
"""
Change -(max(-x)) to min.
......@@ -95,7 +95,7 @@ def local_max_to_min(fgraph, node):
@register_uncanonicalize
@local_optimizer([Alloc])
@node_rewriter([Alloc])
def local_alloc_dimshuffle(fgraph, node):
"""
If a dimshuffle is inside an alloc and only adds dimension to the
......@@ -118,7 +118,7 @@ def local_alloc_dimshuffle(fgraph, node):
@register_uncanonicalize
@local_optimizer([Reshape])
@node_rewriter([Reshape])
def local_reshape_dimshuffle(fgraph, node):
"""
If a dimshuffle is inside a reshape and does not change the order
......@@ -147,7 +147,7 @@ def local_reshape_dimshuffle(fgraph, node):
@register_uncanonicalize
@local_optimizer([DimShuffle])
@node_rewriter([DimShuffle])
def local_dimshuffle_alloc(fgraph, node):
"""
If an alloc is inside a dimshuffle which only adds dimension to the left,
......@@ -175,7 +175,7 @@ def local_dimshuffle_alloc(fgraph, node):
@register_uncanonicalize
@local_optimizer([DimShuffle])
@node_rewriter([DimShuffle])
def local_dimshuffle_subtensor(fgraph, node):
"""If a subtensor is inside a dimshuffle which only drop
broadcastable dimensions, scrap the dimshuffle and index the
......
from aesara.compile import optdb
from aesara.configdefaults import config
from aesara.graph.op import compute_test_value
from aesara.graph.opt import in2out, local_optimizer
from aesara.graph.opt import in2out, node_rewriter
from aesara.tensor.basic import constant, get_vector_length
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.extra_ops import broadcast_to
......@@ -39,7 +39,7 @@ def is_rv_used_in_graph(base_rv, node, fgraph):
return not all(_node_check(n, i) for n, i in fgraph.clients.get(base_rv, ()))
@local_optimizer([RandomVariable], inplace=True)
@node_rewriter([RandomVariable], inplace=True)
def random_make_inplace(fgraph, node):
op = node.op
......@@ -61,7 +61,7 @@ optdb.register(
)
@local_optimizer(tracks=None)
@node_rewriter(tracks=None)
def local_rv_size_lift(fgraph, node):
"""Lift the ``size`` parameter in a ``RandomVariable``.
......@@ -109,7 +109,7 @@ def local_rv_size_lift(fgraph, node):
return new_node.outputs
@local_optimizer([DimShuffle])
@node_rewriter([DimShuffle])
def local_dimshuffle_rv_lift(fgraph, node):
"""Lift a ``DimShuffle`` through ``RandomVariable`` inputs.
......@@ -266,7 +266,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
return False
@local_optimizer([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])
@node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])
def local_subtensor_rv_lift(fgraph, node):
"""Lift a ``*Subtensor`` through ``RandomVariable`` inputs.
......
......@@ -7,7 +7,7 @@ import aesara
import aesara.scalar.basic as aes
from aesara import compile
from aesara.graph.basic import Constant, Variable
from aesara.graph.opt import TopoOptimizer, copy_stack_trace, in2out, local_optimizer
from aesara.graph.opt import TopoOptimizer, copy_stack_trace, in2out, node_rewriter
from aesara.raise_op import Assert
from aesara.tensor.basic import (
Alloc,
......@@ -202,7 +202,7 @@ def get_advsubtensor_axis(indices):
@register_specialize
@local_optimizer([AdvancedSubtensor])
@node_rewriter([AdvancedSubtensor])
def local_replace_AdvancedSubtensor(fgraph, node):
r"""
This rewrite converts expressions like ``X[..., y]`` into ``X.T[y].T``, for
......@@ -231,7 +231,7 @@ def local_replace_AdvancedSubtensor(fgraph, node):
@register_specialize
@local_optimizer([AdvancedIncSubtensor])
@node_rewriter([AdvancedIncSubtensor])
def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
r"""Replace `AdvancedIncSubtensor`\s with `AdvancedIncSubtensor1`\s.
......@@ -268,7 +268,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
@register_canonicalize
@register_stabilize
@register_specialize
@local_optimizer([Subtensor])
@node_rewriter([Subtensor])
def local_subtensor_of_dot(fgraph, node):
"""Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``.
``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is
......@@ -326,7 +326,7 @@ def local_subtensor_of_dot(fgraph, node):
@register_useless
@register_canonicalize
@register_specialize
@local_optimizer([Subtensor])
@node_rewriter([Subtensor])
def local_useless_slice(fgraph, node):
"""
Remove Subtensor of the form X[0, :] -> X[0]
......@@ -362,7 +362,7 @@ def local_useless_slice(fgraph, node):
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
@register_canonicalize("fast_compile")
@local_optimizer([Subtensor])
@node_rewriter([Subtensor])
def local_subtensor_lift(fgraph, node):
"""
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
......@@ -466,7 +466,7 @@ def local_subtensor_lift(fgraph, node):
@register_canonicalize
@register_specialize
@local_optimizer([Subtensor])
@node_rewriter([Subtensor])
def local_subtensor_merge(fgraph, node):
"""
Refactored optimization to deal with all cases of tensor merging.
......@@ -537,7 +537,7 @@ def local_subtensor_merge(fgraph, node):
@register_specialize
@register_canonicalize
@local_optimizer([Subtensor])
@node_rewriter([Subtensor])
def local_subtensor_remove_broadcastable_index(fgraph, node):
"""
Remove broadcastable dimension with index 0 or -1
......@@ -586,7 +586,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
@register_useless
@register_canonicalize
@register_specialize
@local_optimizer([Subtensor])
@node_rewriter([Subtensor])
def local_subtensor_of_alloc(fgraph, node):
"""
......@@ -654,7 +654,7 @@ def local_subtensor_of_alloc(fgraph, node):
@register_specialize
@register_canonicalize
@local_optimizer([Subtensor])
@node_rewriter([Subtensor])
def local_subtensor_inc_subtensor(fgraph, node):
"""
Subtensor(SetSubtensor(x, y, idx), idx) -> y
......@@ -694,7 +694,7 @@ def local_subtensor_inc_subtensor(fgraph, node):
@register_specialize
@register_canonicalize("fast_compile")
@register_useless
@local_optimizer([Subtensor, AdvancedSubtensor1])
@node_rewriter([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(fgraph, node):
"""Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant.
......@@ -770,7 +770,7 @@ def local_subtensor_make_vector(fgraph, node):
@register_useless
@register_canonicalize
@register_specialize
@local_optimizer([IncSubtensor])
@node_rewriter([IncSubtensor])
def local_useless_inc_subtensor(fgraph, node):
r"""Remove redundant `IncSubtensor`\s.
......@@ -834,7 +834,7 @@ def local_useless_inc_subtensor(fgraph, node):
@register_canonicalize
@register_specialize
@local_optimizer([AdvancedIncSubtensor1])
@node_rewriter([AdvancedIncSubtensor1])
def local_set_to_inc_subtensor(fgraph, node):
r"""
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
......@@ -878,7 +878,7 @@ def local_set_to_inc_subtensor(fgraph, node):
@register_canonicalize
@register_specialize
@local_optimizer([Subtensor])
@node_rewriter([Subtensor])
def local_useless_subtensor(fgraph, node):
"""Remove `Subtensor` if it takes the full input."""
# This optimization needs ShapeOpt and fgraph.shape_feature
......@@ -960,7 +960,7 @@ def local_useless_subtensor(fgraph, node):
@register_canonicalize
@register_specialize
@local_optimizer([AdvancedSubtensor1])
@node_rewriter([AdvancedSubtensor1])
def local_useless_AdvancedSubtensor1(fgraph, node):
"""Remove `AdvancedSubtensor1` if it takes the full input.
......@@ -1116,7 +1116,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
@register_canonicalize
@local_optimizer([add])
@node_rewriter([add])
def local_IncSubtensor_serialize(fgraph, node):
"""
When using Subtensor, gradient graphs can be ugly.
......@@ -1216,7 +1216,7 @@ compile.optdb.register(
# gemm is the first one now, at priority 70
@local_optimizer([IncSubtensor], inplace=True)
@node_rewriter([IncSubtensor], inplace=True)
def local_inplace_setsubtensor(fgraph, node):
if isinstance(node.op, IncSubtensor) and not node.op.inplace:
dta = node.op.destroyhandler_tolerate_aliased
......@@ -1249,7 +1249,7 @@ compile.optdb.register(
)
@local_optimizer([AdvancedIncSubtensor1], inplace=True)
@node_rewriter([AdvancedIncSubtensor1], inplace=True)
def local_inplace_AdvancedIncSubtensor1(fgraph, node):
if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
new_op = node.op.clone_inplace()
......@@ -1270,7 +1270,7 @@ compile.optdb.register(
)
@local_optimizer([AdvancedIncSubtensor], inplace=True)
@node_rewriter([AdvancedIncSubtensor], inplace=True)
def local_inplace_AdvancedIncSubtensor(fgraph, node):
if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace:
new_op = type(node.op)(
......@@ -1298,7 +1298,7 @@ compile.optdb.register(
# Register old name
@register_canonicalize("local_incsubtensor_of_allocs")
@register_stabilize("local_incsubtensor_of_allocs")
@local_optimizer([IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1])
@node_rewriter([IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1])
def local_incsubtensor_of_zeros(fgraph, node):
"""
IncSubtensor(x, zeros, idx) -> x
......@@ -1323,7 +1323,7 @@ def local_incsubtensor_of_zeros(fgraph, node):
@register_canonicalize
@register_specialize
@local_optimizer([IncSubtensor])
@node_rewriter([IncSubtensor])
def local_incsubtensor_of_zeros_to_setsubtensor(fgraph, node):
"""
IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...)
......@@ -1344,7 +1344,7 @@ def local_incsubtensor_of_zeros_to_setsubtensor(fgraph, node):
@register_canonicalize("local_setsubtensor_of_allocs")
@register_stabilize("local_setsubtensor_of_allocs")
@local_optimizer([IncSubtensor])
@node_rewriter([IncSubtensor])
def local_setsubtensor_of_constants(fgraph, node):
"""
SetSubtensor(x, x[idx], idx) -> x
......@@ -1379,7 +1379,7 @@ def local_setsubtensor_of_constants(fgraph, node):
@register_canonicalize
@register_specialize
@local_optimizer([AdvancedSubtensor1])
@node_rewriter([AdvancedSubtensor1])
def local_adv_sub1_adv_inc_sub1(fgraph, node):
"""Optimize the possible AdvSub1(AdvSetSub1(...), ...).
......@@ -1446,7 +1446,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
@register_stabilize
@register_canonicalize
@register_useless
@local_optimizer([IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1])
@node_rewriter([IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1])
def local_useless_inc_subtensor_alloc(fgraph, node):
"""
Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of
......@@ -1552,7 +1552,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
@register_specialize
@register_canonicalize
@local_optimizer([Subtensor])
@node_rewriter([Subtensor])
def local_subtensor_shape_constant(fgraph, node):
r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known.
......@@ -1606,7 +1606,7 @@ def local_subtensor_shape_constant(fgraph, node):
@register_canonicalize
@local_optimizer([Subtensor])
@node_rewriter([Subtensor])
def local_subtensor_SpecifyShape_lift(fgraph, node):
"""Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``."""
......@@ -1640,7 +1640,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
@register_specialize
@local_optimizer([Join])
@node_rewriter([Join])
def local_join_subtensors(fgraph, node):
r"""Simplify contiguous :class:`Subtensor`\s inside a :class:`Join`.
......
from aesara.compile import optdb
from aesara.graph.opt import TopoOptimizer, local_optimizer
from aesara.graph.opt import TopoOptimizer, node_rewriter
from aesara.typed_list.basic import Append, Extend, Insert, Remove, Reverse
@local_optimizer([Append, Extend, Insert, Reverse, Remove], inplace=True)
@node_rewriter([Append, Extend, Insert, Reverse, Remove], inplace=True)
def typed_list_inplace_opt(fgraph, node):
if (
isinstance(node.op, (Append, Extend, Insert, Reverse, Remove))
......
......@@ -67,15 +67,15 @@ Local optimization
A local optimization is an object which defines the following methods:
.. class:: LocalOptimizer
.. class:: NodeRewriter
.. method:: transform(fgraph, node)
This method takes a :class:`FunctionGraph` and an :class:`Apply` node and
returns either ``False`` to signify that no changes are to be done or a
list of :class:`Variable`\s which matches the length of the node's ``outputs``
list. When the :class:`LocalOptimizer` is applied by a :class:`NavigatorOptimizer`, the outputs
of the node passed as argument to the :class:`LocalOptimizer` will be replaced by
list. When the :class:`NodeRewriter` is applied by a :class:`NavigatorOptimizer`, the outputs
of the node passed as argument to the :class:`NodeRewriter` will be replaced by
the list returned.
......@@ -218,10 +218,10 @@ The local version of the above code would be the following:
.. testcode::
from aesara.graph.opt import LocalOptimizer
from aesara.graph.opt import NodeRewriter
class LocalSimplify(LocalOptimizer):
class LocalSimplify(NodeRewriter):
def transform(self, fgraph, node):
if node.op == true_div:
x, y = node.inputs
......@@ -234,7 +234,7 @@ The local version of the above code would be the following:
return False
def tracks(self):
# This tells certain navigators to only apply this `LocalOptimizer`
# This tells certain navigators to only apply this `NodeRewriter`
# on these kinds of `Op`s
return [true_div]
......@@ -242,7 +242,7 @@ The local version of the above code would be the following:
In this case, the transformation is defined in the
:meth:`LocalOptimizer.transform` method, which is given an explicit
:meth:`NodeRewriter.transform` method, which is given an explicit
:class:`Apply` node on which to work. The entire graph--as a ``fgraph``--is
also provided, in case global information is needed.
......@@ -273,7 +273,7 @@ FunctionGraph(add(z, mul(x, true_div(z, x))))
:class:`OpSub`, :class:`OpRemove`, :class:`PatternSub`
++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aesara defines some shortcuts to make :class:`LocalOptimizer`\s:
Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
.. function:: OpSub(op1, op2)
......@@ -433,7 +433,7 @@ This means that a relation that--say--represents :math:`x + x = 2 x` can be
utilized in both directions.
Currently, the local optimizer :class:`KanrenRelationSub` provides a means of
turning :mod:`kanren` relations into :class:`LocalOptimizer`\s; however,
turning :mod:`kanren` relations into :class:`NodeRewriter`\s; however,
:mod:`kanren` can always be used directly from within a custom :class:`Rewriter`, so
:class:`KanrenRelationSub` is not necessary.
......@@ -561,7 +561,7 @@ serve as a basis for filtering.
The point of :obj:`optdb` is that you might want to apply many optimizations
to a computation graph in many unique patterns. For example, you might
want to do optimization X, then optimization Y, then optimization Z. And then
maybe optimization Y is an :class:`EquilibriumOptimizer` containing :class:`LocalOptimizer`\s A, B
maybe optimization Y is an :class:`EquilibriumOptimizer` containing :class:`NodeRewriter`\s A, B
and C which are applied on every node of the graph until they all fail to change
it. If some optimizations act up, we want an easy way to turn them off. Ditto if
some optimizations are very CPU-intensive and we don't want to take the time to
......@@ -596,14 +596,14 @@ is returned. If the :class:`SequenceDB` contains :class:`OptimizationDatabase`
instances, the :class:`OptimizationQuery` will be passed to them as well and the
optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`LocalOptimizer` or :class:`OptimizationDatabase` objects. Each of them
An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`OptimizationDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to
an :class:`EquilibriumDB`, all :class:`LocalOptimizer`\s that match the query are
an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are
inserted into an :class:`EquilibriumOptimizer`, which is returned. If the
:class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the
:class:`OptimizationQuery` will be passed to them as well and the
:class:`LocalOptimizer`\s they return will be put in their places
(note that as of yet no :class:`OptimizationDatabase` can produce :class:`LocalOptimizer` objects, so this
:class:`NodeRewriter`\s they return will be put in their places
(note that as of yet no :class:`OptimizationDatabase` can produce :class:`NodeRewriter` objects, so this
is a moot point).
Aesara contains one principal :class:`OptimizationDatabase` object, :class:`optdb`, which
......@@ -697,10 +697,10 @@ already-compiled functions will see no change. The 'order' parameter
Registering a :class:`LocalOptimizer`
-------------------------------------
Registering a :class:`NodeRewriter`
-----------------------------------
:class:`LocalOptimizer`\s may be registered in two ways:
:class:`NodeRewriter`\s may be registered in two ways:
* Wrap them in a :class:`NavigatorOptimizer` and insert them like a global optimizer
(see previous section).
......
......@@ -18,7 +18,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.features import BadOptimization
from aesara.graph.op import Op
from aesara.graph.opt import local_optimizer
from aesara.graph.opt import node_rewriter
from aesara.graph.optdb import EquilibriumDB
from aesara.link.c.op import COp
from aesara.tensor.math import add, dot, log
......@@ -237,7 +237,7 @@ def test_badthunkoutput():
def test_badoptimization():
@local_optimizer([add])
@node_rewriter([add])
def insert_broken_add(fgraph, node):
if node.op == add:
return [off_by_half(*node.inputs)]
......@@ -263,7 +263,7 @@ def test_badoptimization():
def test_badoptimization_opt_err():
# This variant of test_badoptimization() replace the working code
# with a new apply node that will raise an error.
@local_optimizer([add])
@node_rewriter([add])
def insert_bigger_b_add(fgraph, node):
if node.op == add:
inputs = list(node.inputs)
......@@ -272,7 +272,7 @@ def test_badoptimization_opt_err():
return [node.op(*inputs)]
return False
@local_optimizer([add])
@node_rewriter([add])
def insert_bad_dtype(fgraph, node):
if node.op == add:
inputs = list(node.inputs)
......@@ -326,7 +326,7 @@ def test_stochasticoptimization():
last_time_replaced = [False]
@local_optimizer([add])
@node_rewriter([add])
def insert_broken_add_sometimes(fgraph, node):
if node.op == add:
last_time_replaced[0] = not last_time_replaced[0]
......
......@@ -15,10 +15,10 @@ from aesara.graph.opt import (
PatternSub,
TopoOptimizer,
in2out,
local_optimizer,
logging,
node_rewriter,
pre_constant_merge,
pre_greedy_local_optimizer,
pre_greedy_node_rewriter,
)
from aesara.raise_op import assert_op
from aesara.tensor.basic_opt import constant_folding
......@@ -547,7 +547,7 @@ def test_pre_constant_merge():
assert res == [adv]
def test_pre_greedy_local_optimizer():
def test_pre_greedy_node_rewriter():
empty_fgraph = FunctionGraph([], [])
......@@ -564,7 +564,7 @@ def test_pre_greedy_local_optimizer():
# This should fold `o1`, because it has only `Constant` arguments, and
# replace it with the `Constant` result
cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], o2)
cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], o2)
assert cst.owner.inputs[0].owner is None
assert cst.owner.inputs[1] is c2
......@@ -577,14 +577,14 @@ def test_pre_greedy_local_optimizer():
fg = FunctionGraph([], [o1], clone=False)
o2 = op1(o1, c2, x, o3, o1)
cst = pre_greedy_local_optimizer(fg, [constant_folding], o2)
cst = pre_greedy_node_rewriter(fg, [constant_folding], o2)
assert cst.owner.inputs[0] is o1
assert cst.owner.inputs[4] is cst.owner.inputs[0]
# What exactly is this supposed to test?
ms = MakeSlice()(1)
cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], ms)
cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], ms)
assert isinstance(cst, SliceConstant)
......@@ -673,13 +673,13 @@ class TestLocalOptGroup:
fgraph = FunctionGraph([x, y], [o1], clone=False)
@local_optimizer(None)
@node_rewriter(None)
def local_opt_1(fgraph, node):
if node.inputs[0] == x:
res = op2(y, *node.inputs[1:])
return [res]
@local_optimizer(None)
@node_rewriter(None)
def local_opt_2(fgraph, node):
if node.inputs[0] == y:
res = op2(x, *node.inputs[1:])
......@@ -703,8 +703,8 @@ class TestLocalOptGroup:
)
def test_local_optimizer_str():
@local_optimizer([op1, MyOp])
def test_node_rewriter_str():
@node_rewriter([op1, MyOp])
def local_opt_1(fgraph, node):
pass
......@@ -715,17 +715,17 @@ def test_local_optimizer_str():
assert "local_opt_1" in res
def test_local_optimizer():
def test_node_rewriter():
with pytest.raises(ValueError):
@local_optimizer([])
@node_rewriter([])
def local_bad_1(fgraph, node):
return node.outputs
with pytest.raises(TypeError):
@local_optimizer([None])
@node_rewriter([None])
def local_bad_2(fgraph, node):
return node.outputs
......@@ -748,7 +748,7 @@ def test_local_optimizer():
hits = [0]
@local_optimizer([op1, MyNewOp])
@node_rewriter([op1, MyNewOp])
def local_opt_1(fgraph, node, hits=hits):
hits[0] += 1
return node.outputs
......@@ -766,24 +766,24 @@ def test_local_optimizer():
assert hits[0] == 2
def test_TrackingLocalOptimizer():
@local_optimizer(None)
def test_TrackingNodeRewriter():
@node_rewriter(None)
def local_opt_1(fgraph, node):
pass
@local_optimizer([op1])
@node_rewriter([op1])
def local_opt_2(fgraph, node):
pass
@local_optimizer([Op])
@node_rewriter([Op])
def local_opt_3(fgraph, node):
pass
@local_optimizer([MyOp])
@node_rewriter([MyOp])
def local_opt_4(fgraph, node):
pass
@local_optimizer([MyOp])
@node_rewriter([MyOp])
def local_opt_5(fgraph, node):
pass
......
......@@ -16,7 +16,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import check_stack_trace, local_optimizer, out2in
from aesara.graph.opt import check_stack_trace, node_rewriter, out2in
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type
......@@ -1752,7 +1752,7 @@ class TestShapeOptimizer:
identity_shape = IdentityShape()
@local_optimizer([IdentityNoShape])
@node_rewriter([IdentityNoShape])
def local_identity_noshape_to_identity_shape(fgraph, node):
"""Optimization transforming the first Op into the second"""
if isinstance(node.op, IdentityNoShape):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论