提交 be1330bf authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement Elemwise and Blockwise operations for XTensorVariables

上级 cd1e5dc9
import warnings
import pytensor.xtensor.rewriting
from pytensor.xtensor import (
linalg,
)
from pytensor.xtensor.type import (
XTensorType,
as_xtensor,
......
from collections.abc import Sequence
from typing import Literal
from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.xtensor.type import as_xtensor
from pytensor.xtensor.vectorization import XBlockwise
def cholesky(
x,
lower: bool = True,
*,
check_finite: bool = False,
overwrite_a: bool = False,
on_error: Literal["raise", "nan"] = "raise",
dims: Sequence[str],
):
if len(dims) != 2:
raise ValueError(f"Cholesky needs two dims, got {len(dims)}")
core_op = Cholesky(
lower=lower,
check_finite=check_finite,
overwrite_a=overwrite_a,
on_error=on_error,
)
core_dims = (
((dims[0], dims[1]),),
((dims[0], dims[1]),),
)
x_op = XBlockwise(core_op, core_dims=core_dims)
return x_op(x)
def solve(
a,
b,
dims: Sequence[str],
assume_a="gen",
lower: bool = False,
check_finite: bool = False,
):
a, b = as_xtensor(a), as_xtensor(b)
input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]]
output_core_dims: tuple[tuple[str] | tuple[str, str]]
if len(dims) == 2:
b_ndim = 1
[m1_dim] = [dim for dim in dims if dim not in b.type.dims]
m2_dim = dims[0] if dims[0] != m1_dim else dims[1]
input_core_dims = ((m1_dim, m2_dim), (m2_dim,))
# The shared dim disappears in the output
output_core_dims = ((m1_dim,),)
elif len(dims) == 3:
b_ndim = 2
[n_dim] = [dim for dim in dims if dim not in a.type.dims]
[m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim]
input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim))
# The shared dim disappears in the output
output_core_dims = ((m1_dim, n_dim),)
else:
raise ValueError("Solve dims must have length 2 or 3")
core_op = Solve(
b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite
)
x_op = XBlockwise(
core_op,
core_dims=(input_core_dims, output_core_dims),
)
return x_op(a, b)
import sys
import pytensor.scalar as ps
from pytensor.scalar import ScalarOp
from pytensor.xtensor.vectorization import XElemwise
this_module = sys.modules[__name__]
def _as_xelemwise(core_op: ScalarOp) -> XElemwise:
out = XElemwise(core_op)
out.__doc__ = f"Ufunc version of {core_op} for XTensorVariables"
return out
abs = _as_xelemwise(ps.abs)
add = _as_xelemwise(ps.add)
logical_and = bitwise_and = and_ = _as_xelemwise(ps.and_)
angle = _as_xelemwise(ps.angle)
arccos = _as_xelemwise(ps.arccos)
arccosh = _as_xelemwise(ps.arccosh)
arcsin = _as_xelemwise(ps.arcsin)
arcsinh = _as_xelemwise(ps.arcsinh)
arctan = _as_xelemwise(ps.arctan)
arctan2 = _as_xelemwise(ps.arctan2)
arctanh = _as_xelemwise(ps.arctanh)
betainc = _as_xelemwise(ps.betainc)
betaincinv = _as_xelemwise(ps.betaincinv)
ceil = _as_xelemwise(ps.ceil)
clip = _as_xelemwise(ps.clip)
complex = _as_xelemwise(ps.complex)
conjugate = conj = _as_xelemwise(ps.conj)
cos = _as_xelemwise(ps.cos)
cosh = _as_xelemwise(ps.cosh)
deg2rad = _as_xelemwise(ps.deg2rad)
equal = eq = _as_xelemwise(ps.eq)
erf = _as_xelemwise(ps.erf)
erfc = _as_xelemwise(ps.erfc)
erfcinv = _as_xelemwise(ps.erfcinv)
erfcx = _as_xelemwise(ps.erfcx)
erfinv = _as_xelemwise(ps.erfinv)
exp = _as_xelemwise(ps.exp)
exp2 = _as_xelemwise(ps.exp2)
expm1 = _as_xelemwise(ps.expm1)
floor = _as_xelemwise(ps.floor)
floor_divide = floor_div = int_div = _as_xelemwise(ps.int_div)
gamma = _as_xelemwise(ps.gamma)
gammainc = _as_xelemwise(ps.gammainc)
gammaincc = _as_xelemwise(ps.gammaincc)
gammainccinv = _as_xelemwise(ps.gammainccinv)
gammaincinv = _as_xelemwise(ps.gammaincinv)
gammal = _as_xelemwise(ps.gammal)
gammaln = _as_xelemwise(ps.gammaln)
gammau = _as_xelemwise(ps.gammau)
greater_equal = ge = _as_xelemwise(ps.ge)
greater = gt = _as_xelemwise(ps.gt)
hyp2f1 = _as_xelemwise(ps.hyp2f1)
i0 = _as_xelemwise(ps.i0)
i1 = _as_xelemwise(ps.i1)
identity = _as_xelemwise(ps.identity)
imag = _as_xelemwise(ps.imag)
logical_not = bitwise_invert = bitwise_not = invert = _as_xelemwise(ps.invert)
isinf = _as_xelemwise(ps.isinf)
isnan = _as_xelemwise(ps.isnan)
iv = _as_xelemwise(ps.iv)
ive = _as_xelemwise(ps.ive)
j0 = _as_xelemwise(ps.j0)
j1 = _as_xelemwise(ps.j1)
jv = _as_xelemwise(ps.jv)
kve = _as_xelemwise(ps.kve)
less_equal = le = _as_xelemwise(ps.le)
log = _as_xelemwise(ps.log)
log10 = _as_xelemwise(ps.log10)
log1mexp = _as_xelemwise(ps.log1mexp)
log1p = _as_xelemwise(ps.log1p)
log2 = _as_xelemwise(ps.log2)
less = lt = _as_xelemwise(ps.lt)
mod = _as_xelemwise(ps.mod)
multiply = mul = _as_xelemwise(ps.mul)
negative = neg = _as_xelemwise(ps.neg)
not_equal = neq = _as_xelemwise(ps.neq)
logical_or = bitwise_or = or_ = _as_xelemwise(ps.or_)
owens_t = _as_xelemwise(ps.owens_t)
polygamma = _as_xelemwise(ps.polygamma)
power = pow = _as_xelemwise(ps.pow)
psi = _as_xelemwise(ps.psi)
rad2deg = _as_xelemwise(ps.rad2deg)
real = _as_xelemwise(ps.real)
reciprocal = _as_xelemwise(ps.reciprocal)
round = _as_xelemwise(ps.round_half_to_even)
maximum = _as_xelemwise(ps.scalar_maximum)
minimum = _as_xelemwise(ps.scalar_minimum)
second = _as_xelemwise(ps.second)
sigmoid = _as_xelemwise(ps.sigmoid)
sign = _as_xelemwise(ps.sign)
sin = _as_xelemwise(ps.sin)
sinh = _as_xelemwise(ps.sinh)
softplus = _as_xelemwise(ps.softplus)
square = sqr = _as_xelemwise(ps.sqr)
sqrt = _as_xelemwise(ps.sqrt)
subtract = sub = _as_xelemwise(ps.sub)
where = switch = _as_xelemwise(ps.switch)
tan = _as_xelemwise(ps.tan)
tanh = _as_xelemwise(ps.tanh)
tri_gamma = _as_xelemwise(ps.tri_gamma)
true_divide = true_div = _as_xelemwise(ps.true_div)
trunc = _as_xelemwise(ps.trunc)
logical_xor = bitwise_xor = xor = _as_xelemwise(ps.xor)
import pytensor.xtensor.rewriting.basic
import pytensor.xtensor.rewriting.shape
import pytensor.xtensor.rewriting.vectorization
from pytensor.graph import node_rewriter
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
from pytensor.xtensor.vectorization import XBlockwise, XElemwise
@register_lower_xtensor
@node_rewriter(tracks=[XElemwise])
def lower_elemwise(fgraph, node):
out_dims = node.outputs[0].type.dims
# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = []
for inp in node.inputs:
inp_dims = inp.type.dims
order = [
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
for out_dim in out_dims
]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
tensor_inputs.append(tensor_inp)
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
*tensor_inputs, return_list=True
)
# Convert output Tensors to XTensors
new_outs = [
xtensor_from_tensor(tensor_out, dims=out_dims) for tensor_out in tensor_outs
]
return new_outs
@register_lower_xtensor
@node_rewriter(tracks=[XBlockwise])
def lower_blockwise(fgraph, node):
op: XBlockwise = node.op
batch_ndim = node.outputs[0].type.ndim - len(op.core_dims[1][0])
batch_dims = node.outputs[0].type.dims[:batch_ndim]
# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_inputs = []
for inp, core_dims in zip(node.inputs, op.core_dims[0]):
inp_dims = inp.type.dims
# Align the batch dims of the input, and place the core dims on the right
batch_order = [
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
for batch_dim in batch_dims
]
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
tensor_inputs.append(tensor_inp)
signature = op.signature or getattr(op.core_op, "gufunc_signature", None)
if signature is None:
# Build a signature based on the core dimensions
# The Op signature could be more strict, as core_dims will never be repeated, but no functionality depends greatly on it
inputs_core_dims, outputs_core_dims = op.core_dims
inputs_signature = ",".join(
f"({', '.join(inp_core_dims)})" for inp_core_dims in inputs_core_dims
)
outputs_signature = ",".join(
f"({', '.join(out_core_dims)})" for out_core_dims in outputs_core_dims
)
signature = f"{inputs_signature}->{outputs_signature}"
tensor_op = Blockwise(core_op=op.core_op, signature=signature)
tensor_outs = tensor_op(*tensor_inputs, return_list=True)
# Convert output Tensors to XTensors
new_outs = [
xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True)
]
return new_outs
......@@ -231,6 +231,109 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
"Call `.astype(complex)` for the symbolic equivalent."
)
# Python valid overloads
def __abs__(self):
return px.math.abs(self)
def __neg__(self):
return px.math.neg(self)
def __lt__(self, other):
return px.math.lt(self, other)
def __le__(self, other):
return px.math.le(self, other)
def __gt__(self, other):
return px.math.gt(self, other)
def __ge__(self, other):
return px.math.ge(self, other)
def __invert__(self):
return px.math.invert(self)
def __and__(self, other):
return px.math.and_(self, other)
def __or__(self, other):
return px.math.or_(self, other)
def __xor__(self, other):
return px.math.xor(self, other)
def __rand__(self, other):
return px.math.and_(other, self)
def __ror__(self, other):
return px.math.or_(other, self)
def __rxor__(self, other):
return px.math.xor(other, self)
def __add__(self, other):
return px.math.add(self, other)
def __sub__(self, other):
return px.math.sub(self, other)
def __mul__(self, other):
return px.math.mul(self, other)
def __div__(self, other):
return px.math.div(self, other)
def __pow__(self, other):
return px.math.pow(self, other)
def __mod__(self, other):
return px.math.mod(self, other)
def __divmod__(self, other):
return px.math.divmod(self, other)
def __truediv__(self, other):
return px.math.true_div(self, other)
def __floordiv__(self, other):
return px.math.floor_div(self, other)
def __rtruediv__(self, other):
return px.math.true_div(other, self)
def __rfloordiv__(self, other):
return px.math.floor_div(other, self)
def __radd__(self, other):
return px.math.add(other, self)
def __rsub__(self, other):
return px.math.sub(other, self)
def __rmul__(self, other):
return px.math.mul(other, self)
def __rdiv__(self, other):
return px.math.div_proxy(other, self)
def __rmod__(self, other):
return px.math.mod(other, self)
def __rdivmod__(self, other):
return px.math.divmod(other, self)
def __rpow__(self, other):
return px.math.pow(other, self)
def __ceil__(self):
return px.math.ceil(self)
def __floor__(self):
return px.math.floor(self)
def __trunc__(self):
return px.math.trunc(self)
# DataArray-like attributes
# https://docs.xarray.dev/en/latest/api.html#id1
@property
......@@ -293,6 +396,11 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
new_out.name = new_name
return new_out
def copy(self, name: str | None = None):
out = px.math.identity(self)
out.name = name # type: ignore
return out
def item(self):
raise NotImplementedError("item not implemented for XTensorVariable")
......@@ -311,6 +419,22 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def __getitem__(self, idx):
raise NotImplementedError("Indexing not yet implemnented")
# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max):
return px.math.clip(self, min, max)
def conj(self):
return px.math.conj(self)
@property
def imag(self):
return px.math.imag(self)
@property
def real(self):
return px.math.real(self)
# Reshaping and reorganizing
# https://docs.xarray.dev/en/latest/api.html#id8
def stack(self, dim, **dims):
......
from itertools import chain
from pytensor import scalar as ps
from pytensor.graph import Apply, Op
from pytensor.tensor import tensor
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor
def combine_dims_and_shape(inputs):
dims_and_shape: dict[str, int | None] = {}
for inp in inputs:
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
if dim not in dims_and_shape:
dims_and_shape[dim] = dim_length
elif dim_length is not None:
# Check for conflicting shapes
if (dims_and_shape[dim] is not None) and (
dims_and_shape[dim] != dim_length
):
raise ValueError(f"Dimension {dim} has conflicting shapes")
# Keep the non-None shape
dims_and_shape[dim] = dim_length
return dims_and_shape
class XElemwise(XOp):
__props__ = ("scalar_op",)
def __init__(self, scalar_op):
super().__init__()
self.scalar_op = scalar_op
def make_node(self, *inputs):
inputs = [as_xtensor(inp) for inp in inputs]
if (self.scalar_op.nin != -1) and (len(inputs) != self.scalar_op.nin):
raise ValueError(
f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}"
)
dims_and_shape = combine_dims_and_shape(inputs)
if dims_and_shape:
output_dims, output_shape = zip(*dims_and_shape.items())
else:
output_dims, output_shape = (), ()
dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs]
output_dtypes = [
out.type.dtype for out in self.scalar_op.make_node(*dummy_scalars).outputs
]
outputs = [
xtensor(dtype=output_dtype, dims=output_dims, shape=output_shape)
for output_dtype in output_dtypes
]
return Apply(self, inputs, outputs)
class XBlockwise(XOp):
__props__ = ("core_op", "core_dims")
def __init__(
self,
core_op: Op,
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[tuple[str, ...], ...]],
signature: str | None = None,
):
super().__init__()
self.core_op = core_op
self.core_dims = core_dims
self.signature = signature # Only used for lowering, not for validation
def make_node(self, *inputs):
inputs = [as_xtensor(i) for i in inputs]
if len(inputs) != len(self.core_dims[0]):
raise ValueError(
f"Wrong number of inputs, expected {len(self.core_dims[0])}, got {len(inputs)}"
)
dims_and_shape = combine_dims_and_shape(inputs)
core_inputs_dims, core_outputs_dims = self.core_dims
core_input_dims_set = set(chain.from_iterable(core_inputs_dims))
batch_dims, batch_shape = zip(
*((k, v) for k, v in dims_and_shape.items() if k not in core_input_dims_set)
)
dummy_core_inputs = []
for inp, core_inp_dims in zip(inputs, core_inputs_dims):
try:
core_static_shape = [
inp.type.shape[inp.type.dims.index(d)] for d in core_inp_dims
]
except IndexError:
raise ValueError(
f"At least one core dim={core_inp_dims} missing from input {inp} with dims={inp.type.dims}"
)
dummy_core_inputs.append(
tensor(dtype=inp.type.dtype, shape=core_static_shape)
)
core_node = self.core_op.make_node(*dummy_core_inputs)
outputs = [
xtensor(
dtype=core_out.type.dtype,
shape=batch_shape + core_out.type.shape,
dims=batch_dims + core_out_dims,
)
for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims)
]
return Apply(self, inputs, outputs)
# ruff: noqa: E402
import pytest
pytest.importorskip("xarray")
pytest.importorskip("xarray_einstats")
import numpy as np
from xarray import DataArray
from xarray_einstats.linalg import (
cholesky as xr_cholesky,
)
from xarray_einstats.linalg import (
solve as xr_solve,
)
from pytensor.xtensor.linalg import cholesky, solve
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_assert_allclose, xr_function
def test_cholesky():
x = xtensor("x", dims=("a", "batch", "b"), shape=(4, 3, 4))
y = cholesky(x, dims=["b", "a"])
assert y.type.dims == ("batch", "b", "a")
assert y.type.shape == (3, 4, 4)
fn = xr_function([x], y)
rng = np.random.default_rng(25)
x_ = rng.random(size=(3, 4, 4))
x_ = x_ @ x_.mT
x_test = DataArray(x_.transpose(1, 0, 2), dims=x.type.dims)
xr_assert_allclose(
fn(x_test),
xr_cholesky(x_test, dims=["b", "a"]),
)
def test_solve_vector_b():
a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1))
b = xtensor("b", dims=("city", "planet"), shape=(None, 2))
x = solve(a, b, dims=["country", "city"])
assert x.type.dims == ("galaxy", "planet", "country")
# Core Solve doesn't make use of the fact A must be square in the static shape
assert x.type.shape == (1, 2, None)
fn = xr_function([a, b], x)
rng = np.random.default_rng(25)
a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims)
b_test = DataArray(rng.random(size=(4, 2)), dims=b.type.dims)
xr_assert_allclose(
fn(a_test, b_test),
xr_solve(a_test, b_test, dims=["country", "city"]),
)
def test_solve_matrix_b():
a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1))
b = xtensor("b", dims=("district", "city", "planet"), shape=(5, None, 2))
x = solve(a, b, dims=["country", "city", "district"])
assert x.type.dims == ("galaxy", "planet", "country", "district")
# Core Solve doesn't make use of the fact A must be square in the static shape
assert x.type.shape == (1, 2, None, 5)
fn = xr_function([a, b], x)
rng = np.random.default_rng(25)
a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims)
b_test = DataArray(rng.random(size=(5, 4, 2)), dims=b.type.dims)
xr_assert_allclose(
fn(a_test, b_test),
xr_solve(a_test, b_test, dims=["country", "city", "district"]),
)
# ruff: noqa: E402
import pytest
pytest.importorskip("xarray")
import inspect
import numpy as np
from xarray import DataArray
import pytensor.scalar as ps
import pytensor.xtensor.math as pxm
from pytensor import function
from pytensor.scalar import ScalarOp
from pytensor.xtensor.basic import rename
from pytensor.xtensor.math import add, exp
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_assert_allclose, xr_function
def test_all_scalar_ops_are_wrapped():
# This ignores wrapper functions
pxm_members = {name for name, _ in inspect.getmembers(pxm)}
for name, op in inspect.getmembers(ps):
if name in {
"complex_from_polar",
"inclosedrange",
"inopenrange",
"round_half_away_from_zero",
"round_half_to_even",
"scalar_abs",
"scalar_maximum",
"scalar_minimum",
} or name.startswith("convert_to_"):
# These are not regular numpy functions or are unusual alias
continue
if isinstance(op, ScalarOp) and name not in pxm_members:
raise NotImplementedError(f"ScalarOp {name} not wrapped in xtensor.math")
def test_scalar_case():
x = xtensor("x", dims=(), shape=())
y = xtensor("y", dims=(), shape=())
out = add(x, y)
fn = function([x, y], out)
x_test = DataArray(2.0, dims=())
y_test = DataArray(3.0, dims=())
np.testing.assert_allclose(fn(x_test.values, y_test.values), 5.0)
def test_dimension_alignment():
x = xtensor("x", dims=("city", "country", "planet"), shape=(2, 3, 4))
y = xtensor(
"y",
dims=("galaxy", "country", "city"),
shape=(5, 3, 2),
)
z = xtensor("z", dims=("universe",), shape=(1,))
out = add(x, y, z)
assert out.type.dims == ("city", "country", "planet", "galaxy", "universe")
fn = function([x, y, z], out)
rng = np.random.default_rng(41)
test_x, test_y, test_z = (
DataArray(rng.normal(size=inp.type.shape), dims=inp.type.dims)
for inp in [x, y, z]
)
np.testing.assert_allclose(
fn(test_x.values, test_y.values, test_z.values),
(test_x + test_y + test_z).values,
)
def test_renamed_dimension_alignment():
x = xtensor("x", dims=("a", "b1", "b2"), shape=(2, 3, 3))
y = rename(x, b1="b2", b2="b1")
z = rename(x, b2="b3")
assert y.type.dims == ("a", "b2", "b1")
assert z.type.dims == ("a", "b1", "b3")
out1 = add(x, x) # self addition
assert out1.type.dims == ("a", "b1", "b2")
out2 = add(x, y) # transposed addition
assert out2.type.dims == ("a", "b1", "b2")
out3 = add(x, z) # outer addition
assert out3.type.dims == ("a", "b1", "b2", "b3")
fn = xr_function([x], [out1, out2, out3])
x_test = DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)
results = fn(x_test)
expected_results = [
x_test + x_test,
x_test + x_test.rename(b1="b2", b2="b1"),
x_test + x_test.rename(b2="b3"),
]
for result, expected_result in zip(results, expected_results):
xr_assert_allclose(result, expected_result)
def test_chained_operations():
x = xtensor("x", dims=("city",), shape=(None,))
y = xtensor("y", dims=("country",), shape=(4,))
z = add(exp(x), exp(y))
assert z.type.dims == ("city", "country")
assert z.type.shape == (None, 4)
fn = function([x, y], z)
x_test = DataArray(np.zeros(3), dims="city")
y_test = DataArray(np.ones(4), dims="country")
np.testing.assert_allclose(
fn(x_test.values, y_test.values),
(np.exp(x_test) + np.exp(y_test)).values,
)
def test_multiple_constant():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
out = exp(x * 2) + 2
fn = function([x], out)
x_test = np.zeros((2, 3), dtype=x.type.dtype)
res = fn(x_test)
expected_res = np.exp(x_test * 2) + 2
np.testing.assert_allclose(res, expected_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论