提交 27d79707 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use OpPattern in tracks

上级 19f1486b
......@@ -1228,6 +1228,8 @@ class ScalarOp(COp):
f"(got: {output_types_preference})"
)
self.output_types_preference = output_types_preference
elif not hasattr(self, "output_types_preference"):
self.output_types_preference = None
def make_node(self, *inputs):
if self.nin >= 0:
......@@ -1247,7 +1249,7 @@ class ScalarOp(COp):
return Apply(self, inputs, outputs)
def output_types(self, types):
if hasattr(self, "output_types_preference"):
if self.output_types_preference is not None:
variables = self.output_types_preference(*types)
if not isinstance(variables, list | tuple) or any(
not isinstance(x, CType) for x in variables
......@@ -2696,7 +2698,7 @@ class Sign(UnaryScalarOp):
nfunc_spec = ("sign", 1, 1)
@staticmethod
def output_types_preference(x):
def _output_types_preference(x):
if x == bool:
raise TypeError(x)
return same_out_nocomplex(x)
......@@ -2737,7 +2739,7 @@ class Sign(UnaryScalarOp):
return s
sign = Sign(name="sign")
sign = Sign(name="sign", output_types_preference=Sign._output_types_preference)
class Ceil(UnaryScalarOp):
......
......@@ -14,6 +14,7 @@ from pytensor.tensor.basic import atleast_Nd
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable
......@@ -227,7 +228,7 @@ def _scan_split_non_sequence_decomposition_and_solve(
@register_specialize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(Solve)])
def reuse_decomposition_multiple_solves(fgraph, node):
return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"}
......
......@@ -26,10 +26,9 @@ import logging
import numpy as np
import pytensor.scalar.basic as ps
from pytensor import compile, config
from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph
from pytensor.graph import FunctionGraph, Op
from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
......@@ -40,9 +39,24 @@ from pytensor.graph.rewriting.basic import (
node_rewriter,
)
from pytensor.graph.rewriting.db import RewriteDatabase
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.npy_2_compat import normalize_axis_index
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar.basic import Second
from pytensor.scalar import (
AND,
EQ,
LE,
NEQ,
OR,
XOR,
Add,
BinaryScalarOp,
Cast,
Identity,
Mul,
Second,
Switch,
)
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
......@@ -225,6 +239,12 @@ def register_uncanonicalize(
return node_rewriter
def elemwise_of(scalar_op: OpPatternOpTypeType | OpPattern) -> OpPattern:
if not isinstance(scalar_op, Op | OpPattern):
scalar_op = OpPattern(scalar_op)
return OpPattern(Elemwise, scalar_op=scalar_op)
@register_canonicalize
@register_specialize
@node_rewriter([TensorFromScalar])
......@@ -551,7 +571,7 @@ def local_useless_elemwise(fgraph, node):
dtype = node.outputs[0].type.dtype
scalar_op = node.op.scalar_op
if isinstance(scalar_op, ps.EQ) and len(node.inputs) == 2:
if isinstance(scalar_op, EQ) and len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]:
# it is the same var in the graph. That will always be true
ret = ones_like(node.inputs[0], dtype=dtype, opt=True)
......@@ -559,7 +579,7 @@ def local_useless_elemwise(fgraph, node):
# Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret)
return [ret]
elif isinstance(scalar_op, ps.NEQ | ps.XOR) and len(node.inputs) == 2:
elif isinstance(scalar_op, NEQ | XOR) and len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]:
# it is the same var in the graph. That will always be false
ret = zeros_like(node.inputs[0], dtype=dtype, opt=True)
......@@ -568,14 +588,11 @@ def local_useless_elemwise(fgraph, node):
copy_stack_trace(node.outputs[0], ret)
return [ret]
elif (
isinstance(node.op.scalar_op, ps.Mul | ps.Add | ps.Identity)
and len(node.inputs) == 1
):
elif isinstance(node.op.scalar_op, Mul | Add | Identity) and len(node.inputs) == 1:
# No need to copy over any stack trace
return [node.inputs[0]]
elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2:
elif isinstance(node.op.scalar_op, AND) and len(node.inputs) == 2:
if (
isinstance(node.inputs[0], TensorConstant)
and node.inputs[1].type.broadcastable == out_bcast
......@@ -602,7 +619,7 @@ def local_useless_elemwise(fgraph, node):
# and this rewrite would be wrong
return [node.inputs[0].astype(node.outputs[0].dtype)]
elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2:
elif isinstance(node.op.scalar_op, OR) and len(node.inputs) == 2:
if (
isinstance(node.inputs[0], TensorConstant)
and node.inputs[1].type.broadcastable == out_bcast
......@@ -653,7 +670,7 @@ def local_alloc_unary(fgraph, node):
@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
@node_rewriter([elemwise_of(Cast)])
def local_cast_cast(fgraph, node):
"""cast(cast(x, dtype1), dtype2)
......@@ -663,13 +680,11 @@ def local_cast_cast(fgraph, node):
and the first cast cause an upcast.
"""
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Cast)):
return
x = node.inputs[0]
if not (
x.owner
and isinstance(x.owner.op, Elemwise)
and isinstance(x.owner.op.scalar_op, ps.Cast)
and isinstance(x.owner.op.scalar_op, Cast)
):
return
......@@ -1009,7 +1024,7 @@ def local_useless_switch(fgraph, node):
node.outputs[0].type.ndim == 0
and cond_var.owner
and isinstance(cond_var.owner.op, Elemwise)
and isinstance(cond_var.owner.op.scalar_op, ps.LE)
and isinstance(cond_var.owner.op.scalar_op, LE)
and cond_var.owner.inputs[0].owner
and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i)
and get_scalar_constant_value(
......@@ -1031,24 +1046,18 @@ def local_useless_switch(fgraph, node):
@register_canonicalize
@node_rewriter([Elemwise])
@node_rewriter([elemwise_of(BinaryScalarOp | Add | Mul)])
def local_merge_switch_same_cond(fgraph, node):
"""
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
condition, to enable further simplification of their branches
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
"""
# node must be binary elemwise or add or mul
if not (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul)
):
return
# all inputs must be switch
if not all(
s.owner
and isinstance(s.owner.op, Elemwise)
and isinstance(s.owner.op.scalar_op, ps.Switch)
and isinstance(s.owner.op.scalar_op, Switch)
for s in node.inputs
):
return
......@@ -1174,10 +1183,9 @@ register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True)
@register_infer_shape
@register_canonicalize("fast_compile")
@register_useless("fast_compile")
@node_rewriter(None)
@node_rewriter([ViewOp])
def local_view_op(fgraph, node):
if isinstance(node.op, ViewOp):
return node.inputs
return node.inputs
@register_infer_shape
......
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph import Constant, Op, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.math import Dot
......@@ -20,6 +21,12 @@ from pytensor.tensor.subtensor import (
)
def blockwise_of(core_op: OpPatternOpTypeType | OpPattern) -> OpPattern:
if not isinstance(core_op, Op | OpPattern):
core_op = OpPattern(core_op)
return OpPattern(Blockwise, core_op=core_op)
@node_rewriter([Blockwise])
def local_useless_blockwise(fgraph, node):
"""
......@@ -71,22 +78,24 @@ optdb.register(
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter(tracks=[Blockwise])
@node_rewriter(
tracks=[
blockwise_of(
Dot
| Alloc
| ARange
| Subtensor
| AdvancedSubtensor
| AdvancedIncSubtensor
| Reshape
)
]
)
def local_eager_useless_unbatched_blockwise(fgraph, node):
if isinstance(
node.op.core_op,
Dot
| Alloc
| ARange
| Subtensor
| AdvancedSubtensor
| AdvancedIncSubtensor
| Reshape,
):
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
# These other Ops can't always be trivially vectorized at runtime,
# since their inputs may imply non-rectangular shapes.
return local_useless_unbatched_blockwise.fn(fgraph, node)
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
# These other Ops can't always be trivially vectorized at runtime,
# since their inputs may imply non-rectangular shapes.
return local_useless_unbatched_blockwise.fn(fgraph, node)
@register_specialize("shape_unsafe")
......@@ -204,7 +213,7 @@ def local_blockwise_alloc(fgraph, node):
@register_specialize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(Reshape)])
def local_blockwise_reshape(fgraph, node):
"""Rewrite away square Blockwise reshapes.
......@@ -215,9 +224,6 @@ def local_blockwise_reshape(fgraph, node):
For the square Reshape case, we must wait for all the intermediate
operations to be lifted as Allocs
"""
if not isinstance(node.op.core_op, Reshape):
return None
x, output_shape = node.inputs
batch_ndim = node.op.batch_ndim(node)
if all(output_shape.type.broadcastable[:batch_ndim]):
......
......@@ -26,6 +26,7 @@ from pytensor.graph.rewriting.basic import (
out2in,
)
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import (
......@@ -37,6 +38,7 @@ from pytensor.tensor.math import add, exp, mul
from pytensor.tensor.rewriting.basic import (
alloc_like,
broadcasted_by,
elemwise_of,
register_canonicalize,
register_specialize,
register_stabilize,
......@@ -422,7 +424,14 @@ def local_useless_dimshuffle_makevector(fgraph, node):
@register_canonicalize
@node_rewriter([Elemwise])
@node_rewriter(
[
elemwise_of(
OpPattern(ps.ScalarOp, output_types_preference=ps.upgrade_to_float)
),
elemwise_of(OpPattern(ps.ScalarOp, output_types_preference=ps.upcast_out)),
]
)
def local_upcast_elemwise_constant_inputs(fgraph, node):
"""This explicitly upcasts constant inputs to elemwise Ops, when
those Ops do implicit upcasting anyway.
......@@ -433,12 +442,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
if len(node.outputs) > 1:
return None
if getattr(node.op.scalar_op, "output_types_preference", None) not in (
ps.upgrade_to_float,
ps.upcast_out,
):
return None
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
[old_out] = node.outputs
......@@ -988,13 +991,9 @@ class FusionOptimizer(GraphRewriter):
@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
@node_rewriter([elemwise_of(ps.Composite)])
def local_useless_composite_outputs(fgraph, node):
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
if not (
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Composite)
):
return
comp = node.op.scalar_op
used_outputs_idxs = [
i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]
......@@ -1104,14 +1103,10 @@ def local_careduce_fusion(fgraph, node):
return [new_car_op(*elm_inputs)]
@node_rewriter([Elemwise])
@node_rewriter([elemwise_of(ps.Composite)])
def local_inline_composite_constants(fgraph, node):
"""Inline scalar constants in Composite graphs."""
composite_op = node.op.scalar_op
if not isinstance(composite_op, ps.Composite):
return None
new_outer_inputs = []
new_inner_inputs = []
inner_replacements = {}
......@@ -1287,14 +1282,9 @@ def _rebuild_partial_2f1grad_loop(node, wrt):
@register_specialize
@node_rewriter([Elemwise])
@node_rewriter([elemwise_of(Grad2F1Loop)])
def local_useless_2f1grad_loop(fgraph, node):
# Remove unused terms from the hyp2f1 grad loop
loop_op = node.op.scalar_op
if not isinstance(loop_op, Grad2F1Loop):
return
grad_related_vars = node.outputs[:-4]
# Rewrite was already applied
if len(grad_related_vars) // 3 != 3:
......@@ -1326,18 +1316,13 @@ def local_useless_2f1grad_loop(fgraph, node):
return replacements
@node_rewriter([Elemwise])
@node_rewriter([elemwise_of(Grad2F1Loop)])
def split_2f1grad_loop(fgraph, node):
"""
2f1grad loop has too many operands for Numpy frompyfunc code used by Elemwise nodes on python mode.
This rewrite splits it across 3 different operations. It is not needed if `local_useless_2f1grad_loop` was applied
"""
loop_op = node.op.scalar_op
if not isinstance(loop_op, Grad2F1Loop):
return None
grad_related_vars = node.outputs[:-4]
# local_useless_2f1grad_loop was used, we should be safe
if len(grad_related_vars) // 3 != 3:
......
......@@ -13,6 +13,7 @@ from pytensor.graph.rewriting.basic import (
in2out,
node_rewriter,
)
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.scalar.basic import Abs, Log, Mul, Sign
from pytensor.tensor.basic import (
AllocDiag,
......@@ -43,6 +44,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize,
register_stabilize,
)
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
......@@ -60,7 +62,7 @@ from pytensor.tensor.slinalg import (
logger = logging.getLogger(__name__)
ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv)
MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)
def is_matrix_transpose(x: TensorVariable) -> bool:
......@@ -129,69 +131,48 @@ def inv_as_solve(fgraph, node):
@register_stabilize
@register_canonicalize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(OpPattern(Solve, assume_a="gen"))])
def generic_solve_to_solve_triangular(fgraph, node):
"""
If any solve() is applied to the output of a cholesky op, then
replace it with a triangular solve.
"""
if isinstance(node.op.core_op, Solve):
if node.op.core_op.assume_a == "gen":
A, b = node.inputs # result is solution Ax=b
if (
A.owner
and isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, Cholesky)
):
if A.owner.op.core_op.lower:
return [
solve_triangular(
A, b, lower=True, b_ndim=node.op.core_op.b_ndim
)
]
else:
return [
solve_triangular(
A, b, lower=False, b_ndim=node.op.core_op.b_ndim
)
]
if is_matrix_transpose(A):
(A_T,) = A.owner.inputs
if (
A_T.owner
and isinstance(A_T.owner.op, Blockwise)
and isinstance(A_T.owner.op, Cholesky)
):
if A_T.owner.op.lower:
return [
solve_triangular(
A, b, lower=False, b_ndim=node.op.core_op.b_ndim
)
]
else:
return [
solve_triangular(
A, b, lower=True, b_ndim=node.op.core_op.b_ndim
)
]
A, b = node.inputs # result is the solution to Ax=b
if (
A.owner
and isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, Cholesky)
):
if A.owner.op.core_op.lower:
return [solve_triangular(A, b, lower=True, b_ndim=node.op.core_op.b_ndim)]
else:
return [solve_triangular(A, b, lower=False, b_ndim=node.op.core_op.b_ndim)]
if is_matrix_transpose(A):
(A_T,) = A.owner.inputs
if (
A_T.owner
and isinstance(A_T.owner.op, Blockwise)
and isinstance(A_T.owner.op, Cholesky)
):
if A_T.owner.op.lower:
return [
solve_triangular(A, b, lower=False, b_ndim=node.op.core_op.b_ndim)
]
else:
return [
solve_triangular(A, b, lower=True, b_ndim=node.op.core_op.b_ndim)
]
@register_specialize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(OpPattern(SolveBase, b_ndim=1))])
def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T
`a` must have no batched dimensions, while `b` can have arbitrary batched dimensions.
"""
core_op = node.op.core_op
if not isinstance(core_op, SolveBase):
return None
if node.op.core_op.b_ndim != 1:
return None
[a, b] = node.inputs
# Check `b` is actually batched
......@@ -242,26 +223,24 @@ def no_transpose_symmetric(fgraph, node):
@register_stabilize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(OpPattern(Solve, b_ndim=2))])
def psd_solve_with_chol(fgraph, node):
"""
This utilizes a boolean `psd` tag on matrices.
"""
if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 2:
A, b = node.inputs # result is solution Ax=b
if getattr(A.tag, "psd", None) is True:
L = cholesky(A)
# N.B. this can be further reduced to a yet-unwritten cho_solve Op
# __if__ no other Op makes use of the L matrix during the
# stabilization
Li_b = solve_triangular(L, b, lower=True, b_ndim=2)
x = solve_triangular((L.mT), Li_b, lower=False, b_ndim=2)
return [x]
A, b = node.inputs # result is the solution to Ax=b
if getattr(A.tag, "psd", None) is True:
L = cholesky(A)
# N.B. this can be further reduced to cho_solve Op
# if no other Op makes use of the L matrix
Li_b = solve_triangular(L, b, lower=True, b_ndim=2)
x = solve_triangular((L.mT), Li_b, lower=False, b_ndim=2)
return [x]
@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(Cholesky)])
def cholesky_ldotlt(fgraph, node):
"""
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
......@@ -271,9 +250,6 @@ def cholesky_ldotlt(fgraph, node):
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
"""
if not isinstance(node.op.core_op, Cholesky):
return
A = node.inputs[0]
if not (
A.owner is not None and (isinstance(A.owner.op, Dot) or (A.owner.op == _matmul))
......@@ -342,7 +318,7 @@ def local_log_prod_sqr(fgraph, node):
@register_specialize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(MatrixInverse | Cholesky | MatrixPinv)])
def local_lift_through_linalg(
fgraph: FunctionGraph, node: Apply
) -> list[Variable] | None:
......@@ -370,9 +346,6 @@ def local_lift_through_linalg(
"""
# TODO: Simplify this if we end up Blockwising KroneckerProduct
if not isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
return None
y = node.inputs[0]
outer_op = node.op
......@@ -534,15 +507,12 @@ def rewrite_det_diag_to_prod_diag(fgraph, node):
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(SVD)])
def svd_uv_merge(fgraph, node):
"""If we have more than one `SVD` `Op`s and at least one has keyword argument
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
"""
if not isinstance(node.op.core_op, SVD):
return
(x,) = node.inputs
if node.op.core_op.compute_uv:
......@@ -585,7 +555,7 @@ def svd_uv_merge(fgraph, node):
@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)])
def rewrite_inv_inv(fgraph, node):
"""
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
......@@ -607,9 +577,6 @@ def rewrite_inv_inv(fgraph, node):
# Check if its a valid inverse operation (either inv/pinv)
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
# If the outer operation is not a valid inverse, we do not apply this rewrite
if not isinstance(node.op.core_op, ALL_INVERSE_OPS):
return None
potential_inner_inv = node.inputs[0].owner
if potential_inner_inv is None or potential_inner_inv.op is None:
return None
......@@ -618,7 +585,7 @@ def rewrite_inv_inv(fgraph, node):
if not (
potential_inner_inv
and isinstance(potential_inner_inv.op, Blockwise)
and isinstance(potential_inner_inv.op.core_op, ALL_INVERSE_OPS)
and isinstance(potential_inner_inv.op.core_op, MATRIX_INVERSE_OPS)
):
return None
return [potential_inner_inv.inputs[0]]
......@@ -626,7 +593,7 @@ def rewrite_inv_inv(fgraph, node):
@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)])
def rewrite_inv_eye_to_eye(fgraph, node):
"""
This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself
......@@ -642,10 +609,6 @@ def rewrite_inv_eye_to_eye(fgraph, node):
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
core_op = node.op.core_op
if not (isinstance(core_op, ALL_INVERSE_OPS)):
return None
# Check whether input to inverse is Eye and the 1's are on main diagonal
potential_eye = node.inputs[0]
if not (
......@@ -659,7 +622,7 @@ def rewrite_inv_eye_to_eye(fgraph, node):
@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)])
def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
"""
This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements.
......@@ -677,10 +640,6 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
core_op = node.op.core_op
if not (isinstance(core_op, ALL_INVERSE_OPS)):
return None
inputs = node.inputs[0]
# Check for use of pt.diag first
if (
......@@ -857,7 +816,7 @@ def rewrite_det_kronecker(fgraph, node):
@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(Cholesky)])
def rewrite_remove_useless_cholesky(fgraph, node):
"""
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
......@@ -877,8 +836,6 @@ def rewrite_remove_useless_cholesky(fgraph, node):
List of optimized variables, or None if no optimization was performed
"""
# Find whether cholesky op is being applied
if not isinstance(node.op.core_op, Cholesky):
return None
# Check whether input to Cholesky is Eye and the 1's are on main diagonal
potential_eye = node.inputs[0]
......@@ -894,12 +851,8 @@ def rewrite_remove_useless_cholesky(fgraph, node):
@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(Cholesky)])
def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
# Find whether cholesky op is being applied
if not isinstance(node.op.core_op, Cholesky):
return None
[input] = node.inputs
# Check if input is a (1, 1) matrix
......@@ -1022,7 +975,7 @@ def slogdet_specialization(fgraph, node):
@register_stabilize
@register_canonicalize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(SolveBase)])
def scalar_solve_to_division(fgraph, node):
"""
Replace solve(a, b) with b / a if a is a (1, 1) matrix
......
......@@ -37,7 +37,6 @@ from pytensor.tensor.basic import (
zeros,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
......@@ -49,6 +48,11 @@ from pytensor.tensor.math import (
_dot,
_matmul,
add,
arccosh,
arcsinh,
arctanh,
cosh,
deg2rad,
digamma,
dot,
erf,
......@@ -70,13 +74,16 @@ from pytensor.tensor.math import (
neg,
polygamma,
prod,
rad2deg,
reciprocal,
sigmoid,
sign,
sinh,
softplus,
sqr,
sqrt,
sub,
tanh,
tri_gamma,
true_div,
variadic_add,
......@@ -96,6 +103,7 @@ from pytensor.tensor.rewriting.basic import (
register_uncanonicalize,
register_useless,
)
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.shape import Shape, Shape_i
......@@ -151,7 +159,7 @@ def local_0_dot_x(fgraph, node):
@register_stabilize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(BlockDiagonal)])
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
r"""
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
......@@ -160,9 +168,6 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
a single dot on the larger matrix.
"""
if not isinstance(node.op.core_op, BlockDiagonal):
return
# Check that the BlockDiagonal is an input to a Dot node:
for client in itertools.chain.from_iterable(
get_clients_at_depth(fgraph, node, depth=i) for i in [1, 2]
......@@ -424,60 +429,30 @@ def local_dot_to_mul(fgraph, node):
return [new_out]
def is_inverse_pair(node_op, prev_op, inv_pair):
"""
Given two consecutive operations, check if they are the
provided pair of inverse functions.
"""
node_is_op0 = isinstance(node_op, inv_pair[0])
node_is_op1 = isinstance(node_op, inv_pair[1])
prev_is_op0 = isinstance(prev_op, inv_pair[0])
prev_is_op1 = isinstance(prev_op, inv_pair[1])
return (node_is_op0 and prev_is_op1) or (node_is_op1 and prev_is_op0)
@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
def local_func_inv(fgraph, node):
"""
Check for two consecutive operations that are functional inverses
and remove them from the function graph.
"""
inv_pairs = (
(ps.Deg2Rad, ps.Rad2Deg),
(ps.Cosh, ps.ArcCosh),
(ps.Tanh, ps.ArcTanh),
(ps.Sinh, ps.ArcSinh),
(ps.Conj, ps.Conj),
(ps.Neg, ps.Neg),
(ps.Reciprocal, ps.Reciprocal),
)
x = node.inputs[0]
if not isinstance(node.op, Elemwise):
return
if not (x.owner and isinstance(x.owner.op, Elemwise)):
return
prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op
for inv_pair in inv_pairs:
if is_inverse_pair(node_op, prev_op, inv_pair):
# We don't need to copy stack trace, because the rewrite
# is trivial and maintains the earlier stack trace
ottype = node.out.dtype
inp = x.owner.inputs[0]
# Functions may have casted integer input to float
if inp.dtype != ottype:
inp = cast(inp, ottype)
return [inp]
for pair in (
(deg2rad, rad2deg),
(cosh, arccosh),
(tanh, arctanh),
(sinh, arcsinh),
(_conj, _conj),
(neg, neg),
(reciprocal, reciprocal),
):
# Create a simple PatternNodeRewriter for each pair of opposite ops
# instead of a general Op that is called to often for very few hits
for op, inv_op in (pair, reversed(pair)):
rewrite = PatternNodeRewriter(
(op, (inv_op, "x")),
"x",
allow_multiple_clients=True,
allow_cast=True,
name=f"useless_{op}_of_{inv_op}",
)
register_canonicalize(rewrite)
register_specialize(rewrite)
return
if op is inv_op:
break # Same Op, no need to define two rewrites
@register_canonicalize
......
......@@ -35,7 +35,7 @@ from pytensor.tensor.basic import (
switch,
)
from pytensor.tensor.basic import constant as tensor_constant
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.blockwise import _squeeze_left
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_to
......@@ -58,6 +58,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize,
register_stabilize,
)
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.shape import (
shape_padleft,
shape_padright,
......@@ -974,33 +975,30 @@ def local_IncSubtensor_serialize(fgraph, node):
and not i.owner.op.set_instead_of_inc
)
if node.op == add:
o_type = node.outputs[0].type
o_type = node.outputs[0].type
movable_inputs = [i for i in node.inputs if movable(i)]
movable_inputs = [i for i in node.inputs if movable(i)]
if movable_inputs:
new_inputs = [i for i in node.inputs if not movable(i)] + [
mi.owner.inputs[0] for mi in movable_inputs
]
new_add = variadic_add(*new_inputs)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
copy_stack_trace(node.outputs[0], new_add)
# stack up the new incsubtensors
tip = new_add
for mi in movable_inputs:
assert o_type.is_super(tip.type)
tip = mi.owner.op(tip, *mi.owner.inputs[1:])
# Copy over stacktrace from outputs of the original
# "movable" operation to the new operation.
copy_stack_trace(node.outputs + mi.owner.outputs, tip)
if movable_inputs:
new_inputs = [i for i in node.inputs if not movable(i)] + [
mi.owner.inputs[0] for mi in movable_inputs
]
new_add = variadic_add(*new_inputs)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
copy_stack_trace(node.outputs[0], new_add)
return [tip]
# stack up the new incsubtensors
tip = new_add
for mi in movable_inputs:
assert o_type.is_super(tip.type)
tip = mi.owner.op(tip, *mi.owner.inputs[1:])
# Copy over stacktrace from outputs of the original
# "movable" operation to the new operation.
copy_stack_trace(node.outputs + mi.owner.outputs, tip)
# print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs]
return [tip]
# We register it in a WalkingGraphRewriter inside the canonizer EQ optimizer.
......@@ -1576,7 +1574,7 @@ compile.optdb.register(
@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(Subtensor)])
def local_blockwise_of_subtensor(fgraph, node):
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
......@@ -1585,9 +1583,6 @@ def local_blockwise_of_subtensor(fgraph, node):
TODO: Handle batched indices like we do with blockwise of inc_subtensor
TODO: Extend to AdvanceSubtensor
"""
if not isinstance(node.op.core_op, Subtensor):
return
x, *idxs = node.inputs
if not all(all(idx.type.broadcastable) for idx in idxs):
return
......@@ -1603,7 +1598,7 @@ def local_blockwise_of_subtensor(fgraph, node):
@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(IncSubtensor | AdvancedIncSubtensor)])
def local_blockwise_inc_subtensor(fgraph, node):
"""Rewrite blockwised inc_subtensors.
......@@ -1614,12 +1609,9 @@ def local_blockwise_inc_subtensor(fgraph, node):
and can be safely rewritten without Blockwise.
"""
core_op = node.op.core_op
if not isinstance(core_op, AdvancedIncSubtensor | IncSubtensor):
return None
x, y, *idxs = node.inputs
[out] = node.outputs
if isinstance(node.op.core_op, AdvancedIncSubtensor):
if isinstance(core_op, AdvancedIncSubtensor):
if any(
(
# Blockwise requires all inputs to be tensors so it is not possible
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论