提交 4ce092fe authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add xtensor docs

上级 ebfac59e
...@@ -25,6 +25,7 @@ Modules ...@@ -25,6 +25,7 @@ Modules
sparse/index sparse/index
tensor/index tensor/index
typed_list typed_list
xtensor/index
.. module:: pytensor .. module:: pytensor
:platform: Unix, Windows :platform: Unix, Windows
......
(libdoc_xtensor)=
# `xtensor` -- XTensor operations
This module implements as abstraction layer on regular tensor operations, that behaves like Xarray.
A new type {class}`pytensor.xtensor.type.XTensorType`, generalizes the {class}`pytensor.tensor.TensorType`
with the addition of a `dims` attribute, that labels the dimensions of the tensor.
Variables of XTensorType (i.e., {class}`pytensor.xtensor.type.XTensorVariable`s) are the symbolic counterpart
to xarray DataArray objects.
The module implements several PyTensor operations {class}`pytensor.xtensor.basic.XOp`s, whose signature mimics that of
xarray (and xarray_einstats) DataArray operations. These operations, unlike most regular PyTensor operations, cannot
be directly evaluated, but require a rewrite (lowering) into a regular tensor graph that can itself be evaluated as usual.
Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray.
If the existing XOps can be composed to produce the desired result, then we can use them directly.
## Coordinates
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`.
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor.
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor.
## Example
```{testcode}
import pytensor.tensor as pt
import pytensor.xtensor as ptx
a = pt.tensor("a", shape=(3,))
b = pt.tensor("b", shape=(4,))
ax = ptx.as_xtensor(a, dims=["x"])
bx = ptx.as_xtensor(b, dims=["y"])
zx = ax + bx
assert zx.type == ptx.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4))
z = zx.values
z.dprint()
```
```{testoutput}
TensorFromXTensor [id A]
└─ XElemwise{scalar_op=Add()} [id B]
├─ XTensorFromTensor{dims=('x',)} [id C]
│ └─ a [id D]
└─ XTensorFromTensor{dims=('y',)} [id E]
└─ b [id F]
```
Once we compile the graph, no XOps are left.
```{testcode}
import pytensor
with pytensor.config.change_flags(optimizer_verbose=True):
fn = pytensor.function([a, b], z)
```
```{testoutput}
rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0)
rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None
rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None
rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0)
```
```{testcode}
fn.dprint()
```
```{testoutput}
Add [id A] 2
├─ ExpandDims{axis=1} [id B] 1
│ └─ a [id C]
└─ ExpandDims{axis=0} [id D] 0
└─ b [id E]
```
## Index
:::{toctree}
:maxdepth: 1
module_functions
math
linalg
random
type
:::
\ No newline at end of file
(libdoc_xtensor_linalg)=
# `xtensor.linalg` -- Linear algebra operations
```{eval-rst}
.. automodule:: pytensor.xtensor.linalg
:members:
```
(libdoc_xtensor_math)=
# `xtensor.math` Mathematical operations
```{eval-rst}
.. automodule:: pytensor.xtensor.math
:members:
:exclude-members: XDot, dot
```
\ No newline at end of file
(libdoc_xtensor_module_function)=
# `xtensor` -- Module level operations
```{eval-rst}
.. automodule:: pytensor.xtensor
:members: broadcast, concat, dot, full_like, ones_like, zeros_like
```
(libdoc_xtensor_random)=
# `xtensor.random` Random number generator operations
```{eval-rst}
.. automodule:: pytensor.xtensor.random
:members:
```
(libdoc_xtenor_type)=
# `xtensor.type` -- Types and Variables
## XTensorVariable creation functions
```{eval-rst}
.. automodule:: pytensor.xtensor.type
:members: xtensor, xtensor_constant, as_xtensor
```
## XTensor Type and Variable classes
```{eval-rst}
.. automodule:: pytensor.xtensor.type
:noindex:
:members: XTensorType, XTensorVariable, XTensorConstant
```
import warnings import warnings
import pytensor.xtensor.rewriting import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg, random from pytensor.xtensor import linalg, math, random
from pytensor.xtensor.math import dot from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like
from pytensor.xtensor.type import ( from pytensor.xtensor.type import (
......
...@@ -11,17 +11,31 @@ def cholesky( ...@@ -11,17 +11,31 @@ def cholesky(
lower: bool = True, lower: bool = True,
*, *,
check_finite: bool = False, check_finite: bool = False,
overwrite_a: bool = False,
on_error: Literal["raise", "nan"] = "raise", on_error: Literal["raise", "nan"] = "raise",
dims: Sequence[str], dims: Sequence[str],
): ):
"""Compute the Cholesky decomposition of an XTensorVariable.
Parameters
----------
x : XTensorVariable
The input variable to decompose.
lower : bool, optional
Whether to return the lower triangular matrix. Default is True.
check_finite : bool, optional
Whether to check that the input is finite. Default is False.
on_error : {'raise', 'nan'}, optional
What to do if the input is not positive definite. If 'raise', an error is raised.
If 'nan', the output will contain NaNs. Default is 'raise'.
dims : Sequence[str]
The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
"""
if len(dims) != 2: if len(dims) != 2:
raise ValueError(f"Cholesky needs two dims, got {len(dims)}") raise ValueError(f"Cholesky needs two dims, got {len(dims)}")
core_op = Cholesky( core_op = Cholesky(
lower=lower, lower=lower,
check_finite=check_finite, check_finite=check_finite,
overwrite_a=overwrite_a,
on_error=on_error, on_error=on_error,
) )
core_dims = ( core_dims = (
...@@ -40,6 +54,30 @@ def solve( ...@@ -40,6 +54,30 @@ def solve(
lower: bool = False, lower: bool = False,
check_finite: bool = False, check_finite: bool = False,
): ):
"""Solve a system of linear equations using XTensorVariables.
Parameters
----------
a : XTensorVariable
The left hand-side xtensor.
b : XTensorVariable
The right-hand side xtensor.
dims : Sequence[str]
The core dimensions over which to solve the linear equations.
If length is 2, we are solving a matrix-vector equation,
and the two dimensions should be present in `a`, but only one in `b`.
If length is 3, we are solving a matrix-matrix equation,
and two dimensions should be present in `a`, two in `b`, and only one should be shared.
In both cases the shared dimension will not appear in the output.
assume_a : str, optional
The type of matrix `a` is assumed to be. Default is 'gen' (general).
Options are ["gen", "sym", "her", "pos", "tridiagonal", "banded"].
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
lower : bool, optional
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
check_finite : bool, optional
Whether to check that the input is finite. Default is False.
"""
a, b = as_xtensor(a), as_xtensor(b) a, b = as_xtensor(a), as_xtensor(b)
input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]] input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]]
output_core_dims: tuple[tuple[str] | tuple[str, str]] output_core_dims: tuple[tuple[str] | tuple[str, str]]
......
import sys import sys
from collections.abc import Iterable from collections.abc import Iterable, Sequence
from types import EllipsisType from types import EllipsisType
import numpy as np import numpy as np
...@@ -7,7 +7,6 @@ import numpy as np ...@@ -7,7 +7,6 @@ import numpy as np
import pytensor.scalar as ps import pytensor.scalar as ps
from pytensor import config from pytensor import config
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.scalar import ScalarOp
from pytensor.scalar.basic import _cast_mapping, upcast from pytensor.scalar.basic import _cast_mapping, upcast
from pytensor.xtensor.basic import XOp, as_xtensor from pytensor.xtensor.basic import XOp, as_xtensor
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
...@@ -17,110 +16,477 @@ from pytensor.xtensor.vectorization import XElemwise ...@@ -17,110 +16,477 @@ from pytensor.xtensor.vectorization import XElemwise
this_module = sys.modules[__name__] this_module = sys.modules[__name__]
def _as_xelemwise(core_op: ScalarOp) -> XElemwise: def _as_xelemwise(core_op):
out = XElemwise(core_op) x_op = XElemwise(core_op)
out.__doc__ = f"Ufunc version of {core_op} for XTensorVariables"
return out def decorator(func):
def wrapper(*args, **kwargs):
return x_op(*args, **kwargs)
abs = _as_xelemwise(ps.abs)
add = _as_xelemwise(ps.add) wrapper.__doc__ = f"Ufunc version of {core_op} for XTensorVariables"
logical_and = bitwise_and = and_ = _as_xelemwise(ps.and_) return wrapper
angle = _as_xelemwise(ps.angle)
arccos = _as_xelemwise(ps.arccos) return decorator
arccosh = _as_xelemwise(ps.arccosh)
arcsin = _as_xelemwise(ps.arcsin)
arcsinh = _as_xelemwise(ps.arcsinh) @_as_xelemwise(ps.abs)
arctan = _as_xelemwise(ps.arctan) def abs(): ...
arctan2 = _as_xelemwise(ps.arctan2)
arctanh = _as_xelemwise(ps.arctanh)
betainc = _as_xelemwise(ps.betainc) @_as_xelemwise(ps.add)
betaincinv = _as_xelemwise(ps.betaincinv) def add(): ...
ceil = _as_xelemwise(ps.ceil)
clip = _as_xelemwise(ps.clip)
complex = _as_xelemwise(ps.complex) @_as_xelemwise(ps.and_)
conjugate = conj = _as_xelemwise(ps.conj) def logical_and(): ...
cos = _as_xelemwise(ps.cos)
cosh = _as_xelemwise(ps.cosh)
deg2rad = _as_xelemwise(ps.deg2rad) @_as_xelemwise(ps.and_)
equal = eq = _as_xelemwise(ps.eq) def bitwise_and(): ...
erf = _as_xelemwise(ps.erf)
erfc = _as_xelemwise(ps.erfc)
erfcinv = _as_xelemwise(ps.erfcinv) and_ = logical_and
erfcx = _as_xelemwise(ps.erfcx)
erfinv = _as_xelemwise(ps.erfinv)
exp = _as_xelemwise(ps.exp) @_as_xelemwise(ps.angle)
exp2 = _as_xelemwise(ps.exp2) def angle(): ...
expm1 = _as_xelemwise(ps.expm1)
floor = _as_xelemwise(ps.floor)
floor_divide = floor_div = int_div = _as_xelemwise(ps.int_div) @_as_xelemwise(ps.arccos)
gamma = _as_xelemwise(ps.gamma) def arccos(): ...
gammainc = _as_xelemwise(ps.gammainc)
gammaincc = _as_xelemwise(ps.gammaincc)
gammainccinv = _as_xelemwise(ps.gammainccinv) @_as_xelemwise(ps.arccosh)
gammaincinv = _as_xelemwise(ps.gammaincinv) def arccosh(): ...
gammal = _as_xelemwise(ps.gammal)
gammaln = _as_xelemwise(ps.gammaln)
gammau = _as_xelemwise(ps.gammau) @_as_xelemwise(ps.arcsin)
greater_equal = ge = _as_xelemwise(ps.ge) def arcsin(): ...
greater = gt = _as_xelemwise(ps.gt)
hyp2f1 = _as_xelemwise(ps.hyp2f1)
i0 = _as_xelemwise(ps.i0) @_as_xelemwise(ps.arcsinh)
i1 = _as_xelemwise(ps.i1) def arcsinh(): ...
identity = _as_xelemwise(ps.identity)
imag = _as_xelemwise(ps.imag)
logical_not = bitwise_invert = bitwise_not = invert = _as_xelemwise(ps.invert) @_as_xelemwise(ps.arctan)
isinf = _as_xelemwise(ps.isinf) def arctan(): ...
isnan = _as_xelemwise(ps.isnan)
iv = _as_xelemwise(ps.iv)
ive = _as_xelemwise(ps.ive) @_as_xelemwise(ps.arctan2)
j0 = _as_xelemwise(ps.j0) def arctan2(): ...
j1 = _as_xelemwise(ps.j1)
jv = _as_xelemwise(ps.jv)
kve = _as_xelemwise(ps.kve) @_as_xelemwise(ps.arctanh)
less_equal = le = _as_xelemwise(ps.le) def arctanh(): ...
log = _as_xelemwise(ps.log)
log10 = _as_xelemwise(ps.log10)
log1mexp = _as_xelemwise(ps.log1mexp) @_as_xelemwise(ps.betainc)
log1p = _as_xelemwise(ps.log1p) def betainc(): ...
log2 = _as_xelemwise(ps.log2)
less = lt = _as_xelemwise(ps.lt)
mod = _as_xelemwise(ps.mod) @_as_xelemwise(ps.betaincinv)
multiply = mul = _as_xelemwise(ps.mul) def betaincinv(): ...
negative = neg = _as_xelemwise(ps.neg)
not_equal = neq = _as_xelemwise(ps.neq)
logical_or = bitwise_or = or_ = _as_xelemwise(ps.or_) @_as_xelemwise(ps.ceil)
owens_t = _as_xelemwise(ps.owens_t) def ceil(): ...
polygamma = _as_xelemwise(ps.polygamma)
power = pow = _as_xelemwise(ps.pow)
psi = _as_xelemwise(ps.psi) @_as_xelemwise(ps.clip)
rad2deg = _as_xelemwise(ps.rad2deg) def clip(): ...
real = _as_xelemwise(ps.real)
reciprocal = _as_xelemwise(ps.reciprocal)
round = _as_xelemwise(ps.round_half_to_even) @_as_xelemwise(ps.complex)
maximum = _as_xelemwise(ps.scalar_maximum) def complex(): ...
minimum = _as_xelemwise(ps.scalar_minimum)
second = _as_xelemwise(ps.second)
sigmoid = expit = _as_xelemwise(ps.sigmoid) @_as_xelemwise(ps.conj)
sign = _as_xelemwise(ps.sign) def conjugate(): ...
sin = _as_xelemwise(ps.sin)
sinh = _as_xelemwise(ps.sinh)
softplus = _as_xelemwise(ps.softplus) conj = conjugate
square = sqr = _as_xelemwise(ps.sqr)
sqrt = _as_xelemwise(ps.sqrt)
subtract = sub = _as_xelemwise(ps.sub) @_as_xelemwise(ps.cos)
where = switch = _as_xelemwise(ps.switch) def cos(): ...
tan = _as_xelemwise(ps.tan)
tanh = _as_xelemwise(ps.tanh)
tri_gamma = _as_xelemwise(ps.tri_gamma) @_as_xelemwise(ps.cosh)
true_divide = true_div = _as_xelemwise(ps.true_div) def cosh(): ...
trunc = _as_xelemwise(ps.trunc)
logical_xor = bitwise_xor = xor = _as_xelemwise(ps.xor)
@_as_xelemwise(ps.deg2rad)
def deg2rad(): ...
@_as_xelemwise(ps.eq)
def equal(): ...
eq = equal
@_as_xelemwise(ps.erf)
def erf(): ...
@_as_xelemwise(ps.erfc)
def erfc(): ...
@_as_xelemwise(ps.erfcinv)
def erfcinv(): ...
@_as_xelemwise(ps.erfcx)
def erfcx(): ...
@_as_xelemwise(ps.erfinv)
def erfinv(): ...
@_as_xelemwise(ps.exp)
def exp(): ...
@_as_xelemwise(ps.exp2)
def exp2(): ...
@_as_xelemwise(ps.expm1)
def expm1(): ...
@_as_xelemwise(ps.floor)
def floor(): ...
@_as_xelemwise(ps.int_div)
def floor_divide(): ...
floor_div = int_div = floor_divide
@_as_xelemwise(ps.gamma)
def gamma(): ...
@_as_xelemwise(ps.gammainc)
def gammainc(): ...
@_as_xelemwise(ps.gammaincc)
def gammaincc(): ...
@_as_xelemwise(ps.gammainccinv)
def gammainccinv(): ...
@_as_xelemwise(ps.gammaincinv)
def gammaincinv(): ...
@_as_xelemwise(ps.gammal)
def gammal(): ...
@_as_xelemwise(ps.gammaln)
def gammaln(): ...
@_as_xelemwise(ps.gammau)
def gammau(): ...
@_as_xelemwise(ps.ge)
def greater_equal(): ...
ge = greater_equal
@_as_xelemwise(ps.gt)
def greater(): ...
gt = greater
@_as_xelemwise(ps.hyp2f1)
def hyp2f1(): ...
@_as_xelemwise(ps.i0)
def i0(): ...
@_as_xelemwise(ps.i1)
def i1(): ...
@_as_xelemwise(ps.identity)
def identity(): ...
@_as_xelemwise(ps.imag)
def imag(): ...
@_as_xelemwise(ps.invert)
def logical_not(): ...
@_as_xelemwise(ps.invert)
def bitwise_not(): ...
@_as_xelemwise(ps.invert)
def bitwise_invert(): ...
@_as_xelemwise(ps.invert)
def invert(): ...
@_as_xelemwise(ps.isinf)
def isinf(): ...
@_as_xelemwise(ps.isnan)
def isnan(): ...
@_as_xelemwise(ps.iv)
def iv(): ...
@_as_xelemwise(ps.ive)
def ive(): ...
@_as_xelemwise(ps.j0)
def j0(): ...
@_as_xelemwise(ps.j1)
def j1(): ...
@_as_xelemwise(ps.jv)
def jv(): ...
@_as_xelemwise(ps.kve)
def kve(): ...
@_as_xelemwise(ps.le)
def less_equal(): ...
le = less_equal
@_as_xelemwise(ps.log)
def log(): ...
@_as_xelemwise(ps.log10)
def log10(): ...
@_as_xelemwise(ps.log1mexp)
def log1mexp(): ...
@_as_xelemwise(ps.log1p)
def log1p(): ...
@_as_xelemwise(ps.log2)
def log2(): ...
@_as_xelemwise(ps.lt)
def less(): ...
lt = less
@_as_xelemwise(ps.mod)
def mod(): ...
@_as_xelemwise(ps.mul)
def multiply(): ...
mul = multiply
@_as_xelemwise(ps.neg)
def negative(): ...
neg = negative
@_as_xelemwise(ps.neq)
def not_equal(): ...
neq = not_equal
@_as_xelemwise(ps.or_)
def logical_or(): ...
@_as_xelemwise(ps.or_)
def bitwise_or(): ...
or_ = logical_or
@_as_xelemwise(ps.owens_t)
def owens_t(): ...
@_as_xelemwise(ps.polygamma)
def polygamma(): ...
@_as_xelemwise(ps.pow)
def power(): ...
pow = power
@_as_xelemwise(ps.psi)
def psi(): ...
@_as_xelemwise(ps.rad2deg)
def rad2deg(): ...
@_as_xelemwise(ps.real)
def real(): ...
@_as_xelemwise(ps.reciprocal)
def reciprocal(): ...
@_as_xelemwise(ps.round_half_to_even)
def round(): ...
@_as_xelemwise(ps.scalar_maximum)
def maximum(): ...
@_as_xelemwise(ps.scalar_minimum)
def minimum(): ...
@_as_xelemwise(ps.second)
def second(): ...
@_as_xelemwise(ps.sigmoid)
def sigmoid(): ...
expit = sigmoid
@_as_xelemwise(ps.sign)
def sign(): ...
@_as_xelemwise(ps.sin)
def sin(): ...
@_as_xelemwise(ps.sinh)
def sinh(): ...
@_as_xelemwise(ps.softplus)
def softplus(): ...
@_as_xelemwise(ps.sqr)
def square(): ...
sqr = square
@_as_xelemwise(ps.sqrt)
def sqrt(): ...
@_as_xelemwise(ps.sub)
def subtract(): ...
sub = subtract
@_as_xelemwise(ps.switch)
def where(): ...
switch = where
@_as_xelemwise(ps.tan)
def tan(): ...
@_as_xelemwise(ps.tanh)
def tanh(): ...
@_as_xelemwise(ps.tri_gamma)
def tri_gamma(): ...
@_as_xelemwise(ps.true_div)
def true_divide(): ...
true_div = true_divide
@_as_xelemwise(ps.trunc)
def trunc(): ...
@_as_xelemwise(ps.xor)
def logical_xor(): ...
@_as_xelemwise(ps.xor)
def bitwise_xor(): ...
xor = logical_xor
_xelemwise_cast_op: dict[str, XElemwise] = {} _xelemwise_cast_op: dict[str, XElemwise] = {}
def cast(x, dtype): def cast(x, dtype):
"""Cast an XTensorVariable to a different dtype."""
if dtype == "floatX": if dtype == "floatX":
dtype = config.floatX dtype = config.floatX
else: else:
...@@ -141,6 +507,7 @@ def cast(x, dtype): ...@@ -141,6 +507,7 @@ def cast(x, dtype):
def softmax(x, dim=None): def softmax(x, dim=None):
"""Compute the softmax of an XTensorVariable along a specified dimension."""
exp_x = exp(x) exp_x = exp(x)
return exp_x / exp_x.sum(dim=dim) return exp_x / exp_x.sum(dim=dim)
...@@ -195,11 +562,11 @@ class Dot(XOp): ...@@ -195,11 +562,11 @@ class Dot(XOp):
return Apply(self, [x, y], [out]) return Apply(self, [x, y], [out])
def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None): def dot(x, y, dim: str | Sequence[str] | EllipsisType | None = None):
"""Matrix multiplication between two XTensorVariables. """Generalized dot product for XTensorVariables.
This operation performs matrix multiplication between two tensors, automatically This operation performs multiplication followed by summation for shared dimensions
aligning and contracting dimensions. The behavior matches xarray's dot operation. or simply summation for non-shared dimensions.
Parameters Parameters
---------- ----------
...@@ -207,21 +574,29 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None): ...@@ -207,21 +574,29 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
First input tensor First input tensor
y : XTensorVariable y : XTensorVariable
Second input tensor Second input tensor
dim : str, Iterable[Hashable], EllipsisType, or None, optional dim : str, Sequence[str], Ellipsis (...), or None, optional
The dimensions to contract over. If None, will contract over all matching dimensions. The dimensions to contract over. If None, will contract over all matching dimensions.
If Ellipsis (...), will contract over all dimensions. If Ellipsis (...), will contract over all dimensions.
Returns Returns
------- -------
XTensorVariable XTensorVariable
The result of the matrix multiplication.
Examples Examples
-------- --------
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
>>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4)) .. testcode::
>>> z = dot(x, y) # Result has dimensions ("a", "c")
>>> z = dot(x, y, dim=...) # Contract over all dimensions from pytensor.xtensor import xtensor, dot
x = xtensor("x", dims=("a", "b"))
y = xtensor("y", dims=("b", "c"))
assert dot(x, y).dims == ("a", "c") # Contract over shared `b` dimension
assert dot(x, y, dim=("a", "b")).dims == ("c",) # Contract over 'a' and 'b'
assert dot(x, y, dim=...).dims == () # Contract over all dimensions
""" """
x = as_xtensor(x) x = as_xtensor(x)
y = as_xtensor(y) y = as_xtensor(y)
......
...@@ -10,7 +10,7 @@ from pytensor.xtensor.type import as_xtensor ...@@ -10,7 +10,7 @@ from pytensor.xtensor.type import as_xtensor
from pytensor.xtensor.vectorization import XRV from pytensor.xtensor.vectorization import XRV
def _as_xrv( def as_xrv(
core_op: RandomVariable, core_op: RandomVariable,
core_inps_dims_map: Sequence[Sequence[int]] | None = None, core_inps_dims_map: Sequence[Sequence[int]] | None = None,
core_out_dims_map: Sequence[int] | None = None, core_out_dims_map: Sequence[int] | None = None,
...@@ -52,7 +52,6 @@ def _as_xrv( ...@@ -52,7 +52,6 @@ def _as_xrv(
max((entry + 1 for entry in core_out_dims_map), default=0), max((entry + 1 for entry in core_out_dims_map), default=0),
) )
@wraps(core_op)
def xrv_constructor( def xrv_constructor(
*params, *params,
core_dims: Sequence[str] | str | None = None, core_dims: Sequence[str] | str | None = None,
...@@ -93,38 +92,151 @@ def _as_xrv( ...@@ -93,38 +92,151 @@ def _as_xrv(
return xrv_constructor return xrv_constructor
bernoulli = _as_xrv(ptr.bernoulli) def _as_xrv(core_op: RandomVariable, name: str | None = None):
beta = _as_xrv(ptr.beta) """A decorator to create a new XRV and document it in sphinx."""
betabinom = _as_xrv(ptr.betabinom) xrv_constructor = as_xrv(core_op, name=name)
binomial = _as_xrv(ptr.binomial)
categorical = _as_xrv(ptr.categorical) def decorator(func):
cauchy = _as_xrv(ptr.cauchy) @wraps(as_xrv)
dirichlet = _as_xrv(ptr.dirichlet) def wrapper(*args, **kwargs):
exponential = _as_xrv(ptr.exponential) return xrv_constructor(*args, **kwargs)
gamma = _as_xrv(ptr._gamma)
gengamma = _as_xrv(ptr.gengamma) wrapper.__doc__ = f"XRV version of {core_op.name} for XTensorVariables"
geometric = _as_xrv(ptr.geometric)
gumbel = _as_xrv(ptr.gumbel) return wrapper
halfcauchy = _as_xrv(ptr.halfcauchy)
halfnormal = _as_xrv(ptr.halfnormal) return decorator
hypergeometric = _as_xrv(ptr.hypergeometric)
integers = _as_xrv(ptr.integers)
invgamma = _as_xrv(ptr.invgamma) @_as_xrv(ptr.bernoulli)
laplace = _as_xrv(ptr.laplace) def bernoulli(): ...
logistic = _as_xrv(ptr.logistic)
lognormal = _as_xrv(ptr.lognormal)
multinomial = _as_xrv(ptr.multinomial) @_as_xrv(ptr.beta)
nbinom = negative_binomial = _as_xrv(ptr.negative_binomial) def beta(): ...
normal = _as_xrv(ptr.normal)
pareto = _as_xrv(ptr.pareto)
poisson = _as_xrv(ptr.poisson) @_as_xrv(ptr.betabinom)
t = _as_xrv(ptr.t) def betabinom(): ...
triangular = _as_xrv(ptr.triangular)
truncexpon = _as_xrv(ptr.truncexpon)
uniform = _as_xrv(ptr.uniform) @_as_xrv(ptr.binomial)
vonmises = _as_xrv(ptr.vonmises) def binomial(): ...
wald = _as_xrv(ptr.wald)
weibull = _as_xrv(ptr.weibull)
@_as_xrv(ptr.categorical)
def categorical(): ...
@_as_xrv(ptr.cauchy)
def cauchy(): ...
@_as_xrv(ptr.dirichlet)
def dirichlet(): ...
@_as_xrv(ptr.exponential)
def exponential(): ...
@_as_xrv(ptr._gamma)
def gamma(): ...
@_as_xrv(ptr.gengamma)
def gengamma(): ...
@_as_xrv(ptr.geometric)
def geometric(): ...
@_as_xrv(ptr.gumbel)
def gumbel(): ...
@_as_xrv(ptr.halfcauchy)
def halfcauchy(): ...
@_as_xrv(ptr.halfnormal)
def halfnormal(): ...
@_as_xrv(ptr.hypergeometric)
def hypergeometric(): ...
@_as_xrv(ptr.integers)
def integers(): ...
@_as_xrv(ptr.invgamma)
def invgamma(): ...
@_as_xrv(ptr.laplace)
def laplace(): ...
@_as_xrv(ptr.logistic)
def logistic(): ...
@_as_xrv(ptr.lognormal)
def lognormal(): ...
@_as_xrv(ptr.multinomial)
def multinomial(): ...
@_as_xrv(ptr.negative_binomial)
def negative_binomial(): ...
nbinom = negative_binomial
@_as_xrv(ptr.normal)
def normal(): ...
@_as_xrv(ptr.pareto)
def pareto(): ...
@_as_xrv(ptr.poisson)
def poisson(): ...
@_as_xrv(ptr.t)
def t(): ...
@_as_xrv(ptr.triangular)
def triangular(): ...
@_as_xrv(ptr.truncexpon)
def truncexpon(): ...
@_as_xrv(ptr.uniform)
def uniform(): ...
@_as_xrv(ptr.vonmises)
def vonmises(): ...
@_as_xrv(ptr.wald)
def wald(): ...
@_as_xrv(ptr.weibull)
def weibull(): ...
def multivariate_normal( def multivariate_normal(
...@@ -136,6 +248,7 @@ def multivariate_normal( ...@@ -136,6 +248,7 @@ def multivariate_normal(
rng=None, rng=None,
method: Literal["cholesky", "svd", "eigh"] = "cholesky", method: Literal["cholesky", "svd", "eigh"] = "cholesky",
): ):
"""Multivariate normal random variable."""
mean = as_xtensor(mean) mean = as_xtensor(mean)
if len(core_dims) != 2: if len(core_dims) != 2:
raise ValueError( raise ValueError(
...@@ -147,7 +260,7 @@ def multivariate_normal( ...@@ -147,7 +260,7 @@ def multivariate_normal(
if core_dims[0] not in mean.type.dims: if core_dims[0] not in mean.type.dims:
core_dims = core_dims[::-1] core_dims = core_dims[::-1]
xop = _as_xrv(ptr.MvNormalRV(method=method)) xop = as_xrv(ptr.MvNormalRV(method=method))
return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng) return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng)
......
# XTensor Module
This module implements as abstraction layer on regular tensor operations, that behaves like Xarray.
A new type `XTensorType`, generalizes the `TensorType` with the addition of a `dims` attribute,
that labels the dimensions of the tensor.
Variables of `XTensorType` (i.e., `XTensorVariable`s) are the symbolic counterpart to xarray DataArray objects.
The module implements several PyTensor operations `XOp`s, whose signature mimics that of xarray (and xarray_einstants) DataArray operations.
These operations, unlike most regular PyTensor operations, cannot be directly evaluated, but require a rewrite (lowering) into
a regular tensor graph that can itself be evaluated as usual.
Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray.
If the existing XOps can be composed to produce the desired result, then we can use them directly.
## Coordinates
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`.
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor.
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor.
## Example
```python
import pytensor.tensor as pt
import pytensor.xtensor as px
a = pt.tensor("a", shape=(3,))
b = pt.tensor("b", shape=(4,))
ax = px.as_xtensor(a, dims=["x"])
bx = px.as_xtensor(b, dims=["y"])
zx = ax + bx
assert zx.type == px.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4))
z = zx.values
z.dprint()
# TensorFromXTensor [id A]
# └─ XElemwise{scalar_op=Add()} [id B]
# ├─ XTensorFromTensor{dims=('x',)} [id C]
# │ └─ a [id D]
# └─ XTensorFromTensor{dims=('y',)} [id E]
# └─ b [id F]
```
Once we compile the graph, no `XOp`s are left.
```python
import pytensor
with pytensor.config.change_flags(optimizer_verbose=True):
fn = pytensor.function([a, b], z)
# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0)
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0)
fn.dprint()
# Add [id A] 2
# ├─ ExpandDims{axis=1} [id B] 1
# │ └─ a [id C]
# └─ ExpandDims{axis=0} [id D] 0
# └─ b [id E]
```
...@@ -303,6 +303,35 @@ class Concat(XOp): ...@@ -303,6 +303,35 @@ class Concat(XOp):
def concat(xtensors, dim: str): def concat(xtensors, dim: str):
"""Concatenate a sequence of XTensorVariables along a specified dimension.
Parameters
----------
xtensors : Sequence of XTensorVariable
The tensors to concatenate.
dim : str
The dimension along which to concatenate the tensors.
Returns
-------
XTensorVariable
Example
-------
.. testcode::
from pytensor.xtensor import as_xtensor, xtensor, concat
x = xtensor("x", shape=(2, 3), dims=("a", "b"))
zero = as_xtensor([0], dims=("a"))
out = concat([zero, x, zero], dim="a")
assert out.type.dims == ("a", "b")
assert out.type.shape == (4, 3)
"""
return Concat(dim=dim)(*xtensors) return Concat(dim=dim)(*xtensors)
......
...@@ -201,6 +201,24 @@ def xtensor( ...@@ -201,6 +201,24 @@ def xtensor(
shape: Sequence[int | None] | None = None, shape: Sequence[int | None] | None = None,
dtype: str | np.dtype = "floatX", dtype: str | np.dtype = "floatX",
): ):
"""Create an XTensorVariable.
Parameters
----------
name : str or None, optional
The name of the variable
dims : Sequence[str]
The names of the dimensions of the tensor
shape : Sequence[int | None] or None, optional
The shape of the tensor. If None, defaults to a shape with None for each dimension.
dtype : str or np.dtype, optional
The data type of the tensor. Defaults to 'floatX' (config.floatX).
Returns
-------
XTensorVariable
A new XTensorVariable with the specified name, dims, shape, and dtype.
"""
return XTensorType(dtype=dtype, dims=dims, shape=shape)(name=name) return XTensorType(dtype=dtype, dims=dims, shape=shape)(name=name)
...@@ -208,6 +226,8 @@ _XTensorTypeType = TypeVar("_XTensorTypeType", bound=XTensorType) ...@@ -208,6 +226,8 @@ _XTensorTypeType = TypeVar("_XTensorTypeType", bound=XTensorType)
class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
"""Variable of XTensorType."""
# These can't work because Python requires native output types # These can't work because Python requires native output types
def __bool__(self): def __bool__(self):
raise TypeError( raise TypeError(
...@@ -406,7 +426,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -406,7 +426,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def copy(self, name: str | None = None): def copy(self, name: str | None = None):
out = px.math.identity(self) out = px.math.identity(self)
out.name = name # type: ignore out.name = name
return out return out
def astype(self, dtype): def astype(self, dtype):
...@@ -751,6 +771,8 @@ class XTensorConstantSignature(TensorConstantSignature): ...@@ -751,6 +771,8 @@ class XTensorConstantSignature(TensorConstantSignature):
class XTensorConstant(XTensorVariable, Constant[_XTensorTypeType]): class XTensorConstant(XTensorVariable, Constant[_XTensorTypeType]):
"""Constant of XtensorType."""
def __init__(self, type: _XTensorTypeType, data, name=None): def __init__(self, type: _XTensorTypeType, data, name=None):
data_shape = np.shape(data) data_shape = np.shape(data)
...@@ -776,6 +798,8 @@ XTensorType.constant_type = XTensorConstant # type: ignore ...@@ -776,6 +798,8 @@ XTensorType.constant_type = XTensorConstant # type: ignore
def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
"""Convert a constant value to an XTensorConstant."""
x_dims: tuple[str, ...] x_dims: tuple[str, ...]
if XARRAY_AVAILABLE and isinstance(x, xr.DataArray): if XARRAY_AVAILABLE and isinstance(x, xr.DataArray):
xarray_dims = x.dims xarray_dims = x.dims
...@@ -819,7 +843,20 @@ if XARRAY_AVAILABLE: ...@@ -819,7 +843,20 @@ if XARRAY_AVAILABLE:
return xtensor_constant(x, **kwargs) return xtensor_constant(x, **kwargs)
def as_xtensor(x, name=None, dims: Sequence[str] | None = None): def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None):
"""Convert a variable or data to an XTensorVariable.
Parameters
----------
x : Variable or data
dims: Sequence[str] or None, optional
If dims are provided, TensorVariable (or data) will be converted to an XTensorVariable with those dims.
XTensorVariables will be returned as is, if the dims match. Otherwise, a ValueError is raised.
If dims are not provided, and the data is not a scalar, an XTensorVariable or xarray.DataArray, an error is raised.
name: str or None, optional
Name of the resulting XTensorVariable.
"""
if isinstance(x, Apply): if isinstance(x, Apply):
if len(x.outputs) != 1: if len(x.outputs) != 1:
raise ValueError( raise ValueError(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论