提交 f25a624a authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: Ricardo Vieira
上级 23427a0a
...@@ -4,6 +4,7 @@ from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify ...@@ -4,6 +4,7 @@ from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
# Load dispatch specializations # Load dispatch specializations
import pytensor.link.jax.dispatch.blas import pytensor.link.jax.dispatch.blas
import pytensor.link.jax.dispatch.blockwise import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.einsum
import pytensor.link.jax.dispatch.elemwise import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.extra_ops import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.pad import pytensor.link.jax.dispatch.pad
......
import jax.numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.einsum import Einsum
@jax_funcify.register(Einsum)
def jax_funcify_Einsum(op, **kwargs):
"""Dispatch einsum to JAX.
This dispatch is triggered only when we couldn't optimize einsum at the PyTensor level.
This happens when some of the dimension lengths are unknown. This is never a problem in JAX,
as it always compiles a function per runtime input shape.
"""
subscripts = op.subscripts
def einsum(*operands):
return jnp.einsum(subscripts, *operands, optimize="optimal")
return einsum
...@@ -151,6 +151,7 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable ...@@ -151,6 +151,7 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable
# isort: off # isort: off
from pytensor.tensor.einsum import einsum
from pytensor.tensor.functional import vectorize from pytensor.tensor.functional import vectorize
# isort: on # isort: on
......
...@@ -1700,21 +1700,22 @@ class Alloc(COp): ...@@ -1700,21 +1700,22 @@ class Alloc(COp):
return False return False
for client, idx in clients: for client, idx in clients:
if isinstance(client.op, Output): client_op = client.op
if isinstance(client_op, Output):
# If the output is a constant, it will have to be deepcopied # If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold. # each time the function is called. So we do not fold.
return False return False
# Allow alloc to be lifted out of Elemwise before constant folding it # Op's through which Alloc can be lifted
elif isinstance(client.op, Elemwise): elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join):
return None return False
# Same for Blockwise, unless it has no batch_dims # Same for Blockwise, unless it has no batch_dims
elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client): elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client):
return None return False
elif ( elif (
# The following ops work inplace of their input id 0. # The following ops work inplace of their input id 0.
idx == 0 idx == 0
and isinstance( and isinstance(
client.op, client_op,
pytensor.tensor.subtensor.IncSubtensor pytensor.tensor.subtensor.IncSubtensor
| pytensor.tensor.subtensor.AdvancedIncSubtensor1 | pytensor.tensor.subtensor.AdvancedIncSubtensor1
| pytensor.tensor.subtensor.AdvancedIncSubtensor | pytensor.tensor.subtensor.AdvancedIncSubtensor
...@@ -2035,10 +2036,15 @@ def transpose(x, axes=None): ...@@ -2035,10 +2036,15 @@ def transpose(x, axes=None):
_x = as_tensor_variable(x) _x = as_tensor_variable(x)
if axes is None: if axes is None:
axes = list(range((_x.type.ndim - 1), -1, -1)) axes = tuple(range((_x.type.ndim - 1), -1, -1))
if tuple(axes) == tuple(range(len(axes))):
# No-op
return _x
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x) ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)): if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)):
ret.name = _x.name + ".T" ret.name = _x.name + ".T"
return ret return ret
...@@ -3950,6 +3956,10 @@ def moveaxis( ...@@ -3950,6 +3956,10 @@ def moveaxis(
source = normalize_axis_tuple(source, a.ndim, "source") source = normalize_axis_tuple(source, a.ndim, "source")
destination = normalize_axis_tuple(destination, a.ndim, "destination") destination = normalize_axis_tuple(destination, a.ndim, "destination")
if source == destination:
# It's a no-op
return a
if len(source) != len(destination): if len(source) != len(destination):
raise ValueError( raise ValueError(
"`source` and `destination` arguments must have the same number of elements" "`source` and `destination` arguments must have the same number of elements"
...@@ -4260,9 +4270,7 @@ atleast_2d = partial(atleast_Nd, n=2) ...@@ -4260,9 +4270,7 @@ atleast_2d = partial(atleast_Nd, n=2)
atleast_3d = partial(atleast_Nd, n=3) atleast_3d = partial(atleast_Nd, n=3)
def expand_dims( def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
a: np.ndarray | TensorVariable, axis: tuple[int, ...]
) -> TensorVariable:
"""Expand the shape of an array. """Expand the shape of an array.
Insert a new axis that will appear at the `axis` position in the expanded Insert a new axis that will appear at the `axis` position in the expanded
...@@ -4281,7 +4289,7 @@ def expand_dims( ...@@ -4281,7 +4289,7 @@ def expand_dims(
""" """
a = as_tensor(a) a = as_tensor(a)
if not isinstance(axis, tuple | list): if not isinstance(axis, Sequence):
axis = (axis,) axis = (axis,)
out_ndim = len(axis) + a.ndim out_ndim = len(axis) + a.ndim
......
import collections
import warnings
from collections.abc import Sequence
from functools import partial, reduce
from itertools import pairwise
from typing import cast
import numpy as np
from numpy.core.einsumfunc import _find_contraction, _parse_einsum_input # type: ignore
from numpy.core.numeric import ( # type: ignore
normalize_axis_index,
normalize_axis_tuple,
)
from pytensor.compile.builders import OpFromGraph
from pytensor.tensor import TensorLike
from pytensor.tensor.basic import (
arange,
as_tensor,
expand_dims,
get_vector_length,
moveaxis,
stack,
transpose,
where,
)
from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.functional import vectorize
from pytensor.tensor.math import and_, eq, tensordot
from pytensor.tensor.shape import shape_padright
from pytensor.tensor.variable import TensorVariable
PATH = tuple[tuple[int] | tuple[int, int], ...]
class Einsum(OpFromGraph):
"""
Wrapper Op for Einsum graphs
Notes
-----
The `optimized` prop indicates whether the inner graph was optimized, which can only be done when all shapes are
statically known. This is now determined at graph creation time only. We could introduce a rewrite that tries to
optimize the graph if static shapes become known later (e.g., after use of `clone_replace` or shape inference during
rewrites).
Also, once the graph is optimized, it could be inlined for potential further optimization that consider the rest of
the graph.
This prop is different from the `optimize` kwarg in numpy that determines what kind (if any) of optimization is
desired. We haven't decided whether we want to provide this functionality.
"""
__props__ = ("subscripts", "path", "optimized")
def __init__(self, *args, subscripts: str, path: PATH, optimized: bool, **kwargs):
self.subscripts = subscripts
self.path = path
self.optimized = optimized
super().__init__(*args, **kwargs, strict=True)
def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
"""
Create an array with values increasing along the specified axis.
Iota is a multidimensional generalization of the `arange` function. The returned array is filled with whole numbers
increasing along the specified axis.
Parameters
----------
shape: TensorVariable
The shape of the array to be created.
axis: int
The axis along which to fill the array with increasing values.
Returns
-------
TensorVariable
An array with values increasing along the specified axis.
Examples
--------
In the simplest case where ``shape`` is 1d, the output will be equivalent to ``pt.arange``:
.. testcode::
import pytensor.tensor as pt
from pytensor.tensor.einsum import _iota
shape = pt.as_tensor((5,))
print(_iota(shape, 0).eval())
.. testoutput::
[0 1 2 3 4]
In higher dimensions, it will look like many concatenated `arange`:
.. testcode::
shape = pt.as_tensor((5, 5))
print(_iota(shape, 1).eval())
.. testoutput::
[[0 1 2 3 4]
[0 1 2 3 4]
[0 1 2 3 4]
[0 1 2 3 4]
[0 1 2 3 4]]
Setting ``axis=0`` above would result in the transpose of the output.
"""
len_shape = get_vector_length(shape)
axis = normalize_axis_index(axis, len_shape)
values = arange(shape[axis])
return broadcast_to(shape_padright(values, len_shape - axis - 1), shape)
def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable:
"""
Create a Kroncker delta tensor.
The Kroncker delta function is defined:
.. math::
\\delta(i, j) = \begin{cases} 1 & \text{if} \\quad i = j \\ 0 & \text{otherwise} \\end{cases}
To create a Kronecker tensor, the delta function is applied elementwise to the axes specified. The result is a
tensor of booleans, with ``True`` where the axis indices coincide, and ``False`` otherwise. See below for examples.
Parameters
----------
shape: TensorVariable
The shape of the tensor to be created. Note that `_delta` is not defined for 1d tensors, because there is no
second axis against which to compare.
axes: sequence of int
Axes whose indices should be compared. Note that `_delta` is not defined for a single axis, because there is no
second axis against which to compare.
Examples
--------
An easy case to understand is when the shape is square and the number of axes is equal to the number of dimensions.
This will result in a generalized identity tensor, with ``True`` along the main diagonal:
.. testcode::
from pytensor.tensor.einsum import _delta
print(_delta((5, 5), (0, 1)).eval())
.. testoutput::
[[ True False False False False]
[False True False False False]
[False False True False False]
[False False False True False]
[False False False False True]]
In the case where the shape is not square, the result will be a tensor with ``True`` along the main diagonal and
``False`` elsewhere:
.. testcode::
from pytensor.tensor.einsum import _delta
print(_delta((3, 2), (0, 1)).eval())
.. testoutput::
[[ True False]
[False True]
[False False]]
When there are more than two dimensions in the shape, axes can be only a subset of them, leading to different
arragements of True and False values. For example for a 3d batch of matrices, choosing axes (0, 2) will lead to
True values on the column corresponding to the batch index of each matrix:
.. testcode::
from pytensor.tensor.einsum import _delta
print(_delta((3, 3, 3), (0, 2)).eval())
.. testoutput::
[[[ True False False]
[ True False False]
[ True False False]]
[[False True False]
[False True False]
[False True False]]
[[False False True]
[False False True]
[False False True]]]
"""
if len(axes) == 1:
raise ValueError("Need at least two axes to create a delta tensor")
base_shape = stack([shape[axis] for axis in axes])
iotas = [_iota(base_shape, i) for i in range(len(axes))]
eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)]
result = reduce(and_, eyes)
non_axes = [i for i in range(len(tuple(shape))) if i not in axes]
return broadcast_to(expand_dims(result, non_axes), shape)
def _general_dot(
vars: tuple[TensorVariable, TensorVariable],
axes: Sequence[Sequence[int]], # Should be length 2,
batch_axes: Sequence[Sequence[int]], # Should be length 2,
) -> TensorVariable:
"""
Generalized dot product between two tensors.
Ultimately ``_general_dot`` is a call to `tensor_dot`, performing a multiply-and-sum ("dot") operation between two
tensors, along a requested dimension. This function further generalizes this operation by allowing arbitrary
batch dimensions to be specified for each tensor.
Parameters
----------
vars: tuple[TensorVariable, TensorVariable]
The tensors to be ``tensor_dot``ed
axes: Sequence[Sequence[int]]
The axes along which to perform the dot product. Should be a sequence of two sequences, one for each tensor.
batch_axes: Sequence[Sequence[int]]
The batch axes for each tensor. Should be a sequence of two sequences, one for each tensor.
Returns
-------
TensorVariable
The result of the ``tensor_dot`` product.
Examples
--------
Perform a batched dot product between two 3d tensors:
.. testcode::
import pytensor.tensor as pt
from pytensor.tensor.einsum import _general_dot
import numpy as np
A = pt.tensor(shape=(3, 4, 5))
B = pt.tensor(shape=(3, 5, 2))
result = _general_dot((A, B), axes=[[2], [1]], batch_axes=[[0], [0]])
A_val = np.empty((3, 4, 5))
B_val = np.empty((3, 5, 2))
print(tuple(result.shape.eval({A:A_val, B:B_val})))
.. testoutput::
(3, 4, 2)
"""
# Shortcut for non batched case
if not batch_axes[0] and not batch_axes[1]:
return tensordot(*vars, axes=axes)
# Normalize axes, thankfully numpy helper does not sort axis!
axes = [
normalize_axis_tuple(var_axes, var.ndim)
for var, var_axes in zip(vars, axes, strict=True)
]
batch_axes = [
normalize_axis_tuple(var_axes, var.ndim)
for var, var_axes in zip(vars, batch_axes, strict=True)
]
n_batch_axes = [len(var_batch_axes) for var_batch_axes in batch_axes]
# Move batch axes to the left and recode reduction axes
new_vars = list(vars)
new_axes = list(axes)
for i, (var, var_axes, var_batch_axes, var_n_batch_axes) in enumerate(
zip(vars, axes, batch_axes, n_batch_axes, strict=True)
):
if var_batch_axes == tuple(range(var_n_batch_axes)):
# Already on left to right order
continue
new_var_batch_axes = tuple(range(var_n_batch_axes))
new_var = moveaxis(var, var_batch_axes, new_var_batch_axes)
new_var_axes = []
for var_axis in var_axes:
batch_axes_to_the_right = len(
[batch_axis for batch_axis in var_batch_axes if batch_axis > var_axis]
)
new_var_axes.append(var_axis + batch_axes_to_the_right)
new_vars[i] = new_var
new_axes[i] = new_var_axes
lhs, rhs = new_vars
lhs_axes, rhs_axes = new_axes
lhs_n_batch_axes, rhs_n_batch_axes = n_batch_axes
# Create signature of tensordot
lhs_signature = [f"l{i}" for i in range(lhs.type.ndim)]
rhs_signature = [f"r{i}" for i in range(rhs.type.ndim)]
# Aligned axes get the same dimension name
for i, (lhs_axis, rhs_axis) in enumerate(zip(lhs_axes, rhs_axes)):
lhs_signature[lhs_axis] = rhs_signature[rhs_axis] = f"a{i}"
# Trim away the batch ndims
lhs_signature = lhs_signature[lhs_n_batch_axes:]
rhs_signature = rhs_signature[rhs_n_batch_axes:]
out_signature = [
lhs_dim for lhs_dim in lhs_signature if not lhs_dim.startswith("a")
] + [rhs_dim for rhs_dim in rhs_signature if not rhs_dim.startswith("a")]
signature = f"({','.join(lhs_signature)}),({','.join(rhs_signature)})->({','.join(out_signature)})"
# Adjust axes for core case
core_lhs_axes = tuple(np.array(lhs_axes) - lhs_n_batch_axes)
core_rhs_axes = tuple(np.array(rhs_axes) - rhs_n_batch_axes)
if signature == "(),()->()":
# Just a multiplication
out = lhs * rhs
else:
out = vectorize(
partial(tensordot, axes=[core_lhs_axes, core_rhs_axes]), signature=signature
)(lhs, rhs)
return cast(TensorVariable, out)
def _contraction_list_from_path(
subscripts: str, operands: Sequence[TensorVariable], path: PATH
):
"""
Generate a list of contraction steps based on the provided einsum path.
Code adapted from einsum_opt: https://github.com/dgasmith/opt_einsum/blob/94c62a05d5ebcedd30f59c90b9926de967ed10b5/opt_einsum/contract.py#L369
When all shapes are known, the linked einsum_opt implementation is preferred. This implementation is used when
some or all shapes are not known. As a result, contraction will (always?) be done left-to-right, pushing intermediate
results to the end of the stack.
Parameters
----------
subscripts: str
Einsum signature string describing the computation to be performed.
operands: Sequence[TensorLike]
Tensors described by the subscripts.
path: tuple[tuple[int] | tuple[int, int]]
A list of tuples, where each tuple describes the indices of the operands to be contracted, sorted in the order
they should be contracted.
Returns
-------
contraction_list: list
A list of tuples, where each tuple describes a contraction step. Each tuple contains the following elements:
- contraction_inds: tuple[int]
The indices of the operands to be contracted
- idx_removed: str
The indices of the contracted indices (those removed from the einsum string at this step)
- einsum_str: str
The einsum string for the contraction step
- remaining: None
The remaining indices. Included to match the output of opt_einsum.contract_path, but not used.
- do_blas: None
Whether to use blas to perform this step. Included to match the output of opt_einsum.contract_path,
but not used.
"""
fake_operands = [
np.zeros([1 if dim == 1 else 0 for dim in x.type.shape]) for x in operands
]
input_subscripts, output_subscript, operands = _parse_einsum_input(
(subscripts, *fake_operands)
)
# Build a few useful list and sets
input_list = input_subscripts.split(",")
input_sets = [set(x) for x in input_list]
output_set = set(output_subscript)
# Build contraction tuple (positions, gemm, einsum_str, remaining)
contraction_list = []
for cnum, contract_inds in enumerate(path):
# Make sure we remove inds from right to left
contract_inds = cast(
tuple[int] | tuple[int, int], tuple(sorted(contract_inds, reverse=True))
)
contract_tuple = _find_contraction(contract_inds, input_sets, output_set)
out_inds, input_sets, idx_removed, idx_contract = contract_tuple
tmp_inputs = [input_list.pop(x) for x in contract_inds]
# Last contraction
if (cnum - len(path)) == -1:
idx_result = output_subscript
else:
# use tensordot order to minimize transpositions
all_input_inds = "".join(tmp_inputs)
idx_result = "".join(sorted(out_inds, key=all_input_inds.find))
input_list.append(idx_result)
einsum_str = ",".join(tmp_inputs) + "->" + idx_result
# We only need the first three inputs to build the forward graph
contraction = (contract_inds, idx_removed, einsum_str, None, None)
contraction_list.append(contraction)
return contraction_list
def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable:
"""
Multiplication and summation of tensors using the Einstein summation convention.
Code adapted from JAX: https://github.com/google/jax/blob/534d32a24d7e1efdef206188bb11ae48e9097092/jax/_src/numpy/lax_numpy.py#L5283
Einsum allows the user to specify a wide range of operations on tensors using the Einstein summation convention. Using
this notation, many common linear algebraic operations can be succinctly described on higher order tensors.
Parameters
----------
subscripts: str
Einsum signature string describing the computation to be performed.
operands: sequence of TensorVariable
Tensors to be multiplied and summed.
Returns
-------
TensorVariable
The result of the einsum operation.
See Also
--------
pytensor.tensor.tensordot: Generalized dot product between two tensors
pytensor.tensor.dot: Matrix multiplication between two tensors
numpy.einsum: The numpy implementation of einsum
Examples
--------
Inputs to `pt.einsum` are a string describing the operation to be performed (the "subscripts"), and a sequence of
tensors to be operated on. The string must follow the following rules:
1. The string gives inputs and (optionally) outputs. Inputs and outputs are separated by "->".
2. The input side of the string is a comma-separated list of indices. For each comma-separated index string, there
must be a corresponding tensor in the input sequence.
3. For each index string, the number of dimensions in the corresponding tensor must match the number of characters
in the index string.
4. Indices are arbitrary strings of characters. If an index appears multiple times in the input side, it must have
the same shape in each input.
5. The indices on the output side must be a subset of the indices on the input side -- you cannot introduce new
indices in the output.
6. Elipses ("...") can be used to elide multiple indices. This is useful when you have a large number of "batch"
dimensions that are not implicated in the operation.
Finally, two rules about these indicies govern how computation is carried out:
1. Repeated indices on the input side indicate how the tensor should be "aligned" for multiplication.
2. Indices that appear on the input side but not the output side are summed over.
The operation of these rules is best understood via examples:
Example 1: Matrix multiplication
.. code-block:: python
import pytensor as pt
A = pt.matrix("A")
B = pt.matrix("B")
C = pt.einsum("ij, jk -> ik", A, B)
This computation is equivalent to :code:`C = A @ B`. Notice that the ``j`` index is repeated on the input side of the
signature, and does not appear on the output side. This indicates that the ``j`` dimension of the first tensor should be
multiplied with the ``j`` dimension of the second tensor, and the resulting tensor's ``j`` dimension should be summed
away.
Example 2: Batched matrix multiplication
.. code-block:: python
import pytensor as pt
A = pt.tensor("A", shape=(None, 4, 5))
B = pt.tensor("B", shape=(None, 5, 6))
C = pt.einsum("bij, bjk -> bik", A, B)
This computation is also equivalent to :code:`C = A @ B` because of Pytensor's built-in broadcasting rules, but
the einsum signature is more explicit about the batch dimensions. The ``b`` and ``j`` indices are repeated on the
input side. Unlike ``j``, the ``b`` index is also present on the output side, indicating that the batch dimension
should **not** be summed away. As a result, multiplication will be performed over the ``b, j`` dimensions, and then
the ``j`` dimension will be summed over. The resulting tensor will have shape ``(None, 4, 6)``.
Example 3: Batched matrix multiplication with elipses
.. code-block:: python
import pytensor as pt
A = pt.tensor("A", shape=(4, None, None, None, 5))
B = pt.tensor("B", shape=(5, None, None, None, 6))
C = pt.einsum("i...j, j...k -> ...ik", A, B)
This case is the same as above, but inputs ``A`` and ``B`` have multiple batch dimensions. To avoid writing out all
of the batch dimensions (which we do not care about), we can use ellipses to elide over these dimensions. Notice
also that we are not required to "sort" the input dimensions in any way. In this example, we are doing a dot
between the last dimension A and the first dimension of B, which is perfectly valid.
Example 4: Outer product
.. code-block:: python
import pytensor as pt
x = pt.tensor("x", shape=(3,))
y = pt.tensor("y", shape=(4,))
z = pt.einsum("i, j -> ij", x, y)
This computation is equivalent to :code:`pt.outer(x, y)`. Notice that no indices are repeated on the input side,
and the output side has two indices. Since there are no indices to align on, the einsum operation will simply
multiply the two tensors elementwise, broadcasting dimensions ``i`` and ``j``.
Example 5: Convolution
.. code-block:: python
import pytensor as pt
x = pt.tensor("x", shape=(None, None, None, None, None, None))
w = pt.tensor("w", shape=(None, None, None, None))
y = pt.einsum(""bchwkt,fckt->bfhw", x, w)
Given a batch of images ``x`` with dimensions ``(batch, channel, height, width, kernel_size, num_filters)``
and a filter ``w``, with dimensions ``(num_filters, channels, kernel_size, num_filters)``, this einsum operation
computes the convolution of ``x`` with ``w``. Multiplication is aligned on the batch, num_filters, height, and width
dimensions. The channel, kernel_size, and num_filters dimensions are summed over. The resulting tensor has shape
``(batch, num_filters, height, width)``, reflecting the fact that information from each channel has been mixed
together.
"""
if optimize is not None:
raise NotImplementedError(
"Optimize kwarg is not implemented in PyTensor. "
"By default, PyTensor will always optimize the graph if the inputs have static shapes.\n"
"If you need this functionality open an issue in https://github.com/pymc-devs/pytensor/issues to let us know. "
)
# TODO: Is this doing something clever about unknown shapes?
# contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
tensor_operands = [as_tensor(operand) for operand in operands]
shapes = [operand.type.shape for operand in tensor_operands]
path: PATH
if any(None in shape for shape in shapes):
# Case 1: At least one of the operands has an unknown shape. In this case, we can't use opt_einsum to optimize
# the contraction order, so we just use a default path of (1,0) contractions. This will work left-to-right,
# pushing intermediate results to the end of the stack.
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will
# match more often
# If shapes become known later we will likely want to rebuild the Op (unless we inline it)
if len(tensor_operands) == 1:
path = ((0,),)
else:
# By default, we try right to left because we assume that most graphs
# have a lower dimensional rightmost operand
path = tuple(pairwise(reversed(range(len(tensor_operands)))))
contraction_list = _contraction_list_from_path(
subscripts, tensor_operands, path
)
# If there are only 1 or 2 operands, there is no optimization to be done?
optimized = len(tensor_operands) <= 2
else:
# Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
# contraction order.
_, contraction_list = np.einsum_path(
subscripts,
# Numpy einsum_path requires arrays even though only the shapes matter
# It's not trivial to duck-type our way around because of internal call to `asanyarray`
*[np.empty(shape) for shape in shapes],
einsum_call=True, # Not part of public API
optimize="optimal",
) # type: ignore
path = tuple(contraction[0] for contraction in contraction_list)
optimized = True
def removechars(s, chars):
return s.translate(str.maketrans(dict.fromkeys(chars)))
def sum_uniques(
operand: TensorVariable, names: str, uniques: list[str]
) -> tuple[TensorVariable, str]:
"""Reduce unique indices (those that appear only once) in a given contraction step via summing."""
if uniques:
axes = [names.index(name) for name in uniques]
operand = operand.sum(axes)
names = removechars(names, uniques)
return operand, names
def sum_repeats(
operand: TensorVariable,
names: str,
counts: collections.Counter,
keep_names: str,
) -> tuple[TensorVariable, str]:
"""Reduce repeated indices in a given contraction step via summation against an identity matrix."""
for name, count in counts.items():
if count > 1:
axes = [i for i, n in enumerate(names) if n == name]
eye = _delta(operand.shape, axes)
operand = where(eye, operand, operand.zeros_like())
if name not in keep_names:
operand = operand.sum(axes)
names = names.replace(name, "")
else:
operand = operand.sum(axes[:-1])
names = names.replace(name, "", count - 1)
return operand, names
def filter_singleton_dims(operand, names, other_operand, other_names):
op_bcast = operand.type.broadcastable
other_bcast = other_operand.type.broadcastable
keep = [
(not op_bcast[i]) or (j == -1) or other_bcast[j]
for i, j in enumerate(map(other_names.find, names))
]
keep_axes = [i for i, keep_axis in enumerate(keep) if keep_axis]
squeeze_axes = [i for i, keep_axis in enumerate(keep) if not keep_axis]
if squeeze_axes:
# TODO: We could modify the subscripts to avoid the problem?
warnings.warn(
"The same einsum subscript is used for a broadcastable and non-broadcastable dimension. "
"This can result in a suboptimal contraction path."
)
return operand.squeeze(squeeze_axes), "".join(names[i] for i in keep_axes)
einsum_operands = list(tensor_operands) # So we can pop
for operand_indices, contracted_names, einstr, _, _ in contraction_list:
contracted_names = sorted(contracted_names)
assert len(contracted_names) == len(
set(contracted_names)
), "The set was needed!"
input_str, result_names = einstr.split("->")
input_names = input_str.split(",")
# switch on the number of operands to be processed in this loop iteration.
# every case here sets 'operand' and 'names'.
if len(operand_indices) == 1:
operand = einsum_operands.pop(operand_indices[0])
(names,) = input_names
counts = collections.Counter(names)
# sum out unique contracted indices with a single reduce-sum
uniques = [name for name in contracted_names if counts[name] == 1]
operand, names = sum_uniques(operand, names, uniques)
# for every repeated index, do a contraction against an identity matrix
operand, names = sum_repeats(operand, names, counts, result_names)
elif len(operand_indices) == 2:
lhs, rhs = map(einsum_operands.pop, operand_indices)
lhs_names, rhs_names = input_names
# handle cases where one side of a contracting or batch dimension is 1
# but its counterpart is not.
lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, rhs, rhs_names)
rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, lhs, lhs_names)
lhs_counts = collections.Counter(lhs_names)
rhs_counts = collections.Counter(rhs_names)
# sum out unique contracted indices in lhs and rhs
lhs_uniques = [
name
for name in contracted_names
if lhs_counts[name] == 1 and rhs_counts[name] == 0
]
lhs, lhs_names = sum_uniques(lhs, lhs_names, lhs_uniques)
rhs_uniques = [
name
for name in contracted_names
if rhs_counts[name] == 1 and lhs_counts[name] == 0
]
rhs, rhs_names = sum_uniques(rhs, rhs_names, rhs_uniques)
# for every repeated index, contract against an identity matrix
lhs, lhs_names = sum_repeats(
lhs, lhs_names, lhs_counts, result_names + rhs_names
)
rhs, rhs_names = sum_repeats(
rhs, rhs_names, rhs_counts, result_names + lhs_names
)
lhs_or_rhs_names = set(lhs_names) | set(rhs_names)
contracted_names = [x for x in contracted_names if x in lhs_or_rhs_names]
lhs_and_rhs_names = set(lhs_names) & set(rhs_names)
batch_names = [x for x in result_names if x in lhs_and_rhs_names]
if batch_names:
lhs_batch, rhs_batch = tuple(
zip(*[(lhs_names.find(n), rhs_names.find(n)) for n in batch_names])
)
else:
lhs_batch = rhs_batch = ()
# contract using dot_general
batch_names_str = "".join(batch_names)
if contracted_names:
lhs_cont, rhs_cont = tuple(
zip(
*[
(lhs_names.index(n), rhs_names.index(n))
for n in contracted_names
]
)
)
else:
lhs_cont = rhs_cont = ()
deleted_names = batch_names_str + "".join(contracted_names)
remaining_lhs_names = removechars(lhs_names, deleted_names)
remaining_rhs_names = removechars(rhs_names, deleted_names)
# Try both orders of lhs and rhs, in the hope that one of them means we
# don't need an explicit transpose. opt_einsum likes to contract from
# right to left, so we expect (rhs,lhs) to have the best chance of not
# needing a transpose.
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
if names == result_names:
operand = _general_dot(
(rhs, lhs), (rhs_cont, lhs_cont), (rhs_batch, lhs_batch)
)
else:
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
operand = _general_dot(
(lhs, rhs),
axes=(lhs_cont, rhs_cont),
batch_axes=(lhs_batch, rhs_batch),
)
else:
raise ValueError(
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}"
)
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
assert len(names) == len(result_names) == len(set(names))
assert set(names) == set(result_names)
if names != result_names:
perm = tuple(names.index(name) for name in result_names)
operand = transpose(operand, perm)
einsum_operands.append(operand) # used in next iteration
[einsum_result] = einsum_operands
out = Einsum(
subscripts=subscripts,
inputs=list(tensor_operands),
outputs=[einsum_result],
path=tuple(path),
optimized=optimized,
)(*tensor_operands)
return cast(TensorVariable, out)
from collections.abc import Callable from collections.abc import Callable
from pytensor.graph import vectorize_graph from pytensor.graph import vectorize_graph
from pytensor.tensor import TensorVariable
from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.utils import _parse_gufunc_signature
from pytensor.tensor.variable import TensorVariable
def vectorize(func: Callable, signature: str | None = None) -> Callable: def vectorize(func: Callable, signature: str | None = None) -> Callable:
......
...@@ -3,10 +3,9 @@ import pytensor.tensor.rewriting.blas ...@@ -3,10 +3,9 @@ import pytensor.tensor.rewriting.blas
import pytensor.tensor.rewriting.blas_c import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.blockwise import pytensor.tensor.rewriting.blockwise
import pytensor.tensor.rewriting.einsum
import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.elemwise
import pytensor.tensor.rewriting.extra_ops import pytensor.tensor.rewriting.extra_ops
# Register JAX specializations
import pytensor.tensor.rewriting.jax import pytensor.tensor.rewriting.jax
import pytensor.tensor.rewriting.linalg import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math import pytensor.tensor.rewriting.math
......
...@@ -52,6 +52,7 @@ from pytensor.tensor.basic import ( ...@@ -52,6 +52,7 @@ from pytensor.tensor.basic import (
TensorFromScalar, TensorFromScalar,
alloc, alloc,
as_tensor_variable, as_tensor_variable,
atleast_Nd,
cast, cast,
extract_constant, extract_constant,
fill, fill,
...@@ -1219,3 +1220,123 @@ def local_merge_alloc(fgraph, node): ...@@ -1219,3 +1220,123 @@ def local_merge_alloc(fgraph, node):
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy") register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
@register_specialize
@node_rewriter([DimShuffle])
def local_dimshuffle_alloc(fgraph, node):
"""
Lift DimShuffle through Alloc
dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2)
"""
alloc_out = node.inputs[0]
alloc_node = alloc_out.owner
if not (alloc_node and isinstance(alloc_node.op, Alloc)):
return
ds_op = node.op
value, *alloc_shape = alloc_node.inputs
# Add implicit dimensions of value
value = atleast_Nd(value, n=len(alloc_shape))
# Dimshuffle value and alloc_shape
ds_value = value.dimshuffle(ds_op.new_order)
ds_alloc_shape = [alloc_shape[i] for i in ds_op.shuffle]
for dim in ds_op.augment:
ds_alloc_shape.insert(dim, 1)
return [alloc(ds_value, *ds_alloc_shape)]
@register_specialize("shape_unsafe")
@node_rewriter([Join])
def local_join_of_alloc(fgraph, node):
"""Rewrite a Join of Alloc nodes to an Alloc of the Join nodes."""
axis, *tensors = node.inputs
if len(tensors) < 2:
# Let other rewrite handle the useless Join
return
if not isinstance(axis, Constant):
return
core_tensors = []
alloc_shapes = []
for tensor in tensors:
if tensor.owner is None:
return
# tensor = expand_dims_to_alloc(tensor)
if not isinstance(tensor.owner.op, Alloc):
return
value, *shape = tensor.owner.inputs
# Introduce explicit batch dims
value = atleast_Nd(value, n=len(shape))
core_tensors.append(value)
alloc_shapes.append(shape)
# Find which allocated dimensions can be lifted
# Axis can never be lifted
# Non-axis allocated dimensions can be lifted if they are all broadcastable
[out] = node.outputs
axis = axis.data
broadcasted_dims = list(
zip(
*(
[
bef and not aft
for bef, aft in zip(
core_tensor.type.broadcastable,
tensor.type.broadcastable,
strict=True,
)
]
for core_tensor, tensor in zip(core_tensors, tensors, strict=True)
)
)
)
lifteable_alloc_dims = {
dim
for dim in range(out.type.ndim)
if dim != axis and all(broadcasted_dims[dim])
}
if not lifteable_alloc_dims:
return
# Lift the allocated dimensions
new_tensors = []
for core_tensor, alloc_shape in zip(core_tensors, alloc_shapes):
pre_join_shape = [
1 if i in lifteable_alloc_dims else alloc_dim
for i, alloc_dim in enumerate(alloc_shape)
]
new_tensor = alloc(core_tensor, *pre_join_shape)
copy_stack_trace(tensor, new_tensor)
new_tensors.append(new_tensor)
new_join = node.op(axis, *new_tensors)
copy_stack_trace(node.outputs[0], new_join)
# Reintroduce the lifted dims
post_join_shape = []
for i, alloc_dims in enumerate(zip(*alloc_shapes)):
if i == axis:
# The alloc dim along the axis is the sum of all the pre-join alloc dims
post_join_shape.append(add(*alloc_dims))
else:
# Otherwise the shapes should all match. We prioritize constants if any
for best_alloc_dim in alloc_dims:
if isinstance(best_alloc_dim, Constant):
break
post_join_shape.append(best_alloc_dim)
new_out = alloc(new_join, *post_join_shape)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
...@@ -10,6 +10,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -10,6 +10,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize, register_specialize,
register_stabilize, register_stabilize,
) )
from pytensor.tensor.shape import Reshape
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
...@@ -67,10 +68,16 @@ optdb.register( ...@@ -67,10 +68,16 @@ optdb.register(
def local_eager_useless_unbatched_blockwise(fgraph, node): def local_eager_useless_unbatched_blockwise(fgraph, node):
if isinstance( if isinstance(
node.op.core_op, node.op.core_op,
Dot | Alloc | ARange | Subtensor | AdvancedSubtensor | AdvancedIncSubtensor, Dot
| Alloc
| ARange
| Subtensor
| AdvancedSubtensor
| AdvancedIncSubtensor
| Reshape,
): ):
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize # Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
# These other Ops can't always be trivially vectored at runtime, # These other Ops can't always be trivially vectorized at runtime,
# since their inputs may imply non-rectangular shapes. # since their inputs may imply non-rectangular shapes.
return local_useless_unbatched_blockwise.fn(fgraph, node) return local_useless_unbatched_blockwise.fn(fgraph, node)
...@@ -97,62 +104,67 @@ def local_blockwise_alloc(fgraph, node): ...@@ -97,62 +104,67 @@ def local_blockwise_alloc(fgraph, node):
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector) BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)
""" """
if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner):
return None
op: Blockwise = node.op # type: ignore op: Blockwise = node.op # type: ignore
batch_ndim = op.batch_ndim(node) batch_ndim = op.batch_ndim(node)
if not batch_ndim: if not batch_ndim:
return None return None
if not any(var.owner and isinstance(var.owner.op, Alloc) for var in node.inputs):
return None
new_inputs = [] new_inputs = []
batch_shapes = [] batch_shapes = []
can_push_any_alloc = False can_push_any_alloc = False
for inp, inp_sig in zip(node.inputs, op.inputs_sig): for inp, inp_sig in zip(node.inputs, op.inputs_sig):
if inp.owner and isinstance(inp.owner.op, Alloc): if not all(inp.type.broadcastable[:batch_ndim]):
# Push batch dims from Alloc if inp.owner and isinstance(inp.owner.op, Alloc):
value, *shape = inp.owner.inputs # Push batch dims from Alloc
value, *shape = inp.owner.inputs
# Check what to do with the value of the Alloc
squeezed_value = _squeeze_left(value, batch_ndim) # Check what to do with the value of the Alloc
missing_ndim = len(shape) - value.type.ndim squeezed_value = _squeeze_left(value, batch_ndim)
if ( missing_ndim = len(shape) - value.type.ndim
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:]) if (
!= inp.type.broadcastable[batch_ndim:] (((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
): != inp.type.broadcastable[batch_ndim:]
# We still need an Alloc for the core dims ):
core_shape = shape[batch_ndim:] # We still need an Alloc for the core dims
# And the batch dims of the squeezed value core_shape = shape[batch_ndim:]
squeezed_value_batch_ndim = squeezed_value.type.ndim - len(core_shape) # And the batch dims of the squeezed value
batch_shape = [ squeezed_value_batch_ndim = squeezed_value.type.ndim - len(
1 if broadcastable else dim core_shape
for broadcastable, dim in zip(
squeezed_value.type.broadcastable[:squeezed_value_batch_ndim],
tuple(squeezed_value.shape)[:squeezed_value_batch_ndim],
) )
] batch_shape = [
squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape) 1 if broadcastable else dim
if squeezed_value.type.broadcastable == inp.type.broadcastable: for broadcastable, dim in zip(
# We can't change anything about this Alloc input squeezed_value.type.broadcastable[
new_inputs.append(inp) :squeezed_value_batch_ndim
continue ],
tuple(squeezed_value.shape)[:squeezed_value_batch_ndim],
# We can push batch dims of this Alloc input )
batch_shapes.append( ]
tuple( squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape)
1 if broadcastable else dim if squeezed_value.type.broadcastable == inp.type.broadcastable:
for broadcastable, dim in zip( # We can't change anything about this Alloc input
inp.type.broadcastable, shape[:batch_ndim] new_inputs.append(inp)
continue
# We can push batch dims of this Alloc input
batch_shapes.append(
tuple(
1 if broadcastable else dim
for broadcastable, dim in zip(
inp.type.broadcastable, shape[:batch_ndim]
)
) )
) )
) new_inputs.append(squeezed_value)
new_inputs.append(squeezed_value) can_push_any_alloc = True
can_push_any_alloc = True continue
else: # Nothing to do with this input other than removing dummy batch dims
# Nothing to do with this input other than removing dummy batch dims new_inputs.append(_squeeze_left(inp, batch_ndim))
new_inputs.append(_squeeze_left(inp, batch_ndim))
if not can_push_any_alloc: if not can_push_any_alloc:
return None return None
...@@ -167,17 +179,15 @@ def local_blockwise_alloc(fgraph, node): ...@@ -167,17 +179,15 @@ def local_blockwise_alloc(fgraph, node):
missing_ndim = old_out_type.ndim - new_out_type.ndim missing_ndim = old_out_type.ndim - new_out_type.ndim
batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim] batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim]
for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples
if old_out_type.broadcastable[i]:
continue
for batch_dim in batch_dims: for batch_dim in batch_dims:
if batch_dim == 1: if batch_dim == 1:
continue continue
batch_shape[i] = batch_dim
if isinstance(batch_dim, Constant): if isinstance(batch_dim, Constant):
# Give preference to Constants # Give preference to Constants
batch_shape[i] = batch_dim
break break
elif old_out_type.broadcastable[i]:
# Only use non Constant shapes if absolutely necessary
# Otherwise, we use the shape of the non-alloc output
batch_shape[i] = batch_dim
copy_stack_trace(node.outputs, new_outs) copy_stack_trace(node.outputs, new_outs)
new_outs = [ new_outs = [
...@@ -190,3 +200,28 @@ def local_blockwise_alloc(fgraph, node): ...@@ -190,3 +200,28 @@ def local_blockwise_alloc(fgraph, node):
] ]
copy_stack_trace(node.outputs, new_outs) copy_stack_trace(node.outputs, new_outs)
return new_outs return new_outs
@register_specialize
@node_rewriter([Blockwise])
def local_blockwise_reshape(fgraph, node):
"""Rewrite away square Blockwise reshapes.
Reshape is tricky to vectorize eagerly, because a graph like
`x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
that must be vectorized before we arrize at the reshape operation.
For the square Reshape case, we must wait for all the intemediate
operations to be lifted as Allocs
"""
if not isinstance(node.op.core_op, Reshape):
return None
x, output_shape = node.inputs
batch_ndim = node.op.batch_ndim(node)
if all(output_shape.type.broadcastable[:batch_ndim]):
batched_shape = x.shape[:batch_ndim]
core_reshape = _squeeze_left(output_shape, batch_ndim)
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
from typing import cast
from pytensor.graph import Apply, FunctionGraph, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace
from pytensor.tensor.einsum import Einsum, einsum
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.ofg import inline_ofg_node
from pytensor.tensor.variable import TensorVariable
@register_specialize
@node_rewriter([Einsum])
def optimize_einsum_inner_graph(
fgraph: FunctionGraph, node: Apply
) -> list[TensorVariable] | None:
"""Try to optimize an einsum that was not optimizable at definition time.
This can happen when users replace a graph without rebuilding
Or when during the course of rewrites more specialized static shapes are found
"""
op: Einsum = node.op
if op.optimized:
# Already optimized
return None
operands = node.inputs
if any(None in operand.type.shape for operand in operands):
return None
new_out = einsum(op.subscripts, *operands)
assert new_out.owner.op.optimized
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
@register_specialize
@node_rewriter([Einsum])
def inline_optimized_einsum(
fgraph: FunctionGraph, node: Apply
) -> list[TensorVariable] | None:
"""Inline einsums that are already optimized.
This allows the inner garph to be optimized with the rest of the graph, now that we got ordering right.
"""
op: Einsum = node.op
if not op.optimized:
return None
return cast(list[TensorVariable], inline_ofg_node(node))
from pytensor import clone_replace from typing import cast
from pytensor import Variable, clone_replace
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.graph import node_rewriter from pytensor.graph import Apply, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out from pytensor.graph.rewriting.basic import copy_stack_trace, in2out
from pytensor.tensor.basic import AllocDiag from pytensor.tensor.basic import AllocDiag
from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.basic import register_specialize
def inline_ofg_node(node: Apply) -> list[Variable]:
op = node.op
assert isinstance(op, OpFromGraph)
inlined_outs = clone_replace(
op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))
)
copy_stack_trace(op.inner_outputs, inlined_outs)
return cast(list[Variable], inlined_outs)
@node_rewriter([OpFromGraph]) @node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node): def inline_ofg_expansion(fgraph, node):
""" """
...@@ -18,10 +30,7 @@ def inline_ofg_expansion(fgraph, node): ...@@ -18,10 +30,7 @@ def inline_ofg_expansion(fgraph, node):
if not op.is_inline: if not op.is_inline:
return False return False
new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) return inline_ofg_node(node)
copy_stack_trace(op.inner_outputs, new_out)
return new_out
# We want to run this before the first merge optimizer # We want to run this before the first merge optimizer
...@@ -61,8 +70,4 @@ def late_inline_OpFromGraph(fgraph, node): ...@@ -61,8 +70,4 @@ def late_inline_OpFromGraph(fgraph, node):
------- -------
""" """
op = node.op return inline_ofg_node(node)
new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
copy_stack_trace(op.inner_outputs, new_out)
return new_out
...@@ -749,51 +749,43 @@ pytensor.compile.mode.optdb.register( ...@@ -749,51 +749,43 @@ pytensor.compile.mode.optdb.register(
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
def local_reshape_chain(op): @register_canonicalize("shape_unsafe")
@node_rewriter([op]) @register_specialize("shape_unsafe")
def f(fgraph, node): @node_rewriter([Reshape])
""" def local_reshape_chain(fgraph, node):
Reshape(Reshape(shape1),shape2) -> Reshape(shape2) """
Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2)
"""
if not check_chain(node, op, op):
return False
# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace(node.outputs, rval)
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
return [rval]
else:
return False
return f
"""
if not check_chain(node, Reshape, Reshape):
return False
register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain") rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace(node.outputs, rval)
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
return [rval]
@register_useless @register_useless("shape_unsafe")
@register_canonicalize @register_canonicalize("shape_unsafe")
@register_stabilize @register_specialize("shape_unsafe")
@node_rewriter([Reshape]) @node_rewriter([Reshape])
def local_useless_reshape(fgraph, node): def local_useless_reshape(fgraph, node):
"""Remove two kinds of useless `Reshape`. """Remove two kinds of useless `Reshape`.
...@@ -802,24 +794,17 @@ def local_useless_reshape(fgraph, node): ...@@ -802,24 +794,17 @@ def local_useless_reshape(fgraph, node):
- Remove `Reshape` when reshaping to the shape of the input. - Remove `Reshape` when reshaping to the shape of the input.
""" """
inp = node.inputs[0] inp, output_shape = node.inputs
output = node.outputs[0] [output] = node.outputs
output_shape = node.inputs[1]
if inp.type.ndim != output.type.ndim: if inp.type.ndim != output.type.ndim:
return False return False
# Simple case: both input and output have a single dimension. # Simple case: both input and output have a single dimension.
# TODO FIXME XXX: This could hide errors if the user provides inconsistent
# shapes.
if ( if (
inp.type.ndim == 1 inp.type.ndim == 1
and output.type.ndim == 1 and output.type.ndim == 1
and all( and inp.type.broadcastable == output.type.broadcastable
s1 == s2
for s1, s2 in zip(inp.type.shape, output.type.shape)
if s1 == 1 or s2 == 1
)
): ):
return [inp] return [inp]
...@@ -832,8 +817,15 @@ def local_useless_reshape(fgraph, node): ...@@ -832,8 +817,15 @@ def local_useless_reshape(fgraph, node):
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions # broadcastable and constant dimensions
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector): if isinstance(output_shape, Constant) or (
output_shape_is = output_shape.owner.inputs output_shape.owner and isinstance(output_shape.owner.op, MakeVector)
):
if isinstance(output_shape, Constant):
output_shape_is = [
as_tensor_variable(dim, ndim=0) for dim in output_shape.data
]
else:
output_shape_is = output_shape.owner.inputs
shape_feature = getattr(fgraph, "shape_feature", None) shape_feature = getattr(fgraph, "shape_feature", None)
...@@ -865,9 +857,9 @@ def local_useless_reshape(fgraph, node): ...@@ -865,9 +857,9 @@ def local_useless_reshape(fgraph, node):
shape_match[dim] = True shape_match[dim] = True
continue continue
# Match 1 if input.type.shape[dim] == 1 # Match constant if input.type.shape[dim] == constant
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if inp.type.shape[dim] == 1 and cst_outshp_i == 1: if inp.type.shape[dim] == cst_outshp_i:
shape_match[dim] = True shape_match[dim] = True
continue continue
...@@ -881,17 +873,18 @@ def local_useless_reshape(fgraph, node): ...@@ -881,17 +873,18 @@ def local_useless_reshape(fgraph, node):
if shape_feature: if shape_feature:
inpshp_i = shape_feature.get_shape(inp, dim) inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or ( if inpshp_i == outshp_i or (
extract_constant(inpshp_i, only_process_constants=1) extract_constant(inpshp_i, only_process_constants=True)
== extract_constant(outshp_i, only_process_constants=1) == extract_constant(outshp_i, only_process_constants=True)
): ):
shape_match[dim] = True shape_match[dim] = True
continue continue
if all(shape_match) and nb_m1 <= 1: if nb_m1 <= 1 and all(shape_match):
return [inp]
if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1):
return [inp] return [inp]
# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
return False return False
...@@ -910,9 +903,8 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -910,9 +903,8 @@ def local_reshape_to_dimshuffle(fgraph, node):
-> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n))) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
""" """
op = node.op op = node.op
inp = node.inputs[0] inp, output_shape = node.inputs
output = node.outputs[0] [output] = node.outputs
output_shape = node.inputs[1]
dimshuffle_new_order = [] dimshuffle_new_order = []
new_output_shape = [] new_output_shape = []
...@@ -944,7 +936,7 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -944,7 +936,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
@register_canonicalize @register_canonicalize
@register_stabilize @register_specialize
@node_rewriter([Reshape]) @node_rewriter([Reshape])
def local_reshape_lift(fgraph, node): def local_reshape_lift(fgraph, node):
""" """
......
...@@ -842,13 +842,13 @@ class Reshape(COp): ...@@ -842,13 +842,13 @@ class Reshape(COp):
@_vectorize_node.register(Reshape) @_vectorize_node.register(Reshape)
def _vectorize_reshape(op, node, x, shape): def _vectorize_reshape(op, node, x, shape):
from pytensor.tensor.blockwise import vectorize_node_fallback
old_x, old_shape = node.inputs old_x, old_shape = node.inputs
batched_ndims = x.type.ndim - old_x.type.ndim batched_ndims = x.type.ndim - old_x.type.ndim
if as_tensor_variable(shape).type.ndim != 1: if as_tensor_variable(shape).type.ndim != 1:
raise NotImplementedError( return vectorize_node_fallback(op, node, x, shape)
"It is not possible to vectorize the shape argument of Reshape"
)
if len(tuple(old_shape)) == len(tuple(shape)): if len(tuple(old_shape)) == len(tuple(shape)):
new_shape = [*x.shape[:batched_ndims], *shape] new_shape = [*x.shape[:batched_ndims], *shape]
......
import numpy as np
import pytest
import pytensor
import pytensor.tensor as pt
jax = pytest.importorskip("jax")
def test_jax_einsum():
subscripts = "ij, jk, kl -> il"
x = np.random.rand(3, 5)
y = np.random.rand(5, 2)
z = np.random.rand(2, 4)
shapes = ((3, 5), (5, 2), (2, 4))
x_pt, y_pt, z_pt = (
pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes)
)
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
f = pytensor.function([x_pt, y_pt, z_pt], out, mode="JAX")
np.testing.assert_allclose(f(x, y, z), np.einsum(subscripts, x, y, z))
@pytest.mark.xfail(raises=NotImplementedError)
def test_ellipsis_einsum():
subscripts = "...i,...i->..."
x = np.random.rand(2, 5)
y = np.random.rand(2, 5)
x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt)
f = pytensor.function([x_pt, y_pt], out, mode="JAX")
np.testing.assert_allclose(f(x, y), np.einsum(subscripts, x, y))
from functools import partial from functools import partial
from pytensor import function import numpy as np
from pytensor.graph import FunctionGraph, rewrite_graph
from pytensor import Mode, config, function
from pytensor.graph import FunctionGraph, rewrite_graph, vectorize_graph
from pytensor.graph.basic import equal_computations from pytensor.graph.basic import equal_computations
from pytensor.scalar import log as scalar_log from pytensor.scalar import log as scalar_log
from pytensor.tensor import add, alloc, matrix, tensor, tensor3 from pytensor.tensor import add, alloc, matrix, tensor, tensor3
...@@ -9,6 +11,7 @@ from pytensor.tensor.blockwise import Blockwise ...@@ -9,6 +11,7 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.nlinalg import MatrixPinv from pytensor.tensor.nlinalg import MatrixPinv
from pytensor.tensor.rewriting.blockwise import local_useless_blockwise from pytensor.tensor.rewriting.blockwise import local_useless_blockwise
from pytensor.tensor.shape import Reshape
def test_useless_blockwise_of_elemwise(): def test_useless_blockwise_of_elemwise():
...@@ -45,7 +48,7 @@ def test_blockwise_alloc(): ...@@ -45,7 +48,7 @@ def test_blockwise_alloc():
rewrite = partial( rewrite = partial(
rewrite_graph, rewrite_graph,
include=("ShapeOpt", "specialize"), include=("ShapeOpt", "specialize"),
exclude=("local_useless_unbatched_blockwise",), exclude=("local_useless_unbatched_blockwise", "local_dimshuffle_alloc"),
) )
vector_add = Blockwise(core_op=add, signature="(x),(x)->(x)") vector_add = Blockwise(core_op=add, signature="(x),(x)->(x)")
...@@ -104,7 +107,9 @@ def test_blockwise_alloc(): ...@@ -104,7 +107,9 @@ def test_blockwise_alloc():
y = tensor("y", shape=()) y = tensor("y", shape=())
out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5)) out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5))
expected_out = alloc(vector_add(alloc(x, 5), alloc(y, 5)), 3, 7, 5) expected_out = alloc(vector_add(alloc(x, 5), alloc(y, 5)), 3, 7, 5)
assert equal([rewrite(out)], [expected_out]) assert equal(
[rewrite(out)], [expected_out]
), None # pytensor.dprint([expected_out, rewrite(out)], print_type=True)
x = tensor("x", shape=(5,)) x = tensor("x", shape=(5,))
y = tensor("y", shape=()) y = tensor("y", shape=())
...@@ -118,3 +123,27 @@ def test_blockwise_alloc(): ...@@ -118,3 +123,27 @@ def test_blockwise_alloc():
out = vector_add(x, alloc(y, 5)) out = vector_add(x, alloc(y, 5))
expected_out = out expected_out = out
assert equal([rewrite(out)], [expected_out]) assert equal([rewrite(out)], [expected_out])
def test_blockwise_reshape():
x = tensor("x", shape=(None, None, None))
y = x.reshape([x.shape[0] * x.shape[1], -1])
new_x = tensor("x", shape=(None, None, None, None))
new_y = vectorize_graph(y, {x: new_x})
assert not isinstance(new_y.owner.op, Reshape)
assert isinstance(new_y.owner.op, Blockwise) and isinstance(
new_y.owner.op.core_op, Reshape
)
rewritten_y = rewrite_graph(
new_y, include=("canonicalize", "specialize"), clone=True
)
assert isinstance(rewritten_y.owner.op, Reshape)
no_rewrites = Mode(linker="py", optimizer=None)
test_x = np.arange(5 * 4 * 3 * 2).reshape(5, 4, 3, 2).astype(config.floatX)
np.testing.assert_allclose(
new_y.eval({"x": test_x}, mode=no_rewrites),
rewritten_y.eval({"x": test_x}, mode=no_rewrites),
)
from functools import partial
from pytensor.graph import ancestors, rewrite_graph
from pytensor.tensor import einsum, specify_shape, tensor
from pytensor.tensor.einsum import Einsum
specialize_rewrite = partial(rewrite_graph, include=("specialize",), clone=True)
def test_einsum_optimization():
a = tensor("a", shape=(None, None))
b = tensor("b", shape=(None, None))
c = tensor("c", shape=(None, None))
dynamic_shape_einsum = einsum("ij,ij,jk->ik", a, b, c)
assert not dynamic_shape_einsum.owner.op.optimized
rewritten_out = specialize_rewrite(dynamic_shape_einsum)
assert isinstance(rewritten_out.owner.op, Einsum)
a = specify_shape(a, (2, 3))
b = specify_shape(b, (2, 3))
c = specify_shape(c, (3, 5))
static_shape_einsum = dynamic_shape_einsum.owner.clone_with_new_inputs(
[a, b, c]
).default_output()
assert not static_shape_einsum.owner.op.optimized
rewritten_out = specialize_rewrite(static_shape_einsum)
# Einsum was inlined because it was optimized
assert not isinstance(rewritten_out.owner.op, Einsum)
# Sanity check that it's not buried in the graph
assert not any(
isinstance(var.owner.op, Einsum)
for var in ancestors([rewritten_out])
if var.owner
)
...@@ -337,6 +337,52 @@ class TestLocalUselessReshape: ...@@ -337,6 +337,52 @@ class TestLocalUselessReshape:
topo = f2.maker.fgraph.toposort() topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo) assert not any(isinstance(n.op, Reshape) for n in topo)
def test_constant_shape(self):
# Where reshape is a constant that matches the shape
x = matrix(shape=(2, 3))
shape = pt.as_tensor(np.array([2, 3]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is x
x = matrix(shape=(2, 3))
shape = pt.as_tensor(np.array([-1, 3]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is x
x = matrix(shape=(None, 3))
shape = pt.as_tensor(np.array([-1, 3]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is x
x = matrix(shape=(None, 3))
shape = pt.as_tensor(np.array([2, 3]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
# This could be rewritten as a specify_shape(x, (2, 3))
assert new_out is not x
x = matrix(shape=(2, 3))
shape = pt.as_tensor(np.array([3, 2]))
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is not x
def test_all_but_one_match(self):
x = matrix(shape=(None, None))
shape = [x.shape[0], 3]
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert equal_computations([new_out], [specify_shape(x, (None, 3))])
# Rewrite does not apply if there's also a -1
shape = [-1, 3]
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is out
class TestLocalReshapeToDimshuffle: class TestLocalReshapeToDimshuffle:
def setup_method(self): def setup_method(self):
......
...@@ -3847,8 +3847,10 @@ def test_transpose(): ...@@ -3847,8 +3847,10 @@ def test_transpose():
assert np.all(t2d == np.transpose(x2v, [0, 1])) assert np.all(t2d == np.transpose(x2v, [0, 1]))
assert np.all(t3d == np.transpose(x3v, [0, 2, 1])) assert np.all(t3d == np.transpose(x3v, [0, 2, 1]))
# Check we don't introduce useless transpose
assert ptb.transpose(x1) is x1
# Check that we create a name. # Check that we create a name.
assert ptb.transpose(x1).name == "x1.T"
assert ptb.transpose(x2).name == "x2.T" assert ptb.transpose(x2).name == "x2.T"
assert ptb.transpose(x3).name == "x3.T" assert ptb.transpose(x3).name == "x3.T"
assert ptb.transpose(dmatrix()).name is None assert ptb.transpose(dmatrix()).name is None
......
from functools import partial
from string import ascii_lowercase
import numpy as np
import pytest
import pytensor
import pytensor.tensor as pt
from pytensor import Mode, config, function
from pytensor.graph import FunctionGraph
from pytensor.graph.op import HasInnerGraph
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum
from pytensor.tensor.shape import Reshape
# Fail for unexpected warnings in this file
pytestmark = pytest.mark.filterwarnings("error")
floatX = pytensor.config.floatX
ATOL = RTOL = 1e-8 if floatX == "float64" else 1e-4
def assert_no_blockwise_in_graph(fgraph: FunctionGraph, core_op=None) -> None:
for node in fgraph.apply_nodes:
if isinstance(node.op, Blockwise):
if core_op is None:
raise AssertionError
assert not isinstance(node.op.core_op, core_op)
if isinstance(node.op, HasInnerGraph):
# InnerGraph Ops can be rewritten without modifying the original fgraph
if hasattr(node.op, "_fn"):
inner_fgraph = node.op._fn.maker.fgraph
else:
inner_fgraph = node.op.fgraph
assert_no_blockwise_in_graph(inner_fgraph, core_op=core_op)
def test_iota():
mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
_iota((4, 8), 0).eval(mode=mode),
[
[0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3],
],
)
np.testing.assert_allclose(
_iota((4, 8), 1).eval(mode=mode),
[
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
],
)
def test_delta():
mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
_delta((2, 2), (0, 1)).eval(mode=mode),
[[1.0, 0.0], [0.0, 1.0]],
)
np.testing.assert_allclose(
_delta((2, 2, 2), (0, 1)).eval(mode=mode),
[[[1, 1], [0, 0]], [[0, 0], [1, 1]]],
)
def test_general_dot():
rng = np.random.default_rng(45)
signature = "(l0,a0,a1,l1),(a1,r0,r1,a0)->(l0,l1,r0,r1)"
tensordot_axes = [(-3, -2), (-1, -4)]
# X has two batch dims
# Y has one batch dim
x = pt.tensor("x", shape=(5, 4, 2, 11, 13, 3))
y = pt.tensor("y", shape=(4, 13, 5, 7, 11))
out = _general_dot((x, y), tensordot_axes, [(0, 1), (0,)])
fn = pytensor.function([x, y], out)
# fn.dprint(print_type=True)
if config.mode != "FAST_COMPILE":
assert_no_blockwise_in_graph(fn.maker.fgraph, Reshape)
np_batched_tensordot = np.vectorize(
partial(np.tensordot, axes=tensordot_axes), signature=signature
)
x_test = rng.normal(size=x.type.shape).astype(floatX)
y_test = rng.normal(size=y.type.shape).astype(floatX)
np.testing.assert_allclose(
fn(x_test, y_test), np_batched_tensordot(x_test, y_test), atol=ATOL, rtol=RTOL
)
@pytest.mark.parametrize("static_shape_known", [True, False])
@pytest.mark.parametrize(
"signature",
[
"ij",
"ji",
"ii->i",
"ii",
"ij->",
"ij->j",
"ij->i",
"ij,ij->ij",
"ij,ji->ij",
"ij,ji->ji",
"ij,jk",
"kj,ji",
"ij,kj->ik",
"ik,kj->ikj",
"ij,kl->ijkl",
"ij,jk,kl->il",
"kl,ij,jk->il",
"oij,imj,mjkn,lnk,plk->op",
],
)
def test_einsum_signatures(static_shape_known, signature):
letters_to_dims = dict(zip("ijklmnop", [2, 3, 5, 7, 11, 13, 17, 19], strict=True))
inputs = signature.split("->")[0].split(",")
shapes = [tuple(letters_to_dims[letter] for letter in inp) for inp in inputs]
if static_shape_known:
static_shapes = shapes
else:
static_shapes = [[None] * len(shape) for shape in shapes]
operands = [
pt.tensor(name, shape=static_shape)
for name, static_shape in zip(ascii_lowercase, static_shapes)
]
out = pt.einsum(signature, *operands)
assert out.owner.op.optimized == static_shape_known or len(operands) <= 2
rng = np.random.default_rng(37)
test_values = [rng.normal(size=shape).astype(floatX) for shape in shapes]
np_out = np.einsum(signature, *test_values)
fn = function(operands, out)
pt_out = fn(*test_values)
# print(); fn.dprint(print_type=True)
if config.mode != "FAST_COMPILE":
assert_no_blockwise_in_graph(fn.maker.fgraph)
np.testing.assert_allclose(pt_out, np_out, atol=ATOL, rtol=RTOL)
def test_batch_dim():
shapes = (
(7, 3, 5),
(5, 2),
)
x, y = (pt.tensor(name, shape=shape) for name, shape in zip("xy", shapes))
out = pt.einsum("mij,jk->mik", x, y)
assert out.type.shape == (7, 3, 2)
def test_einsum_conv():
# Adapted example from https://medium.com/latinxinai/vectorized-convolution-operation-using-numpy-b122fd52fba3
rng = np.random.default_rng(125)
batch_size = 32
channels = 3
height = 8
width = 8
kernel_size = 2
num_filters = 15
conv_signature = "bchwkt,fckt->bfhw"
windowed_input = rng.random(
size=(batch_size, channels, height, width, kernel_size, kernel_size)
).astype(floatX)
weights = rng.random(size=(num_filters, channels, kernel_size, kernel_size)).astype(
floatX
)
result = einsum(conv_signature, windowed_input, weights).eval()
assert result.shape == (32, 15, 8, 8)
np.testing.assert_allclose(
result,
np.einsum("bchwkt,fckt->bfhw", windowed_input, weights),
atol=ATOL,
rtol=RTOL,
)
def test_ellipsis():
rng = np.random.default_rng(159)
x = pt.tensor("x", shape=(3, 5, 7, 11))
y = pt.tensor("y", shape=(3, 5, 11, 13))
x_test = rng.normal(size=x.type.shape).astype(floatX)
y_test = rng.normal(size=y.type.shape).astype(floatX)
expected_out = np.matmul(x_test, y_test)
with pytest.raises(ValueError):
pt.einsum("mp,pn->mn", x, y)
out = pt.einsum("...mp,...pn->...mn", x, y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), expected_out, atol=ATOL, rtol=RTOL
)
# Put batch axes in the middle
new_x = pt.moveaxis(x, -2, 0)
new_y = pt.moveaxis(y, -2, 0)
out = pt.einsum("m...p,p...n->m...n", new_x, new_y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}),
expected_out.transpose(-2, 0, 1, -1),
atol=ATOL,
rtol=RTOL,
)
out = pt.einsum("m...p,p...n->mn", new_x, new_y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1)), atol=ATOL, rtol=RTOL
)
def test_broadcastable_dims():
# Test that einsum handles broadcasting dims correctly. There are two points:
# 1. Numpy einsum allows the same subscript for degenerate and full dimensions
# There is some stale discussion on whether this should be a bug or not, but for now it is not:
# https://github.com/numpy/numpy/issues/11548
# 2. Using the same letter for dimensions that are and aren't broadcastable
# can lead to suboptimal paths. We check we issue a warning for the following example:
# https://github.com/dgasmith/opt_einsum/issues/220
rng = np.random.default_rng(222)
a = pt.tensor("a", shape=(32, 32, 32))
b = pt.tensor("b", shape=(1000, 32))
c = pt.tensor("c", shape=(1, 32))
a_test = rng.normal(size=a.type.shape).astype(floatX)
b_test = rng.normal(size=b.type.shape).astype(floatX)
c_test = rng.normal(size=c.type.shape).astype(floatX)
# Note b is used for both 1 and 32
with pytest.warns(
UserWarning, match="This can result in a suboptimal contraction path"
):
suboptimal_out = pt.einsum("ijk,bj,bk->i", a, b, c)
assert not [set(p) for p in suboptimal_out.owner.op.path] == [{0, 2}, {0, 1}]
# If we use a distinct letter we get the optimal path
optimal_out = pt.einsum("ijk,bj,ck->i", a, b, c)
assert [set(p) for p in optimal_out.owner.op.path] == [{0, 2}, {0, 1}]
suboptimal_eval = suboptimal_out.eval({a: a_test, b: b_test, c: c_test})
optimal_eval = optimal_out.eval({a: a_test, b: b_test, c: c_test})
np_eval = np.einsum("ijk,bj,bk->i", a_test, b_test, c_test)
atol = 1e-12 if config.floatX == "float64" else 1e-2
np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol)
np.testing.assert_allclose(optimal_eval, np_eval, atol=atol)
...@@ -14,7 +14,7 @@ from pytensor.graph.type import Type ...@@ -14,7 +14,7 @@ from pytensor.graph.type import Type
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar.basic import ScalarConstant from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
from pytensor.tensor.basic import MakeVector, as_tensor, constant from pytensor.tensor.basic import MakeVector, constant, stack
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
...@@ -801,8 +801,14 @@ class TestVectorize: ...@@ -801,8 +801,14 @@ class TestVectorize:
[vect_out] = vectorize_node(node, mat, new_shape).outputs [vect_out] = vectorize_node(node, mat, new_shape).outputs
assert equal_computations([vect_out], [reshape(mat, new_shape)]) assert equal_computations([vect_out], [reshape(mat, new_shape)])
with pytest.raises(NotImplementedError): new_shape = stack([[-1, x], [x - 1, -1]], axis=0)
vectorize_node(node, vec, broadcast_to(as_tensor([5, 2, x]), (2, 3))) print(new_shape.type)
[vect_out] = vectorize_node(node, vec, new_shape).outputs
vec_test_value = np.arange(6)
np.testing.assert_allclose(
vect_out.eval({x: 3, vec: vec_test_value}),
np.broadcast_to(vec_test_value.reshape(2, 3), (2, 2, 3)),
)
with pytest.raises( with pytest.raises(
ValueError, ValueError,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论