提交 1d5b1d94 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Replace uses of in2out and out2in by a depth-first search rewriter

上级 aa7e4d6b
...@@ -27,7 +27,12 @@ from pytensor.graph.features import AlreadyThere, Feature ...@@ -27,7 +27,12 @@ from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
from pytensor.graph.traversal import applys_between, toposort, vars_between from pytensor.graph.traversal import (
apply_ancestors,
applys_between,
toposort,
vars_between,
)
from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import flatten from pytensor.utils import flatten
...@@ -1995,12 +2000,13 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter): ...@@ -1995,12 +2000,13 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
def __init__( def __init__(
self, self,
node_rewriter: NodeRewriter, node_rewriter: NodeRewriter,
order: Literal["out_to_in", "in_to_out"] = "in_to_out", order: Literal["out_to_in", "in_to_out", "dfs"] = "in_to_out",
ignore_newtrees: bool = False, ignore_newtrees: bool = False,
failure_callback: FailureCallbackType | None = None, failure_callback: FailureCallbackType | None = None,
): ):
if order not in ("out_to_in", "in_to_out"): valid_orders = ("out_to_in", "in_to_out", "dfs")
raise ValueError("order must be 'out_to_in' or 'in_to_out'") if order not in valid_orders:
raise ValueError(f"order must be one of {valid_orders}, got {order}")
self.order = order self.order = order
super().__init__(node_rewriter, ignore_newtrees, failure_callback) super().__init__(node_rewriter, ignore_newtrees, failure_callback)
...@@ -2010,7 +2016,11 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2010,7 +2016,11 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
callback_before = fgraph.execute_callbacks_time callback_before = fgraph.execute_callbacks_time
nb_nodes_start = len(fgraph.apply_nodes) nb_nodes_start = len(fgraph.apply_nodes)
t0 = time.perf_counter() t0 = time.perf_counter()
q = deque(toposort(start_from)) q = deque(
apply_ancestors(start_from)
if (self.order == "dfs")
else toposort(start_from)
)
io_t = time.perf_counter() - t0 io_t = time.perf_counter() - t0
def importer(node): def importer(node):
...@@ -2134,6 +2144,7 @@ def walking_rewriter( ...@@ -2134,6 +2144,7 @@ def walking_rewriter(
in2out = partial(walking_rewriter, "in_to_out") in2out = partial(walking_rewriter, "in_to_out")
out2in = partial(walking_rewriter, "out_to_in") out2in = partial(walking_rewriter, "out_to_in")
dfs_rewriter = partial(walking_rewriter, "dfs")
class ChangeTracker(Feature): class ChangeTracker(Feature):
......
...@@ -29,7 +29,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -29,7 +29,7 @@ from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter, EquilibriumGraphRewriter,
GraphRewriter, GraphRewriter,
copy_stack_trace, copy_stack_trace,
in2out, dfs_rewriter,
node_rewriter, node_rewriter,
) )
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
...@@ -2558,7 +2558,7 @@ optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6) ...@@ -2558,7 +2558,7 @@ optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
# ScanSaveMem should execute only once per node. # ScanSaveMem should execute only once per node.
optdb.register( optdb.register(
"scan_save_mem_prealloc", "scan_save_mem_prealloc",
in2out(scan_save_mem_prealloc, ignore_newtrees=True), dfs_rewriter(scan_save_mem_prealloc, ignore_newtrees=True),
"fast_run", "fast_run",
"scan", "scan",
"scan_save_mem", "scan_save_mem",
...@@ -2566,7 +2566,7 @@ optdb.register( ...@@ -2566,7 +2566,7 @@ optdb.register(
) )
optdb.register( optdb.register(
"scan_save_mem_no_prealloc", "scan_save_mem_no_prealloc",
in2out(scan_save_mem_no_prealloc, ignore_newtrees=True), dfs_rewriter(scan_save_mem_no_prealloc, ignore_newtrees=True),
"numba", "numba",
"jax", "jax",
"pytorch", "pytorch",
...@@ -2587,7 +2587,7 @@ scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan") ...@@ -2587,7 +2587,7 @@ scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan")
scan_seqopt1.register( scan_seqopt1.register(
"scan_remove_constants_and_unused_inputs0", "scan_remove_constants_and_unused_inputs0",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
"fast_run", "fast_run",
"scan", "scan",
...@@ -2596,7 +2596,7 @@ scan_seqopt1.register( ...@@ -2596,7 +2596,7 @@ scan_seqopt1.register(
scan_seqopt1.register( scan_seqopt1.register(
"scan_push_out_non_seq", "scan_push_out_non_seq",
in2out(scan_push_out_non_seq, ignore_newtrees=True), dfs_rewriter(scan_push_out_non_seq, ignore_newtrees=True),
"scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name "scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name
"fast_run", "fast_run",
"scan", "scan",
...@@ -2606,7 +2606,7 @@ scan_seqopt1.register( ...@@ -2606,7 +2606,7 @@ scan_seqopt1.register(
scan_seqopt1.register( scan_seqopt1.register(
"scan_push_out_seq", "scan_push_out_seq",
in2out(scan_push_out_seq, ignore_newtrees=True), dfs_rewriter(scan_push_out_seq, ignore_newtrees=True),
"scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name "scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name
"fast_run", "fast_run",
"scan", "scan",
...@@ -2617,7 +2617,7 @@ scan_seqopt1.register( ...@@ -2617,7 +2617,7 @@ scan_seqopt1.register(
scan_seqopt1.register( scan_seqopt1.register(
"scan_push_out_dot1", "scan_push_out_dot1",
in2out(scan_push_out_dot1, ignore_newtrees=True), dfs_rewriter(scan_push_out_dot1, ignore_newtrees=True),
"scan_pushout_dot1", # For backcompat: so it can be tagged with old name "scan_pushout_dot1", # For backcompat: so it can be tagged with old name
"fast_run", "fast_run",
"more_mem", "more_mem",
...@@ -2630,7 +2630,7 @@ scan_seqopt1.register( ...@@ -2630,7 +2630,7 @@ scan_seqopt1.register(
scan_seqopt1.register( scan_seqopt1.register(
"scan_push_out_add", "scan_push_out_add",
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`? # TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
in2out(scan_push_out_add, ignore_newtrees=False), dfs_rewriter(scan_push_out_add, ignore_newtrees=False),
"scan_pushout_add", # For backcompat: so it can be tagged with old name "scan_pushout_add", # For backcompat: so it can be tagged with old name
"fast_run", "fast_run",
"more_mem", "more_mem",
...@@ -2641,14 +2641,14 @@ scan_seqopt1.register( ...@@ -2641,14 +2641,14 @@ scan_seqopt1.register(
scan_eqopt2.register( scan_eqopt2.register(
"while_scan_merge_subtensor_last_element", "while_scan_merge_subtensor_last_element",
in2out(while_scan_merge_subtensor_last_element, ignore_newtrees=True), dfs_rewriter(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
"fast_run", "fast_run",
"scan", "scan",
) )
scan_eqopt2.register( scan_eqopt2.register(
"constant_folding_for_scan2", "constant_folding_for_scan2",
in2out(constant_folding, ignore_newtrees=True), dfs_rewriter(constant_folding, ignore_newtrees=True),
"fast_run", "fast_run",
"scan", "scan",
) )
...@@ -2656,7 +2656,7 @@ scan_eqopt2.register( ...@@ -2656,7 +2656,7 @@ scan_eqopt2.register(
scan_eqopt2.register( scan_eqopt2.register(
"scan_remove_constants_and_unused_inputs1", "scan_remove_constants_and_unused_inputs1",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
"fast_run", "fast_run",
"scan", "scan",
...@@ -2671,7 +2671,7 @@ scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan") ...@@ -2671,7 +2671,7 @@ scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan")
# After Merge optimization # After Merge optimization
scan_eqopt2.register( scan_eqopt2.register(
"scan_remove_constants_and_unused_inputs2", "scan_remove_constants_and_unused_inputs2",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
"fast_run", "fast_run",
"scan", "scan",
...@@ -2679,7 +2679,7 @@ scan_eqopt2.register( ...@@ -2679,7 +2679,7 @@ scan_eqopt2.register(
scan_eqopt2.register( scan_eqopt2.register(
"scan_merge_inouts", "scan_merge_inouts",
in2out(scan_merge_inouts, ignore_newtrees=True), dfs_rewriter(scan_merge_inouts, ignore_newtrees=True),
"fast_run", "fast_run",
"scan", "scan",
) )
...@@ -2687,7 +2687,7 @@ scan_eqopt2.register( ...@@ -2687,7 +2687,7 @@ scan_eqopt2.register(
# After everything else # After everything else
scan_eqopt2.register( scan_eqopt2.register(
"scan_remove_constants_and_unused_inputs3", "scan_remove_constants_and_unused_inputs3",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
"remove_constants_and_unused_inputs_scan", "remove_constants_and_unused_inputs_scan",
"fast_run", "fast_run",
"scan", "scan",
......
...@@ -3,7 +3,7 @@ from copy import copy ...@@ -3,7 +3,7 @@ from copy import copy
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph import Constant, graph_inputs from pytensor.graph import Constant, graph_inputs
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter, node_rewriter
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.scan.rewriting import scan_seqopt1 from pytensor.scan.rewriting import scan_seqopt1
from pytensor.tensor._linalg.solve.tridiagonal import ( from pytensor.tensor._linalg.solve.tridiagonal import (
...@@ -244,7 +244,7 @@ def scan_split_non_sequence_decomposition_and_solve(fgraph, node): ...@@ -244,7 +244,7 @@ def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
scan_seqopt1.register( scan_seqopt1.register(
scan_split_non_sequence_decomposition_and_solve.__name__, scan_split_non_sequence_decomposition_and_solve.__name__,
in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True), dfs_rewriter(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
"fast_run", "fast_run",
"scan", "scan",
"scan_pushout", "scan_pushout",
...@@ -261,7 +261,7 @@ def reuse_decomposition_multiple_solves_jax(fgraph, node): ...@@ -261,7 +261,7 @@ def reuse_decomposition_multiple_solves_jax(fgraph, node):
optdb["specialize"].register( optdb["specialize"].register(
reuse_decomposition_multiple_solves_jax.__name__, reuse_decomposition_multiple_solves_jax.__name__,
in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True), dfs_rewriter(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
"jax", "jax",
use_db_name_as_tag=False, use_db_name_as_tag=False,
) )
...@@ -276,7 +276,9 @@ def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node): ...@@ -276,7 +276,9 @@ def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
scan_seqopt1.register( scan_seqopt1.register(
scan_split_non_sequence_decomposition_and_solve_jax.__name__, scan_split_non_sequence_decomposition_and_solve_jax.__name__,
in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True), dfs_rewriter(
scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True
),
"jax", "jax",
use_db_name_as_tag=False, use_db_name_as_tag=False,
position=2, position=2,
......
...@@ -4,7 +4,11 @@ from pytensor.compile import optdb ...@@ -4,7 +4,11 @@ from pytensor.compile import optdb
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import ancestors from pytensor.graph import ancestors
from pytensor.graph.op import compute_test_value from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter from pytensor.graph.rewriting.basic import (
copy_stack_trace,
dfs_rewriter,
node_rewriter,
)
from pytensor.tensor import NoneConst, TensorVariable from pytensor.tensor import NoneConst, TensorVariable
from pytensor.tensor.basic import constant from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
...@@ -57,7 +61,7 @@ def random_make_inplace(fgraph, node): ...@@ -57,7 +61,7 @@ def random_make_inplace(fgraph, node):
optdb.register( optdb.register(
"random_make_inplace", "random_make_inplace",
in2out(random_make_inplace, ignore_newtrees=True), dfs_rewriter(random_make_inplace, ignore_newtrees=True),
"fast_run", "fast_run",
"inplace", "inplace",
position=50.9, position=50.9,
......
...@@ -2,8 +2,7 @@ import re ...@@ -2,8 +2,7 @@ import re
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph import Constant from pytensor.graph import Constant
from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.basic import dfs_rewriter, in2out, node_rewriter
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.tensor import abs as abs_t from pytensor.tensor import abs as abs_t
from pytensor.tensor import broadcast_arrays, exp, floor, log, log1p, reciprocal, sqrt from pytensor.tensor import broadcast_arrays, exp, floor, log, log1p, reciprocal, sqrt
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
...@@ -179,51 +178,16 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node): ...@@ -179,51 +178,16 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
return new_op.make_node(rng, size, a_vector_param, *other_params).outputs return new_op.make_node(rng, size, a_vector_param, *other_params).outputs
random_vars_opt = SequenceDB() random_vars_opt = dfs_rewriter(
random_vars_opt.register( lognormal_from_normal,
"lognormal_from_normal", halfnormal_from_normal,
in2out(lognormal_from_normal), geometric_from_uniform,
"jax", negative_binomial_from_gamma_poisson,
) inverse_gamma_from_gamma,
random_vars_opt.register( generalized_gamma_from_gamma,
"halfnormal_from_normal", wald_from_normal_uniform,
in2out(halfnormal_from_normal), beta_binomial_from_beta_binomial,
"jax", materialize_implicit_arange_choice_without_replacement,
)
random_vars_opt.register(
"geometric_from_uniform",
in2out(geometric_from_uniform),
"jax",
)
random_vars_opt.register(
"negative_binomial_from_gamma_poisson",
in2out(negative_binomial_from_gamma_poisson),
"jax",
)
random_vars_opt.register(
"inverse_gamma_from_gamma",
in2out(inverse_gamma_from_gamma),
"jax",
)
random_vars_opt.register(
"generalized_gamma_from_gamma",
in2out(generalized_gamma_from_gamma),
"jax",
)
random_vars_opt.register(
"wald_from_normal_uniform",
in2out(wald_from_normal_uniform),
"jax",
)
random_vars_opt.register(
"beta_binomial_from_beta_binomial",
in2out(beta_binomial_from_beta_binomial),
"jax",
)
random_vars_opt.register(
"materialize_implicit_arange_choice_without_replacement",
in2out(materialize_implicit_arange_choice_without_replacement),
"jax",
) )
optdb.register("jax_random_vars_rewrites", random_vars_opt, "jax", position=110) optdb.register("jax_random_vars_rewrites", random_vars_opt, "jax", position=110)
......
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph import node_rewriter from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import out2in from pytensor.graph.rewriting.basic import dfs_rewriter
from pytensor.tensor import as_tensor, constant from pytensor.tensor import as_tensor, constant
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.rewriting.shape import ShapeFeature
...@@ -82,7 +82,7 @@ def introduce_explicit_core_shape_rv(fgraph, node): ...@@ -82,7 +82,7 @@ def introduce_explicit_core_shape_rv(fgraph, node):
optdb.register( optdb.register(
introduce_explicit_core_shape_rv.__name__, introduce_explicit_core_shape_rv.__name__,
out2in(introduce_explicit_core_shape_rv), dfs_rewriter(introduce_explicit_core_shape_rv),
"numba", "numba",
position=100, position=100,
) )
...@@ -35,6 +35,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -35,6 +35,7 @@ from pytensor.graph.rewriting.basic import (
NodeRewriter, NodeRewriter,
Rewriter, Rewriter,
copy_stack_trace, copy_stack_trace,
dfs_rewriter,
in2out, in2out,
node_rewriter, node_rewriter,
) )
...@@ -538,7 +539,7 @@ def local_alloc_empty_to_zeros(fgraph, node): ...@@ -538,7 +539,7 @@ def local_alloc_empty_to_zeros(fgraph, node):
compile.optdb.register( compile.optdb.register(
"local_alloc_empty_to_zeros", "local_alloc_empty_to_zeros",
in2out(local_alloc_empty_to_zeros), dfs_rewriter(local_alloc_empty_to_zeros),
# After move to gpu and merge2, before inplace. # After move to gpu and merge2, before inplace.
"alloc_empty_to_zeros", "alloc_empty_to_zeros",
position=49.3, position=49.3,
......
...@@ -77,7 +77,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -77,7 +77,7 @@ from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter, EquilibriumGraphRewriter,
GraphRewriter, GraphRewriter,
copy_stack_trace, copy_stack_trace,
in2out, dfs_rewriter,
node_rewriter, node_rewriter,
) )
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
...@@ -721,7 +721,7 @@ optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7) ...@@ -721,7 +721,7 @@ optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7)
# fast_compile is needed to have GpuDot22 created. # fast_compile is needed to have GpuDot22 created.
blas_optdb.register( blas_optdb.register(
"local_dot_to_dot22", "local_dot_to_dot22",
in2out(local_dot_to_dot22), dfs_rewriter(local_dot_to_dot22),
"fast_run", "fast_run",
"fast_compile", "fast_compile",
position=0, position=0,
...@@ -744,7 +744,7 @@ blas_optdb.register( ...@@ -744,7 +744,7 @@ blas_optdb.register(
) )
blas_opt_inplace = in2out( blas_opt_inplace = dfs_rewriter(
local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace" local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
) )
optdb.register( optdb.register(
...@@ -883,7 +883,7 @@ def local_dot22_to_dot22scalar(fgraph, node): ...@@ -883,7 +883,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
# dot22scalar and gemm give more speed up then dot22scalar # dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register( blas_optdb.register(
"local_dot22_to_dot22scalar", "local_dot22_to_dot22scalar",
in2out(local_dot22_to_dot22scalar), dfs_rewriter(local_dot22_to_dot22scalar),
"fast_run", "fast_run",
position=12, position=12,
) )
......
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.rewriting.basic import in2out from pytensor.graph.rewriting.basic import dfs_rewriter
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor.blas import gemv_inplace, gemv_no_inplace, ger, ger_destructive from pytensor.tensor.blas import gemv_inplace, gemv_no_inplace, ger, ger_destructive
from pytensor.tensor.blas_c import ( from pytensor.tensor.blas_c import (
...@@ -56,13 +56,15 @@ def make_c_gemv_destructive(fgraph, node): ...@@ -56,13 +56,15 @@ def make_c_gemv_destructive(fgraph, node):
blas_optdb.register( blas_optdb.register(
"use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20 "use_c_blas", dfs_rewriter(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20
) )
# this matches the InplaceBlasOpt defined in blas.py # this matches the InplaceBlasOpt defined in blas.py
optdb.register( optdb.register(
"c_blas_destructive", "c_blas_destructive",
in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"), dfs_rewriter(
make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"
),
"fast_run", "fast_run",
"inplace", "inplace",
"c_blas", "c_blas",
......
...@@ -2,7 +2,7 @@ from pytensor.compile.mode import optdb ...@@ -2,7 +2,7 @@ from pytensor.compile.mode import optdb
from pytensor.graph import Constant, Op, node_rewriter from pytensor.graph import Constant, Op, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise, _squeeze_left from pytensor.tensor.blockwise import Blockwise, _squeeze_left
...@@ -66,7 +66,7 @@ def local_useless_unbatched_blockwise(fgraph, node): ...@@ -66,7 +66,7 @@ def local_useless_unbatched_blockwise(fgraph, node):
# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops # We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
optdb.register( optdb.register(
"local_useless_unbatched_blockwise", "local_useless_unbatched_blockwise",
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True), dfs_rewriter(local_useless_unbatched_blockwise, ignore_newtrees=True),
"fast_run", "fast_run",
"fast_compile", "fast_compile",
"blockwise", "blockwise",
......
...@@ -21,6 +21,7 @@ from pytensor.graph.op import Op ...@@ -21,6 +21,7 @@ from pytensor.graph.op import Op
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
GraphRewriter, GraphRewriter,
copy_stack_trace, copy_stack_trace,
dfs_rewriter,
in2out, in2out,
node_rewriter, node_rewriter,
out2in, out2in,
...@@ -1237,21 +1238,21 @@ fuse_seqopt.register( ...@@ -1237,21 +1238,21 @@ fuse_seqopt.register(
) )
fuse_seqopt.register( fuse_seqopt.register(
"local_useless_composite_outputs", "local_useless_composite_outputs",
in2out(local_useless_composite_outputs), dfs_rewriter(local_useless_composite_outputs),
"fast_run", "fast_run",
"fusion", "fusion",
position=2, position=2,
) )
fuse_seqopt.register( fuse_seqopt.register(
"local_careduce_fusion", "local_careduce_fusion",
in2out(local_careduce_fusion), dfs_rewriter(local_careduce_fusion),
"fast_run", "fast_run",
"fusion", "fusion",
position=10, position=10,
) )
fuse_seqopt.register( fuse_seqopt.register(
"local_inline_composite_constants", "local_inline_composite_constants",
in2out(local_inline_composite_constants, ignore_newtrees=True), dfs_rewriter(local_inline_composite_constants, ignore_newtrees=True),
"fast_run", "fast_run",
"fusion", "fusion",
position=20, position=20,
......
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.basic import dfs_rewriter, node_rewriter
from pytensor.tensor.basic import MakeVector from pytensor.tensor.basic import MakeVector
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Sum from pytensor.tensor.math import Sum
...@@ -46,7 +46,7 @@ def boolean_indexing_set_or_inc(fgraph, node): ...@@ -46,7 +46,7 @@ def boolean_indexing_set_or_inc(fgraph, node):
optdb.register( optdb.register(
"jax_boolean_indexing_set_or_inc", "jax_boolean_indexing_set_or_inc",
in2out(boolean_indexing_set_or_inc), dfs_rewriter(boolean_indexing_set_or_inc),
"jax", "jax",
position=100, position=100,
) )
...@@ -96,7 +96,7 @@ def boolean_indexing_sum(fgraph, node): ...@@ -96,7 +96,7 @@ def boolean_indexing_sum(fgraph, node):
optdb.register( optdb.register(
"jax_boolean_indexing_sum", in2out(boolean_indexing_sum), "jax", position=100 "jax_boolean_indexing_sum", dfs_rewriter(boolean_indexing_sum), "jax", position=100
) )
...@@ -144,7 +144,7 @@ def shape_parameter_as_tuple(fgraph, node): ...@@ -144,7 +144,7 @@ def shape_parameter_as_tuple(fgraph, node):
optdb.register( optdb.register(
"jax_shape_parameter_as_tuple", "jax_shape_parameter_as_tuple",
in2out(shape_parameter_as_tuple), dfs_rewriter(shape_parameter_as_tuple),
"jax", "jax",
position=100, position=100,
) )
...@@ -10,7 +10,7 @@ from pytensor.compile import optdb ...@@ -10,7 +10,7 @@ from pytensor.compile import optdb
from pytensor.graph import Apply, FunctionGraph from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
copy_stack_trace, copy_stack_trace,
in2out, dfs_rewriter,
node_rewriter, node_rewriter,
) )
from pytensor.graph.rewriting.unify import OpPattern from pytensor.graph.rewriting.unify import OpPattern
...@@ -905,7 +905,7 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply): ...@@ -905,7 +905,7 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
optdb.register( optdb.register(
"jax_bilinaer_lyapunov_to_direct", "jax_bilinaer_lyapunov_to_direct",
in2out(jax_bilinaer_lyapunov_to_direct), dfs_rewriter(jax_bilinaer_lyapunov_to_direct),
"jax", "jax",
position=0.9, # Run before canonicalization position=0.9, # Run before canonicalization
) )
......
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph import node_rewriter from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import out2in from pytensor.graph.rewriting.basic import dfs_rewriter
from pytensor.graph.traversal import applys_between from pytensor.graph.traversal import applys_between
from pytensor.tensor.basic import as_tensor, constant from pytensor.tensor.basic import as_tensor, constant
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
...@@ -102,7 +102,7 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): ...@@ -102,7 +102,7 @@ def introduce_explicit_core_shape_blockwise(fgraph, node):
optdb.register( optdb.register(
introduce_explicit_core_shape_blockwise.__name__, introduce_explicit_core_shape_blockwise.__name__,
out2in(introduce_explicit_core_shape_blockwise), dfs_rewriter(introduce_explicit_core_shape_blockwise),
"numba", "numba",
position=100, position=100,
) )
...@@ -4,7 +4,7 @@ from pytensor import Variable, clone_replace ...@@ -4,7 +4,7 @@ from pytensor import Variable, clone_replace
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.graph import Apply, node_rewriter from pytensor.graph import Apply, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter
from pytensor.tensor.basic import AllocDiag from pytensor.tensor.basic import AllocDiag
from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.basic import register_specialize
...@@ -37,7 +37,7 @@ def inline_ofg_expansion(fgraph, node): ...@@ -37,7 +37,7 @@ def inline_ofg_expansion(fgraph, node):
# and before the first scan optimizer. # and before the first scan optimizer.
optdb.register( optdb.register(
"inline_ofg_expansion", "inline_ofg_expansion",
in2out(inline_ofg_expansion), dfs_rewriter(inline_ofg_expansion),
"fast_compile", "fast_compile",
"fast_run", "fast_run",
position=-0.01, position=-0.01,
......
...@@ -10,9 +10,9 @@ from pytensor.graph.basic import Constant, Variable ...@@ -10,9 +10,9 @@ from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
WalkingGraphRewriter, WalkingGraphRewriter,
copy_stack_trace, copy_stack_trace,
dfs_rewriter,
in2out, in2out,
node_rewriter, node_rewriter,
out2in,
) )
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar import Add, ScalarConstant, ScalarType from pytensor.scalar import Add, ScalarConstant, ScalarType
...@@ -1560,7 +1560,7 @@ def local_uint_constant_indices(fgraph, node): ...@@ -1560,7 +1560,7 @@ def local_uint_constant_indices(fgraph, node):
compile.optdb.register( compile.optdb.register(
local_uint_constant_indices.__name__, local_uint_constant_indices.__name__,
out2in(local_uint_constant_indices), dfs_rewriter(local_uint_constant_indices),
# We don't include in the Python / C because those always cast indices to int64 internally. # We don't include in the Python / C because those always cast indices to int64 internally.
"numba", "numba",
"jax", "jax",
......
...@@ -2,7 +2,7 @@ import typing ...@@ -2,7 +2,7 @@ import typing
from collections.abc import Sequence from collections.abc import Sequence
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import NodeRewriter, in2out from pytensor.graph.rewriting.basic import NodeRewriter, dfs_rewriter
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
from pytensor.tensor.rewriting.ofg import inline_ofg_expansion from pytensor.tensor.rewriting.ofg import inline_ofg_expansion
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -23,7 +23,7 @@ optdb.register( ...@@ -23,7 +23,7 @@ optdb.register(
# Register OFG inline again after lowering xtensor # Register OFG inline again after lowering xtensor
optdb.register( optdb.register(
"inline_ofg_expansion_xtensor", "inline_ofg_expansion_xtensor",
in2out(inline_ofg_expansion), dfs_rewriter(inline_ofg_expansion),
"fast_run", "fast_run",
"fast_compile", "fast_compile",
position=0.11, position=0.11,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论