提交 1d5b1d94 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Replace uses of in2out and out2in by a depth-first search rewriter

上级 aa7e4d6b
......@@ -27,7 +27,12 @@ from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
from pytensor.graph.traversal import applys_between, toposort, vars_between
from pytensor.graph.traversal import (
apply_ancestors,
applys_between,
toposort,
vars_between,
)
from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import flatten
......@@ -1995,12 +2000,13 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
def __init__(
self,
node_rewriter: NodeRewriter,
order: Literal["out_to_in", "in_to_out"] = "in_to_out",
order: Literal["out_to_in", "in_to_out", "dfs"] = "in_to_out",
ignore_newtrees: bool = False,
failure_callback: FailureCallbackType | None = None,
):
if order not in ("out_to_in", "in_to_out"):
raise ValueError("order must be 'out_to_in' or 'in_to_out'")
valid_orders = ("out_to_in", "in_to_out", "dfs")
if order not in valid_orders:
raise ValueError(f"order must be one of {valid_orders}, got {order}")
self.order = order
super().__init__(node_rewriter, ignore_newtrees, failure_callback)
......@@ -2010,7 +2016,11 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
callback_before = fgraph.execute_callbacks_time
nb_nodes_start = len(fgraph.apply_nodes)
t0 = time.perf_counter()
q = deque(toposort(start_from))
q = deque(
apply_ancestors(start_from)
if (self.order == "dfs")
else toposort(start_from)
)
io_t = time.perf_counter() - t0
def importer(node):
......@@ -2134,6 +2144,7 @@ def walking_rewriter(
in2out = partial(walking_rewriter, "in_to_out")
out2in = partial(walking_rewriter, "out_to_in")
dfs_rewriter = partial(walking_rewriter, "dfs")
class ChangeTracker(Feature):
......
......@@ -29,7 +29,7 @@ from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter,
GraphRewriter,
copy_stack_trace,
in2out,
dfs_rewriter,
node_rewriter,
)
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
......@@ -2558,7 +2558,7 @@ optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
# ScanSaveMem should execute only once per node.
optdb.register(
"scan_save_mem_prealloc",
in2out(scan_save_mem_prealloc, ignore_newtrees=True),
dfs_rewriter(scan_save_mem_prealloc, ignore_newtrees=True),
"fast_run",
"scan",
"scan_save_mem",
......@@ -2566,7 +2566,7 @@ optdb.register(
)
optdb.register(
"scan_save_mem_no_prealloc",
in2out(scan_save_mem_no_prealloc, ignore_newtrees=True),
dfs_rewriter(scan_save_mem_no_prealloc, ignore_newtrees=True),
"numba",
"jax",
"pytorch",
......@@ -2587,7 +2587,7 @@ scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan")
scan_seqopt1.register(
"scan_remove_constants_and_unused_inputs0",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
"remove_constants_and_unused_inputs_scan",
"fast_run",
"scan",
......@@ -2596,7 +2596,7 @@ scan_seqopt1.register(
scan_seqopt1.register(
"scan_push_out_non_seq",
in2out(scan_push_out_non_seq, ignore_newtrees=True),
dfs_rewriter(scan_push_out_non_seq, ignore_newtrees=True),
"scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name
"fast_run",
"scan",
......@@ -2606,7 +2606,7 @@ scan_seqopt1.register(
scan_seqopt1.register(
"scan_push_out_seq",
in2out(scan_push_out_seq, ignore_newtrees=True),
dfs_rewriter(scan_push_out_seq, ignore_newtrees=True),
"scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name
"fast_run",
"scan",
......@@ -2617,7 +2617,7 @@ scan_seqopt1.register(
scan_seqopt1.register(
"scan_push_out_dot1",
in2out(scan_push_out_dot1, ignore_newtrees=True),
dfs_rewriter(scan_push_out_dot1, ignore_newtrees=True),
"scan_pushout_dot1", # For backcompat: so it can be tagged with old name
"fast_run",
"more_mem",
......@@ -2630,7 +2630,7 @@ scan_seqopt1.register(
scan_seqopt1.register(
"scan_push_out_add",
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
in2out(scan_push_out_add, ignore_newtrees=False),
dfs_rewriter(scan_push_out_add, ignore_newtrees=False),
"scan_pushout_add", # For backcompat: so it can be tagged with old name
"fast_run",
"more_mem",
......@@ -2641,14 +2641,14 @@ scan_seqopt1.register(
scan_eqopt2.register(
"while_scan_merge_subtensor_last_element",
in2out(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
dfs_rewriter(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
"fast_run",
"scan",
)
scan_eqopt2.register(
"constant_folding_for_scan2",
in2out(constant_folding, ignore_newtrees=True),
dfs_rewriter(constant_folding, ignore_newtrees=True),
"fast_run",
"scan",
)
......@@ -2656,7 +2656,7 @@ scan_eqopt2.register(
scan_eqopt2.register(
"scan_remove_constants_and_unused_inputs1",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
"remove_constants_and_unused_inputs_scan",
"fast_run",
"scan",
......@@ -2671,7 +2671,7 @@ scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan")
# After Merge optimization
scan_eqopt2.register(
"scan_remove_constants_and_unused_inputs2",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
"remove_constants_and_unused_inputs_scan",
"fast_run",
"scan",
......@@ -2679,7 +2679,7 @@ scan_eqopt2.register(
scan_eqopt2.register(
"scan_merge_inouts",
in2out(scan_merge_inouts, ignore_newtrees=True),
dfs_rewriter(scan_merge_inouts, ignore_newtrees=True),
"fast_run",
"scan",
)
......@@ -2687,7 +2687,7 @@ scan_eqopt2.register(
# After everything else
scan_eqopt2.register(
"scan_remove_constants_and_unused_inputs3",
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
"remove_constants_and_unused_inputs_scan",
"fast_run",
"scan",
......
......@@ -3,7 +3,7 @@ from copy import copy
from pytensor.compile import optdb
from pytensor.graph import Constant, graph_inputs
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter, node_rewriter
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import scan_seqopt1
from pytensor.tensor._linalg.solve.tridiagonal import (
......@@ -244,7 +244,7 @@ def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
scan_seqopt1.register(
scan_split_non_sequence_decomposition_and_solve.__name__,
in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
dfs_rewriter(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
"fast_run",
"scan",
"scan_pushout",
......@@ -261,7 +261,7 @@ def reuse_decomposition_multiple_solves_jax(fgraph, node):
optdb["specialize"].register(
reuse_decomposition_multiple_solves_jax.__name__,
in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
dfs_rewriter(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
)
......@@ -276,7 +276,9 @@ def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
scan_seqopt1.register(
scan_split_non_sequence_decomposition_and_solve_jax.__name__,
in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True),
dfs_rewriter(
scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True
),
"jax",
use_db_name_as_tag=False,
position=2,
......
......@@ -4,7 +4,11 @@ from pytensor.compile import optdb
from pytensor.configdefaults import config
from pytensor.graph import ancestors
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
dfs_rewriter,
node_rewriter,
)
from pytensor.tensor import NoneConst, TensorVariable
from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle
......@@ -57,7 +61,7 @@ def random_make_inplace(fgraph, node):
optdb.register(
"random_make_inplace",
in2out(random_make_inplace, ignore_newtrees=True),
dfs_rewriter(random_make_inplace, ignore_newtrees=True),
"fast_run",
"inplace",
position=50.9,
......
......@@ -2,8 +2,7 @@ import re
from pytensor.compile import optdb
from pytensor.graph import Constant
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.rewriting.basic import dfs_rewriter, in2out, node_rewriter
from pytensor.tensor import abs as abs_t
from pytensor.tensor import broadcast_arrays, exp, floor, log, log1p, reciprocal, sqrt
from pytensor.tensor.basic import (
......@@ -179,51 +178,16 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
return new_op.make_node(rng, size, a_vector_param, *other_params).outputs
random_vars_opt = SequenceDB()
random_vars_opt.register(
"lognormal_from_normal",
in2out(lognormal_from_normal),
"jax",
)
random_vars_opt.register(
"halfnormal_from_normal",
in2out(halfnormal_from_normal),
"jax",
)
random_vars_opt.register(
"geometric_from_uniform",
in2out(geometric_from_uniform),
"jax",
)
random_vars_opt.register(
"negative_binomial_from_gamma_poisson",
in2out(negative_binomial_from_gamma_poisson),
"jax",
)
random_vars_opt.register(
"inverse_gamma_from_gamma",
in2out(inverse_gamma_from_gamma),
"jax",
)
random_vars_opt.register(
"generalized_gamma_from_gamma",
in2out(generalized_gamma_from_gamma),
"jax",
)
random_vars_opt.register(
"wald_from_normal_uniform",
in2out(wald_from_normal_uniform),
"jax",
)
random_vars_opt.register(
"beta_binomial_from_beta_binomial",
in2out(beta_binomial_from_beta_binomial),
"jax",
)
random_vars_opt.register(
"materialize_implicit_arange_choice_without_replacement",
in2out(materialize_implicit_arange_choice_without_replacement),
"jax",
random_vars_opt = dfs_rewriter(
lognormal_from_normal,
halfnormal_from_normal,
geometric_from_uniform,
negative_binomial_from_gamma_poisson,
inverse_gamma_from_gamma,
generalized_gamma_from_gamma,
wald_from_normal_uniform,
beta_binomial_from_beta_binomial,
materialize_implicit_arange_choice_without_replacement,
)
optdb.register("jax_random_vars_rewrites", random_vars_opt, "jax", position=110)
......
from pytensor.compile import optdb
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import out2in
from pytensor.graph.rewriting.basic import dfs_rewriter
from pytensor.tensor import as_tensor, constant
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.rewriting.shape import ShapeFeature
......@@ -82,7 +82,7 @@ def introduce_explicit_core_shape_rv(fgraph, node):
optdb.register(
introduce_explicit_core_shape_rv.__name__,
out2in(introduce_explicit_core_shape_rv),
dfs_rewriter(introduce_explicit_core_shape_rv),
"numba",
position=100,
)
......@@ -35,6 +35,7 @@ from pytensor.graph.rewriting.basic import (
NodeRewriter,
Rewriter,
copy_stack_trace,
dfs_rewriter,
in2out,
node_rewriter,
)
......@@ -538,7 +539,7 @@ def local_alloc_empty_to_zeros(fgraph, node):
compile.optdb.register(
"local_alloc_empty_to_zeros",
in2out(local_alloc_empty_to_zeros),
dfs_rewriter(local_alloc_empty_to_zeros),
# After move to gpu and merge2, before inplace.
"alloc_empty_to_zeros",
position=49.3,
......
......@@ -77,7 +77,7 @@ from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter,
GraphRewriter,
copy_stack_trace,
in2out,
dfs_rewriter,
node_rewriter,
)
from pytensor.graph.rewriting.db import SequenceDB
......@@ -721,7 +721,7 @@ optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7)
# fast_compile is needed to have GpuDot22 created.
blas_optdb.register(
"local_dot_to_dot22",
in2out(local_dot_to_dot22),
dfs_rewriter(local_dot_to_dot22),
"fast_run",
"fast_compile",
position=0,
......@@ -744,7 +744,7 @@ blas_optdb.register(
)
blas_opt_inplace = in2out(
blas_opt_inplace = dfs_rewriter(
local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
)
optdb.register(
......@@ -883,7 +883,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
# dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register(
"local_dot22_to_dot22scalar",
in2out(local_dot22_to_dot22scalar),
dfs_rewriter(local_dot22_to_dot22scalar),
"fast_run",
position=12,
)
......
from pytensor.configdefaults import config
from pytensor.graph.rewriting.basic import in2out
from pytensor.graph.rewriting.basic import dfs_rewriter
from pytensor.tensor import basic as ptb
from pytensor.tensor.blas import gemv_inplace, gemv_no_inplace, ger, ger_destructive
from pytensor.tensor.blas_c import (
......@@ -56,13 +56,15 @@ def make_c_gemv_destructive(fgraph, node):
blas_optdb.register(
"use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20
"use_c_blas", dfs_rewriter(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20
)
# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"c_blas_destructive",
in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"),
dfs_rewriter(
make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"
),
"fast_run",
"inplace",
"c_blas",
......
......@@ -2,7 +2,7 @@ from pytensor.compile.mode import optdb
from pytensor.graph import Constant, Op, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
......@@ -66,7 +66,7 @@ def local_useless_unbatched_blockwise(fgraph, node):
# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
optdb.register(
"local_useless_unbatched_blockwise",
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
dfs_rewriter(local_useless_unbatched_blockwise, ignore_newtrees=True),
"fast_run",
"fast_compile",
"blockwise",
......
......@@ -21,6 +21,7 @@ from pytensor.graph.op import Op
from pytensor.graph.rewriting.basic import (
GraphRewriter,
copy_stack_trace,
dfs_rewriter,
in2out,
node_rewriter,
out2in,
......@@ -1237,21 +1238,21 @@ fuse_seqopt.register(
)
fuse_seqopt.register(
"local_useless_composite_outputs",
in2out(local_useless_composite_outputs),
dfs_rewriter(local_useless_composite_outputs),
"fast_run",
"fusion",
position=2,
)
fuse_seqopt.register(
"local_careduce_fusion",
in2out(local_careduce_fusion),
dfs_rewriter(local_careduce_fusion),
"fast_run",
"fusion",
position=10,
)
fuse_seqopt.register(
"local_inline_composite_constants",
in2out(local_inline_composite_constants, ignore_newtrees=True),
dfs_rewriter(local_inline_composite_constants, ignore_newtrees=True),
"fast_run",
"fusion",
position=20,
......
import pytensor.tensor as pt
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.rewriting.basic import dfs_rewriter, node_rewriter
from pytensor.tensor.basic import MakeVector
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Sum
......@@ -46,7 +46,7 @@ def boolean_indexing_set_or_inc(fgraph, node):
optdb.register(
"jax_boolean_indexing_set_or_inc",
in2out(boolean_indexing_set_or_inc),
dfs_rewriter(boolean_indexing_set_or_inc),
"jax",
position=100,
)
......@@ -96,7 +96,7 @@ def boolean_indexing_sum(fgraph, node):
optdb.register(
"jax_boolean_indexing_sum", in2out(boolean_indexing_sum), "jax", position=100
"jax_boolean_indexing_sum", dfs_rewriter(boolean_indexing_sum), "jax", position=100
)
......@@ -144,7 +144,7 @@ def shape_parameter_as_tuple(fgraph, node):
optdb.register(
"jax_shape_parameter_as_tuple",
in2out(shape_parameter_as_tuple),
dfs_rewriter(shape_parameter_as_tuple),
"jax",
position=100,
)
......@@ -10,7 +10,7 @@ from pytensor.compile import optdb
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
in2out,
dfs_rewriter,
node_rewriter,
)
from pytensor.graph.rewriting.unify import OpPattern
......@@ -905,7 +905,7 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
optdb.register(
"jax_bilinaer_lyapunov_to_direct",
in2out(jax_bilinaer_lyapunov_to_direct),
dfs_rewriter(jax_bilinaer_lyapunov_to_direct),
"jax",
position=0.9, # Run before canonicalization
)
......
from pytensor.compile import optdb
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import out2in
from pytensor.graph.rewriting.basic import dfs_rewriter
from pytensor.graph.traversal import applys_between
from pytensor.tensor.basic import as_tensor, constant
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
......@@ -102,7 +102,7 @@ def introduce_explicit_core_shape_blockwise(fgraph, node):
optdb.register(
introduce_explicit_core_shape_blockwise.__name__,
out2in(introduce_explicit_core_shape_blockwise),
dfs_rewriter(introduce_explicit_core_shape_blockwise),
"numba",
position=100,
)
......@@ -4,7 +4,7 @@ from pytensor import Variable, clone_replace
from pytensor.compile import optdb
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import Apply, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter
from pytensor.tensor.basic import AllocDiag
from pytensor.tensor.rewriting.basic import register_specialize
......@@ -37,7 +37,7 @@ def inline_ofg_expansion(fgraph, node):
# and before the first scan optimizer.
optdb.register(
"inline_ofg_expansion",
in2out(inline_ofg_expansion),
dfs_rewriter(inline_ofg_expansion),
"fast_compile",
"fast_run",
position=-0.01,
......
......@@ -10,9 +10,9 @@ from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import (
WalkingGraphRewriter,
copy_stack_trace,
dfs_rewriter,
in2out,
node_rewriter,
out2in,
)
from pytensor.raise_op import Assert
from pytensor.scalar import Add, ScalarConstant, ScalarType
......@@ -1560,7 +1560,7 @@ def local_uint_constant_indices(fgraph, node):
compile.optdb.register(
local_uint_constant_indices.__name__,
out2in(local_uint_constant_indices),
dfs_rewriter(local_uint_constant_indices),
# We don't include in the Python / C because those always cast indices to int64 internally.
"numba",
"jax",
......
......@@ -2,7 +2,7 @@ import typing
from collections.abc import Sequence
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import NodeRewriter, in2out
from pytensor.graph.rewriting.basic import NodeRewriter, dfs_rewriter
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
from pytensor.tensor.rewriting.ofg import inline_ofg_expansion
from pytensor.tensor.variable import TensorVariable
......@@ -23,7 +23,7 @@ optdb.register(
# Register OFG inline again after lowering xtensor
optdb.register(
"inline_ofg_expansion_xtensor",
in2out(inline_ofg_expansion),
dfs_rewriter(inline_ofg_expansion),
"fast_run",
"fast_compile",
position=0.11,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论