提交 27d79707 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use OpPattern in tracks

上级 19f1486b
...@@ -1228,6 +1228,8 @@ class ScalarOp(COp): ...@@ -1228,6 +1228,8 @@ class ScalarOp(COp):
f"(got: {output_types_preference})" f"(got: {output_types_preference})"
) )
self.output_types_preference = output_types_preference self.output_types_preference = output_types_preference
elif not hasattr(self, "output_types_preference"):
self.output_types_preference = None
def make_node(self, *inputs): def make_node(self, *inputs):
if self.nin >= 0: if self.nin >= 0:
...@@ -1247,7 +1249,7 @@ class ScalarOp(COp): ...@@ -1247,7 +1249,7 @@ class ScalarOp(COp):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def output_types(self, types): def output_types(self, types):
if hasattr(self, "output_types_preference"): if self.output_types_preference is not None:
variables = self.output_types_preference(*types) variables = self.output_types_preference(*types)
if not isinstance(variables, list | tuple) or any( if not isinstance(variables, list | tuple) or any(
not isinstance(x, CType) for x in variables not isinstance(x, CType) for x in variables
...@@ -2696,7 +2698,7 @@ class Sign(UnaryScalarOp): ...@@ -2696,7 +2698,7 @@ class Sign(UnaryScalarOp):
nfunc_spec = ("sign", 1, 1) nfunc_spec = ("sign", 1, 1)
@staticmethod @staticmethod
def output_types_preference(x): def _output_types_preference(x):
if x == bool: if x == bool:
raise TypeError(x) raise TypeError(x)
return same_out_nocomplex(x) return same_out_nocomplex(x)
...@@ -2737,7 +2739,7 @@ class Sign(UnaryScalarOp): ...@@ -2737,7 +2739,7 @@ class Sign(UnaryScalarOp):
return s return s
sign = Sign(name="sign") sign = Sign(name="sign", output_types_preference=Sign._output_types_preference)
class Ceil(UnaryScalarOp): class Ceil(UnaryScalarOp):
......
...@@ -14,6 +14,7 @@ from pytensor.tensor.basic import atleast_Nd ...@@ -14,6 +14,7 @@ from pytensor.tensor.basic import atleast_Nd
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -227,7 +228,7 @@ def _scan_split_non_sequence_decomposition_and_solve( ...@@ -227,7 +228,7 @@ def _scan_split_non_sequence_decomposition_and_solve(
@register_specialize @register_specialize
@node_rewriter([Blockwise]) @node_rewriter([blockwise_of(Solve)])
def reuse_decomposition_multiple_solves(fgraph, node): def reuse_decomposition_multiple_solves(fgraph, node):
return _split_decomp_and_solve_steps( return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"} fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"}
......
...@@ -26,10 +26,9 @@ import logging ...@@ -26,10 +26,9 @@ import logging
import numpy as np import numpy as np
import pytensor.scalar.basic as ps
from pytensor import compile, config from pytensor import compile, config
from pytensor.compile.ops import ViewOp from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph, Op
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter, NodeProcessingGraphRewriter,
...@@ -40,9 +39,24 @@ from pytensor.graph.rewriting.basic import ( ...@@ -40,9 +39,24 @@ from pytensor.graph.rewriting.basic import (
node_rewriter, node_rewriter,
) )
from pytensor.graph.rewriting.db import RewriteDatabase from pytensor.graph.rewriting.db import RewriteDatabase
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.npy_2_compat import normalize_axis_index from pytensor.npy_2_compat import normalize_axis_index
from pytensor.raise_op import Assert, CheckAndRaise, assert_op from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar.basic import Second from pytensor.scalar import (
AND,
EQ,
LE,
NEQ,
OR,
XOR,
Add,
BinaryScalarOp,
Cast,
Identity,
Mul,
Second,
Switch,
)
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
AllocEmpty, AllocEmpty,
...@@ -225,6 +239,12 @@ def register_uncanonicalize( ...@@ -225,6 +239,12 @@ def register_uncanonicalize(
return node_rewriter return node_rewriter
def elemwise_of(scalar_op: OpPatternOpTypeType | OpPattern) -> OpPattern:
if not isinstance(scalar_op, Op | OpPattern):
scalar_op = OpPattern(scalar_op)
return OpPattern(Elemwise, scalar_op=scalar_op)
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@node_rewriter([TensorFromScalar]) @node_rewriter([TensorFromScalar])
...@@ -551,7 +571,7 @@ def local_useless_elemwise(fgraph, node): ...@@ -551,7 +571,7 @@ def local_useless_elemwise(fgraph, node):
dtype = node.outputs[0].type.dtype dtype = node.outputs[0].type.dtype
scalar_op = node.op.scalar_op scalar_op = node.op.scalar_op
if isinstance(scalar_op, ps.EQ) and len(node.inputs) == 2: if isinstance(scalar_op, EQ) and len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]: if node.inputs[0] is node.inputs[1]:
# it is the same var in the graph. That will always be true # it is the same var in the graph. That will always be true
ret = ones_like(node.inputs[0], dtype=dtype, opt=True) ret = ones_like(node.inputs[0], dtype=dtype, opt=True)
...@@ -559,7 +579,7 @@ def local_useless_elemwise(fgraph, node): ...@@ -559,7 +579,7 @@ def local_useless_elemwise(fgraph, node):
# Copy stack trace from input to constant output # Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
return [ret] return [ret]
elif isinstance(scalar_op, ps.NEQ | ps.XOR) and len(node.inputs) == 2: elif isinstance(scalar_op, NEQ | XOR) and len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]: if node.inputs[0] is node.inputs[1]:
# it is the same var in the graph. That will always be false # it is the same var in the graph. That will always be false
ret = zeros_like(node.inputs[0], dtype=dtype, opt=True) ret = zeros_like(node.inputs[0], dtype=dtype, opt=True)
...@@ -568,14 +588,11 @@ def local_useless_elemwise(fgraph, node): ...@@ -568,14 +588,11 @@ def local_useless_elemwise(fgraph, node):
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
return [ret] return [ret]
elif ( elif isinstance(node.op.scalar_op, Mul | Add | Identity) and len(node.inputs) == 1:
isinstance(node.op.scalar_op, ps.Mul | ps.Add | ps.Identity)
and len(node.inputs) == 1
):
# No need to copy over any stack trace # No need to copy over any stack trace
return [node.inputs[0]] return [node.inputs[0]]
elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2: elif isinstance(node.op.scalar_op, AND) and len(node.inputs) == 2:
if ( if (
isinstance(node.inputs[0], TensorConstant) isinstance(node.inputs[0], TensorConstant)
and node.inputs[1].type.broadcastable == out_bcast and node.inputs[1].type.broadcastable == out_bcast
...@@ -602,7 +619,7 @@ def local_useless_elemwise(fgraph, node): ...@@ -602,7 +619,7 @@ def local_useless_elemwise(fgraph, node):
# and this rewrite would be wrong # and this rewrite would be wrong
return [node.inputs[0].astype(node.outputs[0].dtype)] return [node.inputs[0].astype(node.outputs[0].dtype)]
elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2: elif isinstance(node.op.scalar_op, OR) and len(node.inputs) == 2:
if ( if (
isinstance(node.inputs[0], TensorConstant) isinstance(node.inputs[0], TensorConstant)
and node.inputs[1].type.broadcastable == out_bcast and node.inputs[1].type.broadcastable == out_bcast
...@@ -653,7 +670,7 @@ def local_alloc_unary(fgraph, node): ...@@ -653,7 +670,7 @@ def local_alloc_unary(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@node_rewriter([Elemwise]) @node_rewriter([elemwise_of(Cast)])
def local_cast_cast(fgraph, node): def local_cast_cast(fgraph, node):
"""cast(cast(x, dtype1), dtype2) """cast(cast(x, dtype1), dtype2)
...@@ -663,13 +680,11 @@ def local_cast_cast(fgraph, node): ...@@ -663,13 +680,11 @@ def local_cast_cast(fgraph, node):
and the first cast cause an upcast. and the first cast cause an upcast.
""" """
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Cast)):
return
x = node.inputs[0] x = node.inputs[0]
if not ( if not (
x.owner x.owner
and isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op, Elemwise)
and isinstance(x.owner.op.scalar_op, ps.Cast) and isinstance(x.owner.op.scalar_op, Cast)
): ):
return return
...@@ -1009,7 +1024,7 @@ def local_useless_switch(fgraph, node): ...@@ -1009,7 +1024,7 @@ def local_useless_switch(fgraph, node):
node.outputs[0].type.ndim == 0 node.outputs[0].type.ndim == 0
and cond_var.owner and cond_var.owner
and isinstance(cond_var.owner.op, Elemwise) and isinstance(cond_var.owner.op, Elemwise)
and isinstance(cond_var.owner.op.scalar_op, ps.LE) and isinstance(cond_var.owner.op.scalar_op, LE)
and cond_var.owner.inputs[0].owner and cond_var.owner.inputs[0].owner
and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i)
and get_scalar_constant_value( and get_scalar_constant_value(
...@@ -1031,24 +1046,18 @@ def local_useless_switch(fgraph, node): ...@@ -1031,24 +1046,18 @@ def local_useless_switch(fgraph, node):
@register_canonicalize @register_canonicalize
@node_rewriter([Elemwise]) @node_rewriter([elemwise_of(BinaryScalarOp | Add | Mul)])
def local_merge_switch_same_cond(fgraph, node): def local_merge_switch_same_cond(fgraph, node):
""" """
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
condition, to enable further simplification of their branches condition, to enable further simplification of their branches
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y) Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
""" """
# node must be binary elemwise or add or mul
if not (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul)
):
return
# all inputs must be switch # all inputs must be switch
if not all( if not all(
s.owner s.owner
and isinstance(s.owner.op, Elemwise) and isinstance(s.owner.op, Elemwise)
and isinstance(s.owner.op.scalar_op, ps.Switch) and isinstance(s.owner.op.scalar_op, Switch)
for s in node.inputs for s in node.inputs
): ):
return return
...@@ -1174,10 +1183,9 @@ register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True) ...@@ -1174,10 +1183,9 @@ register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True)
@register_infer_shape @register_infer_shape
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@register_useless("fast_compile") @register_useless("fast_compile")
@node_rewriter(None) @node_rewriter([ViewOp])
def local_view_op(fgraph, node): def local_view_op(fgraph, node):
if isinstance(node.op, ViewOp): return node.inputs
return node.inputs
@register_infer_shape @register_infer_shape
......
from pytensor.compile.mode import optdb from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter from pytensor.graph import Constant, Op, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise, _squeeze_left from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
...@@ -20,6 +21,12 @@ from pytensor.tensor.subtensor import ( ...@@ -20,6 +21,12 @@ from pytensor.tensor.subtensor import (
) )
def blockwise_of(core_op: OpPatternOpTypeType | OpPattern) -> OpPattern:
if not isinstance(core_op, Op | OpPattern):
core_op = OpPattern(core_op)
return OpPattern(Blockwise, core_op=core_op)
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
def local_useless_blockwise(fgraph, node): def local_useless_blockwise(fgraph, node):
""" """
...@@ -71,22 +78,24 @@ optdb.register( ...@@ -71,22 +78,24 @@ optdb.register(
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@node_rewriter(tracks=[Blockwise]) @node_rewriter(
tracks=[
blockwise_of(
Dot
| Alloc
| ARange
| Subtensor
| AdvancedSubtensor
| AdvancedIncSubtensor
| Reshape
)
]
)
def local_eager_useless_unbatched_blockwise(fgraph, node): def local_eager_useless_unbatched_blockwise(fgraph, node):
if isinstance( # Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
node.op.core_op, # These other Ops can't always be trivially vectorized at runtime,
Dot # since their inputs may imply non-rectangular shapes.
| Alloc return local_useless_unbatched_blockwise.fn(fgraph, node)
| ARange
| Subtensor
| AdvancedSubtensor
| AdvancedIncSubtensor
| Reshape,
):
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
# These other Ops can't always be trivially vectorized at runtime,
# since their inputs may imply non-rectangular shapes.
return local_useless_unbatched_blockwise.fn(fgraph, node)
@register_specialize("shape_unsafe") @register_specialize("shape_unsafe")
...@@ -204,7 +213,7 @@ def local_blockwise_alloc(fgraph, node): ...@@ -204,7 +213,7 @@ def local_blockwise_alloc(fgraph, node):
@register_specialize @register_specialize
@node_rewriter([Blockwise]) @node_rewriter([blockwise_of(Reshape)])
def local_blockwise_reshape(fgraph, node): def local_blockwise_reshape(fgraph, node):
"""Rewrite away square Blockwise reshapes. """Rewrite away square Blockwise reshapes.
...@@ -215,9 +224,6 @@ def local_blockwise_reshape(fgraph, node): ...@@ -215,9 +224,6 @@ def local_blockwise_reshape(fgraph, node):
For the square Reshape case, we must wait for all the intermediate For the square Reshape case, we must wait for all the intermediate
operations to be lifted as Allocs operations to be lifted as Allocs
""" """
if not isinstance(node.op.core_op, Reshape):
return None
x, output_shape = node.inputs x, output_shape = node.inputs
batch_ndim = node.op.batch_ndim(node) batch_ndim = node.op.batch_ndim(node)
if all(output_shape.type.broadcastable[:batch_ndim]): if all(output_shape.type.broadcastable[:batch_ndim]):
......
...@@ -26,6 +26,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -26,6 +26,7 @@ from pytensor.graph.rewriting.basic import (
out2in, out2in,
) )
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
...@@ -37,6 +38,7 @@ from pytensor.tensor.math import add, exp, mul ...@@ -37,6 +38,7 @@ from pytensor.tensor.math import add, exp, mul
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
alloc_like, alloc_like,
broadcasted_by, broadcasted_by,
elemwise_of,
register_canonicalize, register_canonicalize,
register_specialize, register_specialize,
register_stabilize, register_stabilize,
...@@ -422,7 +424,14 @@ def local_useless_dimshuffle_makevector(fgraph, node): ...@@ -422,7 +424,14 @@ def local_useless_dimshuffle_makevector(fgraph, node):
@register_canonicalize @register_canonicalize
@node_rewriter([Elemwise]) @node_rewriter(
[
elemwise_of(
OpPattern(ps.ScalarOp, output_types_preference=ps.upgrade_to_float)
),
elemwise_of(OpPattern(ps.ScalarOp, output_types_preference=ps.upcast_out)),
]
)
def local_upcast_elemwise_constant_inputs(fgraph, node): def local_upcast_elemwise_constant_inputs(fgraph, node):
"""This explicitly upcasts constant inputs to elemwise Ops, when """This explicitly upcasts constant inputs to elemwise Ops, when
those Ops do implicit upcasting anyway. those Ops do implicit upcasting anyway.
...@@ -433,12 +442,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): ...@@ -433,12 +442,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
if len(node.outputs) > 1: if len(node.outputs) > 1:
return None return None
if getattr(node.op.scalar_op, "output_types_preference", None) not in (
ps.upgrade_to_float,
ps.upcast_out,
):
return None
# this is the kind of op that we can screw with the input # this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly # dtypes by upcasting explicitly
[old_out] = node.outputs [old_out] = node.outputs
...@@ -988,13 +991,9 @@ class FusionOptimizer(GraphRewriter): ...@@ -988,13 +991,9 @@ class FusionOptimizer(GraphRewriter):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@node_rewriter([Elemwise]) @node_rewriter([elemwise_of(ps.Composite)])
def local_useless_composite_outputs(fgraph, node): def local_useless_composite_outputs(fgraph, node):
"""Remove inputs and outputs of Composite Ops that are not used anywhere.""" """Remove inputs and outputs of Composite Ops that are not used anywhere."""
if not (
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Composite)
):
return
comp = node.op.scalar_op comp = node.op.scalar_op
used_outputs_idxs = [ used_outputs_idxs = [
i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern] i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]
...@@ -1104,14 +1103,10 @@ def local_careduce_fusion(fgraph, node): ...@@ -1104,14 +1103,10 @@ def local_careduce_fusion(fgraph, node):
return [new_car_op(*elm_inputs)] return [new_car_op(*elm_inputs)]
@node_rewriter([Elemwise]) @node_rewriter([elemwise_of(ps.Composite)])
def local_inline_composite_constants(fgraph, node): def local_inline_composite_constants(fgraph, node):
"""Inline scalar constants in Composite graphs.""" """Inline scalar constants in Composite graphs."""
composite_op = node.op.scalar_op composite_op = node.op.scalar_op
if not isinstance(composite_op, ps.Composite):
return None
new_outer_inputs = [] new_outer_inputs = []
new_inner_inputs = [] new_inner_inputs = []
inner_replacements = {} inner_replacements = {}
...@@ -1287,14 +1282,9 @@ def _rebuild_partial_2f1grad_loop(node, wrt): ...@@ -1287,14 +1282,9 @@ def _rebuild_partial_2f1grad_loop(node, wrt):
@register_specialize @register_specialize
@node_rewriter([Elemwise]) @node_rewriter([elemwise_of(Grad2F1Loop)])
def local_useless_2f1grad_loop(fgraph, node): def local_useless_2f1grad_loop(fgraph, node):
# Remove unused terms from the hyp2f1 grad loop # Remove unused terms from the hyp2f1 grad loop
loop_op = node.op.scalar_op
if not isinstance(loop_op, Grad2F1Loop):
return
grad_related_vars = node.outputs[:-4] grad_related_vars = node.outputs[:-4]
# Rewrite was already applied # Rewrite was already applied
if len(grad_related_vars) // 3 != 3: if len(grad_related_vars) // 3 != 3:
...@@ -1326,18 +1316,13 @@ def local_useless_2f1grad_loop(fgraph, node): ...@@ -1326,18 +1316,13 @@ def local_useless_2f1grad_loop(fgraph, node):
return replacements return replacements
@node_rewriter([Elemwise]) @node_rewriter([elemwise_of(Grad2F1Loop)])
def split_2f1grad_loop(fgraph, node): def split_2f1grad_loop(fgraph, node):
""" """
2f1grad loop has too many operands for Numpy frompyfunc code used by Elemwise nodes on python mode. 2f1grad loop has too many operands for Numpy frompyfunc code used by Elemwise nodes on python mode.
This rewrite splits it across 3 different operations. It is not needed if `local_useless_2f1grad_loop` was applied This rewrite splits it across 3 different operations. It is not needed if `local_useless_2f1grad_loop` was applied
""" """
loop_op = node.op.scalar_op
if not isinstance(loop_op, Grad2F1Loop):
return None
grad_related_vars = node.outputs[:-4] grad_related_vars = node.outputs[:-4]
# local_useless_2f1grad_loop was used, we should be safe # local_useless_2f1grad_loop was used, we should be safe
if len(grad_related_vars) // 3 != 3: if len(grad_related_vars) // 3 != 3:
......
...@@ -37,7 +37,6 @@ from pytensor.tensor.basic import ( ...@@ -37,7 +37,6 @@ from pytensor.tensor.basic import (
zeros, zeros,
zeros_like, zeros_like,
) )
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
...@@ -49,6 +48,11 @@ from pytensor.tensor.math import ( ...@@ -49,6 +48,11 @@ from pytensor.tensor.math import (
_dot, _dot,
_matmul, _matmul,
add, add,
arccosh,
arcsinh,
arctanh,
cosh,
deg2rad,
digamma, digamma,
dot, dot,
erf, erf,
...@@ -70,13 +74,16 @@ from pytensor.tensor.math import ( ...@@ -70,13 +74,16 @@ from pytensor.tensor.math import (
neg, neg,
polygamma, polygamma,
prod, prod,
rad2deg,
reciprocal, reciprocal,
sigmoid, sigmoid,
sign, sign,
sinh,
softplus, softplus,
sqr, sqr,
sqrt, sqrt,
sub, sub,
tanh,
tri_gamma, tri_gamma,
true_div, true_div,
variadic_add, variadic_add,
...@@ -96,6 +103,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -96,6 +103,7 @@ from pytensor.tensor.rewriting.basic import (
register_uncanonicalize, register_uncanonicalize,
register_useless, register_useless,
) )
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.shape import Shape, Shape_i
...@@ -151,7 +159,7 @@ def local_0_dot_x(fgraph, node): ...@@ -151,7 +159,7 @@ def local_0_dot_x(fgraph, node):
@register_stabilize @register_stabilize
@node_rewriter([Blockwise]) @node_rewriter([blockwise_of(BlockDiagonal)])
def local_block_diag_dot_to_dot_block_diag(fgraph, node): def local_block_diag_dot_to_dot_block_diag(fgraph, node):
r""" r"""
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))`` Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
...@@ -160,9 +168,6 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node): ...@@ -160,9 +168,6 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
a single dot on the larger matrix. a single dot on the larger matrix.
""" """
if not isinstance(node.op.core_op, BlockDiagonal):
return
# Check that the BlockDiagonal is an input to a Dot node: # Check that the BlockDiagonal is an input to a Dot node:
for client in itertools.chain.from_iterable( for client in itertools.chain.from_iterable(
get_clients_at_depth(fgraph, node, depth=i) for i in [1, 2] get_clients_at_depth(fgraph, node, depth=i) for i in [1, 2]
...@@ -424,60 +429,30 @@ def local_dot_to_mul(fgraph, node): ...@@ -424,60 +429,30 @@ def local_dot_to_mul(fgraph, node):
return [new_out] return [new_out]
def is_inverse_pair(node_op, prev_op, inv_pair): for pair in (
""" (deg2rad, rad2deg),
Given two consecutive operations, check if they are the (cosh, arccosh),
provided pair of inverse functions. (tanh, arctanh),
(sinh, arcsinh),
""" (_conj, _conj),
node_is_op0 = isinstance(node_op, inv_pair[0]) (neg, neg),
node_is_op1 = isinstance(node_op, inv_pair[1]) (reciprocal, reciprocal),
prev_is_op0 = isinstance(prev_op, inv_pair[0]) ):
prev_is_op1 = isinstance(prev_op, inv_pair[1]) # Create a simple PatternNodeRewriter for each pair of opposite ops
# instead of a general Op that is called to often for very few hits
return (node_is_op0 and prev_is_op1) or (node_is_op1 and prev_is_op0) for op, inv_op in (pair, reversed(pair)):
rewrite = PatternNodeRewriter(
(op, (inv_op, "x")),
@register_canonicalize "x",
@register_specialize allow_multiple_clients=True,
@node_rewriter([Elemwise]) allow_cast=True,
def local_func_inv(fgraph, node): name=f"useless_{op}_of_{inv_op}",
""" )
Check for two consecutive operations that are functional inverses register_canonicalize(rewrite)
and remove them from the function graph. register_specialize(rewrite)
"""
inv_pairs = (
(ps.Deg2Rad, ps.Rad2Deg),
(ps.Cosh, ps.ArcCosh),
(ps.Tanh, ps.ArcTanh),
(ps.Sinh, ps.ArcSinh),
(ps.Conj, ps.Conj),
(ps.Neg, ps.Neg),
(ps.Reciprocal, ps.Reciprocal),
)
x = node.inputs[0]
if not isinstance(node.op, Elemwise):
return
if not (x.owner and isinstance(x.owner.op, Elemwise)):
return
prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op
for inv_pair in inv_pairs:
if is_inverse_pair(node_op, prev_op, inv_pair):
# We don't need to copy stack trace, because the rewrite
# is trivial and maintains the earlier stack trace
ottype = node.out.dtype
inp = x.owner.inputs[0]
# Functions may have casted integer input to float
if inp.dtype != ottype:
inp = cast(inp, ottype)
return [inp]
return if op is inv_op:
break # Same Op, no need to define two rewrites
@register_canonicalize @register_canonicalize
......
...@@ -35,7 +35,7 @@ from pytensor.tensor.basic import ( ...@@ -35,7 +35,7 @@ from pytensor.tensor.basic import (
switch, switch,
) )
from pytensor.tensor.basic import constant as tensor_constant from pytensor.tensor.basic import constant as tensor_constant
from pytensor.tensor.blockwise import Blockwise, _squeeze_left from pytensor.tensor.blockwise import _squeeze_left
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.extra_ops import broadcast_to
...@@ -58,6 +58,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -58,6 +58,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize, register_specialize,
register_stabilize, register_stabilize,
) )
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
shape_padleft, shape_padleft,
shape_padright, shape_padright,
...@@ -974,33 +975,30 @@ def local_IncSubtensor_serialize(fgraph, node): ...@@ -974,33 +975,30 @@ def local_IncSubtensor_serialize(fgraph, node):
and not i.owner.op.set_instead_of_inc and not i.owner.op.set_instead_of_inc
) )
if node.op == add: o_type = node.outputs[0].type
o_type = node.outputs[0].type
movable_inputs = [i for i in node.inputs if movable(i)] movable_inputs = [i for i in node.inputs if movable(i)]
if movable_inputs: if movable_inputs:
new_inputs = [i for i in node.inputs if not movable(i)] + [ new_inputs = [i for i in node.inputs if not movable(i)] + [
mi.owner.inputs[0] for mi in movable_inputs mi.owner.inputs[0] for mi in movable_inputs
] ]
new_add = variadic_add(*new_inputs) new_add = variadic_add(*new_inputs)
# Copy over stacktrace from original output, as an error # Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should # (e.g. an index error) in this add operation should
# correspond to an error in the original add operation. # correspond to an error in the original add operation.
copy_stack_trace(node.outputs[0], new_add) copy_stack_trace(node.outputs[0], new_add)
# stack up the new incsubtensors
tip = new_add
for mi in movable_inputs:
assert o_type.is_super(tip.type)
tip = mi.owner.op(tip, *mi.owner.inputs[1:])
# Copy over stacktrace from outputs of the original
# "movable" operation to the new operation.
copy_stack_trace(node.outputs + mi.owner.outputs, tip)
return [tip] # stack up the new incsubtensors
tip = new_add
for mi in movable_inputs:
assert o_type.is_super(tip.type)
tip = mi.owner.op(tip, *mi.owner.inputs[1:])
# Copy over stacktrace from outputs of the original
# "movable" operation to the new operation.
copy_stack_trace(node.outputs + mi.owner.outputs, tip)
# print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs] return [tip]
# We register it in a WalkingGraphRewriter inside the canonizer EQ optimizer. # We register it in a WalkingGraphRewriter inside the canonizer EQ optimizer.
...@@ -1576,7 +1574,7 @@ compile.optdb.register( ...@@ -1576,7 +1574,7 @@ compile.optdb.register(
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@node_rewriter([Blockwise]) @node_rewriter([blockwise_of(Subtensor)])
def local_blockwise_of_subtensor(fgraph, node): def local_blockwise_of_subtensor(fgraph, node):
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor. """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
...@@ -1585,9 +1583,6 @@ def local_blockwise_of_subtensor(fgraph, node): ...@@ -1585,9 +1583,6 @@ def local_blockwise_of_subtensor(fgraph, node):
TODO: Handle batched indices like we do with blockwise of inc_subtensor TODO: Handle batched indices like we do with blockwise of inc_subtensor
TODO: Extend to AdvanceSubtensor TODO: Extend to AdvanceSubtensor
""" """
if not isinstance(node.op.core_op, Subtensor):
return
x, *idxs = node.inputs x, *idxs = node.inputs
if not all(all(idx.type.broadcastable) for idx in idxs): if not all(all(idx.type.broadcastable) for idx in idxs):
return return
...@@ -1603,7 +1598,7 @@ def local_blockwise_of_subtensor(fgraph, node): ...@@ -1603,7 +1598,7 @@ def local_blockwise_of_subtensor(fgraph, node):
@register_canonicalize("shape_unsafe") @register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe") @register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe") @register_specialize("shape_unsafe")
@node_rewriter([Blockwise]) @node_rewriter([blockwise_of(IncSubtensor | AdvancedIncSubtensor)])
def local_blockwise_inc_subtensor(fgraph, node): def local_blockwise_inc_subtensor(fgraph, node):
"""Rewrite blockwised inc_subtensors. """Rewrite blockwised inc_subtensors.
...@@ -1614,12 +1609,9 @@ def local_blockwise_inc_subtensor(fgraph, node): ...@@ -1614,12 +1609,9 @@ def local_blockwise_inc_subtensor(fgraph, node):
and can be safely rewritten without Blockwise. and can be safely rewritten without Blockwise.
""" """
core_op = node.op.core_op core_op = node.op.core_op
if not isinstance(core_op, AdvancedIncSubtensor | IncSubtensor):
return None
x, y, *idxs = node.inputs x, y, *idxs = node.inputs
[out] = node.outputs [out] = node.outputs
if isinstance(node.op.core_op, AdvancedIncSubtensor): if isinstance(core_op, AdvancedIncSubtensor):
if any( if any(
( (
# Blockwise requires all inputs to be tensors so it is not possible # Blockwise requires all inputs to be tensors so it is not possible
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论