提交 27d79707 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use OpPattern in tracks

上级 19f1486b
......@@ -1228,6 +1228,8 @@ class ScalarOp(COp):
f"(got: {output_types_preference})"
)
self.output_types_preference = output_types_preference
elif not hasattr(self, "output_types_preference"):
self.output_types_preference = None
def make_node(self, *inputs):
if self.nin >= 0:
......@@ -1247,7 +1249,7 @@ class ScalarOp(COp):
return Apply(self, inputs, outputs)
def output_types(self, types):
if hasattr(self, "output_types_preference"):
if self.output_types_preference is not None:
variables = self.output_types_preference(*types)
if not isinstance(variables, list | tuple) or any(
not isinstance(x, CType) for x in variables
......@@ -2696,7 +2698,7 @@ class Sign(UnaryScalarOp):
nfunc_spec = ("sign", 1, 1)
@staticmethod
def output_types_preference(x):
def _output_types_preference(x):
if x == bool:
raise TypeError(x)
return same_out_nocomplex(x)
......@@ -2737,7 +2739,7 @@ class Sign(UnaryScalarOp):
return s
sign = Sign(name="sign")
sign = Sign(name="sign", output_types_preference=Sign._output_types_preference)
class Ceil(UnaryScalarOp):
......
......@@ -14,6 +14,7 @@ from pytensor.tensor.basic import atleast_Nd
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable
......@@ -227,7 +228,7 @@ def _scan_split_non_sequence_decomposition_and_solve(
@register_specialize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(Solve)])
def reuse_decomposition_multiple_solves(fgraph, node):
return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"}
......
......@@ -26,10 +26,9 @@ import logging
import numpy as np
import pytensor.scalar.basic as ps
from pytensor import compile, config
from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph
from pytensor.graph import FunctionGraph, Op
from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
......@@ -40,9 +39,24 @@ from pytensor.graph.rewriting.basic import (
node_rewriter,
)
from pytensor.graph.rewriting.db import RewriteDatabase
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.npy_2_compat import normalize_axis_index
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar.basic import Second
from pytensor.scalar import (
AND,
EQ,
LE,
NEQ,
OR,
XOR,
Add,
BinaryScalarOp,
Cast,
Identity,
Mul,
Second,
Switch,
)
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
......@@ -225,6 +239,12 @@ def register_uncanonicalize(
return node_rewriter
def elemwise_of(scalar_op: OpPatternOpTypeType | OpPattern) -> OpPattern:
if not isinstance(scalar_op, Op | OpPattern):
scalar_op = OpPattern(scalar_op)
return OpPattern(Elemwise, scalar_op=scalar_op)
@register_canonicalize
@register_specialize
@node_rewriter([TensorFromScalar])
......@@ -551,7 +571,7 @@ def local_useless_elemwise(fgraph, node):
dtype = node.outputs[0].type.dtype
scalar_op = node.op.scalar_op
if isinstance(scalar_op, ps.EQ) and len(node.inputs) == 2:
if isinstance(scalar_op, EQ) and len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]:
# it is the same var in the graph. That will always be true
ret = ones_like(node.inputs[0], dtype=dtype, opt=True)
......@@ -559,7 +579,7 @@ def local_useless_elemwise(fgraph, node):
# Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret)
return [ret]
elif isinstance(scalar_op, ps.NEQ | ps.XOR) and len(node.inputs) == 2:
elif isinstance(scalar_op, NEQ | XOR) and len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]:
# it is the same var in the graph. That will always be false
ret = zeros_like(node.inputs[0], dtype=dtype, opt=True)
......@@ -568,14 +588,11 @@ def local_useless_elemwise(fgraph, node):
copy_stack_trace(node.outputs[0], ret)
return [ret]
elif (
isinstance(node.op.scalar_op, ps.Mul | ps.Add | ps.Identity)
and len(node.inputs) == 1
):
elif isinstance(node.op.scalar_op, Mul | Add | Identity) and len(node.inputs) == 1:
# No need to copy over any stack trace
return [node.inputs[0]]
elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2:
elif isinstance(node.op.scalar_op, AND) and len(node.inputs) == 2:
if (
isinstance(node.inputs[0], TensorConstant)
and node.inputs[1].type.broadcastable == out_bcast
......@@ -602,7 +619,7 @@ def local_useless_elemwise(fgraph, node):
# and this rewrite would be wrong
return [node.inputs[0].astype(node.outputs[0].dtype)]
elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2:
elif isinstance(node.op.scalar_op, OR) and len(node.inputs) == 2:
if (
isinstance(node.inputs[0], TensorConstant)
and node.inputs[1].type.broadcastable == out_bcast
......@@ -653,7 +670,7 @@ def local_alloc_unary(fgraph, node):
@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
@node_rewriter([elemwise_of(Cast)])
def local_cast_cast(fgraph, node):
"""cast(cast(x, dtype1), dtype2)
......@@ -663,13 +680,11 @@ def local_cast_cast(fgraph, node):
and the first cast cause an upcast.
"""
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Cast)):
return
x = node.inputs[0]
if not (
x.owner
and isinstance(x.owner.op, Elemwise)
and isinstance(x.owner.op.scalar_op, ps.Cast)
and isinstance(x.owner.op.scalar_op, Cast)
):
return
......@@ -1009,7 +1024,7 @@ def local_useless_switch(fgraph, node):
node.outputs[0].type.ndim == 0
and cond_var.owner
and isinstance(cond_var.owner.op, Elemwise)
and isinstance(cond_var.owner.op.scalar_op, ps.LE)
and isinstance(cond_var.owner.op.scalar_op, LE)
and cond_var.owner.inputs[0].owner
and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i)
and get_scalar_constant_value(
......@@ -1031,24 +1046,18 @@ def local_useless_switch(fgraph, node):
@register_canonicalize
@node_rewriter([Elemwise])
@node_rewriter([elemwise_of(BinaryScalarOp | Add | Mul)])
def local_merge_switch_same_cond(fgraph, node):
"""
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
condition, to enable further simplification of their branches
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
"""
# node must be binary elemwise or add or mul
if not (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul)
):
return
# all inputs must be switch
if not all(
s.owner
and isinstance(s.owner.op, Elemwise)
and isinstance(s.owner.op.scalar_op, ps.Switch)
and isinstance(s.owner.op.scalar_op, Switch)
for s in node.inputs
):
return
......@@ -1174,10 +1183,9 @@ register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True)
@register_infer_shape
@register_canonicalize("fast_compile")
@register_useless("fast_compile")
@node_rewriter(None)
@node_rewriter([ViewOp])
def local_view_op(fgraph, node):
if isinstance(node.op, ViewOp):
return node.inputs
return node.inputs
@register_infer_shape
......
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph import Constant, Op, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.math import Dot
......@@ -20,6 +21,12 @@ from pytensor.tensor.subtensor import (
)
def blockwise_of(core_op: OpPatternOpTypeType | OpPattern) -> OpPattern:
if not isinstance(core_op, Op | OpPattern):
core_op = OpPattern(core_op)
return OpPattern(Blockwise, core_op=core_op)
@node_rewriter([Blockwise])
def local_useless_blockwise(fgraph, node):
"""
......@@ -71,22 +78,24 @@ optdb.register(
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter(tracks=[Blockwise])
@node_rewriter(
tracks=[
blockwise_of(
Dot
| Alloc
| ARange
| Subtensor
| AdvancedSubtensor
| AdvancedIncSubtensor
| Reshape
)
]
)
def local_eager_useless_unbatched_blockwise(fgraph, node):
if isinstance(
node.op.core_op,
Dot
| Alloc
| ARange
| Subtensor
| AdvancedSubtensor
| AdvancedIncSubtensor
| Reshape,
):
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
# These other Ops can't always be trivially vectorized at runtime,
# since their inputs may imply non-rectangular shapes.
return local_useless_unbatched_blockwise.fn(fgraph, node)
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
# These other Ops can't always be trivially vectorized at runtime,
# since their inputs may imply non-rectangular shapes.
return local_useless_unbatched_blockwise.fn(fgraph, node)
@register_specialize("shape_unsafe")
......@@ -204,7 +213,7 @@ def local_blockwise_alloc(fgraph, node):
@register_specialize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(Reshape)])
def local_blockwise_reshape(fgraph, node):
"""Rewrite away square Blockwise reshapes.
......@@ -215,9 +224,6 @@ def local_blockwise_reshape(fgraph, node):
For the square Reshape case, we must wait for all the intermediate
operations to be lifted as Allocs
"""
if not isinstance(node.op.core_op, Reshape):
return None
x, output_shape = node.inputs
batch_ndim = node.op.batch_ndim(node)
if all(output_shape.type.broadcastable[:batch_ndim]):
......
......@@ -26,6 +26,7 @@ from pytensor.graph.rewriting.basic import (
out2in,
)
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import (
......@@ -37,6 +38,7 @@ from pytensor.tensor.math import add, exp, mul
from pytensor.tensor.rewriting.basic import (
alloc_like,
broadcasted_by,
elemwise_of,
register_canonicalize,
register_specialize,
register_stabilize,
......@@ -422,7 +424,14 @@ def local_useless_dimshuffle_makevector(fgraph, node):
@register_canonicalize
@node_rewriter([Elemwise])
@node_rewriter(
[
elemwise_of(
OpPattern(ps.ScalarOp, output_types_preference=ps.upgrade_to_float)
),
elemwise_of(OpPattern(ps.ScalarOp, output_types_preference=ps.upcast_out)),
]
)
def local_upcast_elemwise_constant_inputs(fgraph, node):
"""This explicitly upcasts constant inputs to elemwise Ops, when
those Ops do implicit upcasting anyway.
......@@ -433,12 +442,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
if len(node.outputs) > 1:
return None
if getattr(node.op.scalar_op, "output_types_preference", None) not in (
ps.upgrade_to_float,
ps.upcast_out,
):
return None
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
[old_out] = node.outputs
......@@ -988,13 +991,9 @@ class FusionOptimizer(GraphRewriter):
@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
@node_rewriter([elemwise_of(ps.Composite)])
def local_useless_composite_outputs(fgraph, node):
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
if not (
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Composite)
):
return
comp = node.op.scalar_op
used_outputs_idxs = [
i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]
......@@ -1104,14 +1103,10 @@ def local_careduce_fusion(fgraph, node):
return [new_car_op(*elm_inputs)]
@node_rewriter([Elemwise])
@node_rewriter([elemwise_of(ps.Composite)])
def local_inline_composite_constants(fgraph, node):
"""Inline scalar constants in Composite graphs."""
composite_op = node.op.scalar_op
if not isinstance(composite_op, ps.Composite):
return None
new_outer_inputs = []
new_inner_inputs = []
inner_replacements = {}
......@@ -1287,14 +1282,9 @@ def _rebuild_partial_2f1grad_loop(node, wrt):
@register_specialize
@node_rewriter([Elemwise])
@node_rewriter([elemwise_of(Grad2F1Loop)])
def local_useless_2f1grad_loop(fgraph, node):
# Remove unused terms from the hyp2f1 grad loop
loop_op = node.op.scalar_op
if not isinstance(loop_op, Grad2F1Loop):
return
grad_related_vars = node.outputs[:-4]
# Rewrite was already applied
if len(grad_related_vars) // 3 != 3:
......@@ -1326,18 +1316,13 @@ def local_useless_2f1grad_loop(fgraph, node):
return replacements
@node_rewriter([Elemwise])
@node_rewriter([elemwise_of(Grad2F1Loop)])
def split_2f1grad_loop(fgraph, node):
"""
2f1grad loop has too many operands for Numpy frompyfunc code used by Elemwise nodes on python mode.
This rewrite splits it across 3 different operations. It is not needed if `local_useless_2f1grad_loop` was applied
"""
loop_op = node.op.scalar_op
if not isinstance(loop_op, Grad2F1Loop):
return None
grad_related_vars = node.outputs[:-4]
# local_useless_2f1grad_loop was used, we should be safe
if len(grad_related_vars) // 3 != 3:
......
......@@ -37,7 +37,6 @@ from pytensor.tensor.basic import (
zeros,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
......@@ -49,6 +48,11 @@ from pytensor.tensor.math import (
_dot,
_matmul,
add,
arccosh,
arcsinh,
arctanh,
cosh,
deg2rad,
digamma,
dot,
erf,
......@@ -70,13 +74,16 @@ from pytensor.tensor.math import (
neg,
polygamma,
prod,
rad2deg,
reciprocal,
sigmoid,
sign,
sinh,
softplus,
sqr,
sqrt,
sub,
tanh,
tri_gamma,
true_div,
variadic_add,
......@@ -96,6 +103,7 @@ from pytensor.tensor.rewriting.basic import (
register_uncanonicalize,
register_useless,
)
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.shape import Shape, Shape_i
......@@ -151,7 +159,7 @@ def local_0_dot_x(fgraph, node):
@register_stabilize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(BlockDiagonal)])
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
r"""
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
......@@ -160,9 +168,6 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
a single dot on the larger matrix.
"""
if not isinstance(node.op.core_op, BlockDiagonal):
return
# Check that the BlockDiagonal is an input to a Dot node:
for client in itertools.chain.from_iterable(
get_clients_at_depth(fgraph, node, depth=i) for i in [1, 2]
......@@ -424,60 +429,30 @@ def local_dot_to_mul(fgraph, node):
return [new_out]
def is_inverse_pair(node_op, prev_op, inv_pair):
"""
Given two consecutive operations, check if they are the
provided pair of inverse functions.
"""
node_is_op0 = isinstance(node_op, inv_pair[0])
node_is_op1 = isinstance(node_op, inv_pair[1])
prev_is_op0 = isinstance(prev_op, inv_pair[0])
prev_is_op1 = isinstance(prev_op, inv_pair[1])
return (node_is_op0 and prev_is_op1) or (node_is_op1 and prev_is_op0)
@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
def local_func_inv(fgraph, node):
"""
Check for two consecutive operations that are functional inverses
and remove them from the function graph.
"""
inv_pairs = (
(ps.Deg2Rad, ps.Rad2Deg),
(ps.Cosh, ps.ArcCosh),
(ps.Tanh, ps.ArcTanh),
(ps.Sinh, ps.ArcSinh),
(ps.Conj, ps.Conj),
(ps.Neg, ps.Neg),
(ps.Reciprocal, ps.Reciprocal),
)
x = node.inputs[0]
if not isinstance(node.op, Elemwise):
return
if not (x.owner and isinstance(x.owner.op, Elemwise)):
return
prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op
for inv_pair in inv_pairs:
if is_inverse_pair(node_op, prev_op, inv_pair):
# We don't need to copy stack trace, because the rewrite
# is trivial and maintains the earlier stack trace
ottype = node.out.dtype
inp = x.owner.inputs[0]
# Functions may have casted integer input to float
if inp.dtype != ottype:
inp = cast(inp, ottype)
return [inp]
for pair in (
(deg2rad, rad2deg),
(cosh, arccosh),
(tanh, arctanh),
(sinh, arcsinh),
(_conj, _conj),
(neg, neg),
(reciprocal, reciprocal),
):
# Create a simple PatternNodeRewriter for each pair of opposite ops
# instead of a general Op that is called to often for very few hits
for op, inv_op in (pair, reversed(pair)):
rewrite = PatternNodeRewriter(
(op, (inv_op, "x")),
"x",
allow_multiple_clients=True,
allow_cast=True,
name=f"useless_{op}_of_{inv_op}",
)
register_canonicalize(rewrite)
register_specialize(rewrite)
return
if op is inv_op:
break # Same Op, no need to define two rewrites
@register_canonicalize
......
......@@ -35,7 +35,7 @@ from pytensor.tensor.basic import (
switch,
)
from pytensor.tensor.basic import constant as tensor_constant
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.blockwise import _squeeze_left
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_to
......@@ -58,6 +58,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize,
register_stabilize,
)
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.shape import (
shape_padleft,
shape_padright,
......@@ -974,33 +975,30 @@ def local_IncSubtensor_serialize(fgraph, node):
and not i.owner.op.set_instead_of_inc
)
if node.op == add:
o_type = node.outputs[0].type
o_type = node.outputs[0].type
movable_inputs = [i for i in node.inputs if movable(i)]
movable_inputs = [i for i in node.inputs if movable(i)]
if movable_inputs:
new_inputs = [i for i in node.inputs if not movable(i)] + [
mi.owner.inputs[0] for mi in movable_inputs
]
new_add = variadic_add(*new_inputs)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
copy_stack_trace(node.outputs[0], new_add)
# stack up the new incsubtensors
tip = new_add
for mi in movable_inputs:
assert o_type.is_super(tip.type)
tip = mi.owner.op(tip, *mi.owner.inputs[1:])
# Copy over stacktrace from outputs of the original
# "movable" operation to the new operation.
copy_stack_trace(node.outputs + mi.owner.outputs, tip)
if movable_inputs:
new_inputs = [i for i in node.inputs if not movable(i)] + [
mi.owner.inputs[0] for mi in movable_inputs
]
new_add = variadic_add(*new_inputs)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
copy_stack_trace(node.outputs[0], new_add)
return [tip]
# stack up the new incsubtensors
tip = new_add
for mi in movable_inputs:
assert o_type.is_super(tip.type)
tip = mi.owner.op(tip, *mi.owner.inputs[1:])
# Copy over stacktrace from outputs of the original
# "movable" operation to the new operation.
copy_stack_trace(node.outputs + mi.owner.outputs, tip)
# print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs]
return [tip]
# We register it in a WalkingGraphRewriter inside the canonizer EQ optimizer.
......@@ -1576,7 +1574,7 @@ compile.optdb.register(
@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(Subtensor)])
def local_blockwise_of_subtensor(fgraph, node):
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
......@@ -1585,9 +1583,6 @@ def local_blockwise_of_subtensor(fgraph, node):
TODO: Handle batched indices like we do with blockwise of inc_subtensor
TODO: Extend to AdvanceSubtensor
"""
if not isinstance(node.op.core_op, Subtensor):
return
x, *idxs = node.inputs
if not all(all(idx.type.broadcastable) for idx in idxs):
return
......@@ -1603,7 +1598,7 @@ def local_blockwise_of_subtensor(fgraph, node):
@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(IncSubtensor | AdvancedIncSubtensor)])
def local_blockwise_inc_subtensor(fgraph, node):
"""Rewrite blockwised inc_subtensors.
......@@ -1614,12 +1609,9 @@ def local_blockwise_inc_subtensor(fgraph, node):
and can be safely rewritten without Blockwise.
"""
core_op = node.op.core_op
if not isinstance(core_op, AdvancedIncSubtensor | IncSubtensor):
return None
x, y, *idxs = node.inputs
[out] = node.outputs
if isinstance(node.op.core_op, AdvancedIncSubtensor):
if isinstance(core_op, AdvancedIncSubtensor):
if any(
(
# Blockwise requires all inputs to be tensors so it is not possible
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论