提交 4ce092fe authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add xtensor docs

上级 ebfac59e
...@@ -25,6 +25,7 @@ Modules ...@@ -25,6 +25,7 @@ Modules
sparse/index sparse/index
tensor/index tensor/index
typed_list typed_list
xtensor/index
.. module:: pytensor .. module:: pytensor
:platform: Unix, Windows :platform: Unix, Windows
......
(libdoc_xtensor)=
# `xtensor` -- XTensor operations
This module implements as abstraction layer on regular tensor operations, that behaves like Xarray.
A new type {class}`pytensor.xtensor.type.XTensorType`, generalizes the {class}`pytensor.tensor.TensorType`
with the addition of a `dims` attribute, that labels the dimensions of the tensor.
Variables of XTensorType (i.e., {class}`pytensor.xtensor.type.XTensorVariable`s) are the symbolic counterpart
to xarray DataArray objects.
The module implements several PyTensor operations {class}`pytensor.xtensor.basic.XOp`s, whose signature mimics that of
xarray (and xarray_einstats) DataArray operations. These operations, unlike most regular PyTensor operations, cannot
be directly evaluated, but require a rewrite (lowering) into a regular tensor graph that can itself be evaluated as usual.
Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray.
If the existing XOps can be composed to produce the desired result, then we can use them directly.
## Coordinates
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`.
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor.
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor.
## Example
```{testcode}
import pytensor.tensor as pt
import pytensor.xtensor as ptx
a = pt.tensor("a", shape=(3,))
b = pt.tensor("b", shape=(4,))
ax = ptx.as_xtensor(a, dims=["x"])
bx = ptx.as_xtensor(b, dims=["y"])
zx = ax + bx
assert zx.type == ptx.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4))
z = zx.values
z.dprint()
```
```{testoutput}
TensorFromXTensor [id A]
└─ XElemwise{scalar_op=Add()} [id B]
├─ XTensorFromTensor{dims=('x',)} [id C]
│ └─ a [id D]
└─ XTensorFromTensor{dims=('y',)} [id E]
└─ b [id F]
```
Once we compile the graph, no XOps are left.
```{testcode}
import pytensor
with pytensor.config.change_flags(optimizer_verbose=True):
fn = pytensor.function([a, b], z)
```
```{testoutput}
rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0)
rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None
rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None
rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0)
```
```{testcode}
fn.dprint()
```
```{testoutput}
Add [id A] 2
├─ ExpandDims{axis=1} [id B] 1
│ └─ a [id C]
└─ ExpandDims{axis=0} [id D] 0
└─ b [id E]
```
## Index
:::{toctree}
:maxdepth: 1
module_functions
math
linalg
random
type
:::
\ No newline at end of file
(libdoc_xtensor_linalg)=
# `xtensor.linalg` -- Linear algebra operations
```{eval-rst}
.. automodule:: pytensor.xtensor.linalg
:members:
```
(libdoc_xtensor_math)=
# `xtensor.math` Mathematical operations
```{eval-rst}
.. automodule:: pytensor.xtensor.math
:members:
:exclude-members: XDot, dot
```
\ No newline at end of file
(libdoc_xtensor_module_function)=
# `xtensor` -- Module level operations
```{eval-rst}
.. automodule:: pytensor.xtensor
:members: broadcast, concat, dot, full_like, ones_like, zeros_like
```
(libdoc_xtensor_random)=
# `xtensor.random` Random number generator operations
```{eval-rst}
.. automodule:: pytensor.xtensor.random
:members:
```
(libdoc_xtenor_type)=
# `xtensor.type` -- Types and Variables
## XTensorVariable creation functions
```{eval-rst}
.. automodule:: pytensor.xtensor.type
:members: xtensor, xtensor_constant, as_xtensor
```
## XTensor Type and Variable classes
```{eval-rst}
.. automodule:: pytensor.xtensor.type
:noindex:
:members: XTensorType, XTensorVariable, XTensorConstant
```
import warnings import warnings
import pytensor.xtensor.rewriting import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg, random from pytensor.xtensor import linalg, math, random
from pytensor.xtensor.math import dot from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like
from pytensor.xtensor.type import ( from pytensor.xtensor.type import (
......
...@@ -11,17 +11,31 @@ def cholesky( ...@@ -11,17 +11,31 @@ def cholesky(
lower: bool = True, lower: bool = True,
*, *,
check_finite: bool = False, check_finite: bool = False,
overwrite_a: bool = False,
on_error: Literal["raise", "nan"] = "raise", on_error: Literal["raise", "nan"] = "raise",
dims: Sequence[str], dims: Sequence[str],
): ):
"""Compute the Cholesky decomposition of an XTensorVariable.
Parameters
----------
x : XTensorVariable
The input variable to decompose.
lower : bool, optional
Whether to return the lower triangular matrix. Default is True.
check_finite : bool, optional
Whether to check that the input is finite. Default is False.
on_error : {'raise', 'nan'}, optional
What to do if the input is not positive definite. If 'raise', an error is raised.
If 'nan', the output will contain NaNs. Default is 'raise'.
dims : Sequence[str]
The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
"""
if len(dims) != 2: if len(dims) != 2:
raise ValueError(f"Cholesky needs two dims, got {len(dims)}") raise ValueError(f"Cholesky needs two dims, got {len(dims)}")
core_op = Cholesky( core_op = Cholesky(
lower=lower, lower=lower,
check_finite=check_finite, check_finite=check_finite,
overwrite_a=overwrite_a,
on_error=on_error, on_error=on_error,
) )
core_dims = ( core_dims = (
...@@ -40,6 +54,30 @@ def solve( ...@@ -40,6 +54,30 @@ def solve(
lower: bool = False, lower: bool = False,
check_finite: bool = False, check_finite: bool = False,
): ):
"""Solve a system of linear equations using XTensorVariables.
Parameters
----------
a : XTensorVariable
The left hand-side xtensor.
b : XTensorVariable
The right-hand side xtensor.
dims : Sequence[str]
The core dimensions over which to solve the linear equations.
If length is 2, we are solving a matrix-vector equation,
and the two dimensions should be present in `a`, but only one in `b`.
If length is 3, we are solving a matrix-matrix equation,
and two dimensions should be present in `a`, two in `b`, and only one should be shared.
In both cases the shared dimension will not appear in the output.
assume_a : str, optional
The type of matrix `a` is assumed to be. Default is 'gen' (general).
Options are ["gen", "sym", "her", "pos", "tridiagonal", "banded"].
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
lower : bool, optional
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
check_finite : bool, optional
Whether to check that the input is finite. Default is False.
"""
a, b = as_xtensor(a), as_xtensor(b) a, b = as_xtensor(a), as_xtensor(b)
input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]] input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]]
output_core_dims: tuple[tuple[str] | tuple[str, str]] output_core_dims: tuple[tuple[str] | tuple[str, str]]
......
差异被折叠。
...@@ -10,7 +10,7 @@ from pytensor.xtensor.type import as_xtensor ...@@ -10,7 +10,7 @@ from pytensor.xtensor.type import as_xtensor
from pytensor.xtensor.vectorization import XRV from pytensor.xtensor.vectorization import XRV
def _as_xrv( def as_xrv(
core_op: RandomVariable, core_op: RandomVariable,
core_inps_dims_map: Sequence[Sequence[int]] | None = None, core_inps_dims_map: Sequence[Sequence[int]] | None = None,
core_out_dims_map: Sequence[int] | None = None, core_out_dims_map: Sequence[int] | None = None,
...@@ -52,7 +52,6 @@ def _as_xrv( ...@@ -52,7 +52,6 @@ def _as_xrv(
max((entry + 1 for entry in core_out_dims_map), default=0), max((entry + 1 for entry in core_out_dims_map), default=0),
) )
@wraps(core_op)
def xrv_constructor( def xrv_constructor(
*params, *params,
core_dims: Sequence[str] | str | None = None, core_dims: Sequence[str] | str | None = None,
...@@ -93,38 +92,151 @@ def _as_xrv( ...@@ -93,38 +92,151 @@ def _as_xrv(
return xrv_constructor return xrv_constructor
bernoulli = _as_xrv(ptr.bernoulli) def _as_xrv(core_op: RandomVariable, name: str | None = None):
beta = _as_xrv(ptr.beta) """A decorator to create a new XRV and document it in sphinx."""
betabinom = _as_xrv(ptr.betabinom) xrv_constructor = as_xrv(core_op, name=name)
binomial = _as_xrv(ptr.binomial)
categorical = _as_xrv(ptr.categorical) def decorator(func):
cauchy = _as_xrv(ptr.cauchy) @wraps(as_xrv)
dirichlet = _as_xrv(ptr.dirichlet) def wrapper(*args, **kwargs):
exponential = _as_xrv(ptr.exponential) return xrv_constructor(*args, **kwargs)
gamma = _as_xrv(ptr._gamma)
gengamma = _as_xrv(ptr.gengamma) wrapper.__doc__ = f"XRV version of {core_op.name} for XTensorVariables"
geometric = _as_xrv(ptr.geometric)
gumbel = _as_xrv(ptr.gumbel) return wrapper
halfcauchy = _as_xrv(ptr.halfcauchy)
halfnormal = _as_xrv(ptr.halfnormal) return decorator
hypergeometric = _as_xrv(ptr.hypergeometric)
integers = _as_xrv(ptr.integers)
invgamma = _as_xrv(ptr.invgamma) @_as_xrv(ptr.bernoulli)
laplace = _as_xrv(ptr.laplace) def bernoulli(): ...
logistic = _as_xrv(ptr.logistic)
lognormal = _as_xrv(ptr.lognormal)
multinomial = _as_xrv(ptr.multinomial) @_as_xrv(ptr.beta)
nbinom = negative_binomial = _as_xrv(ptr.negative_binomial) def beta(): ...
normal = _as_xrv(ptr.normal)
pareto = _as_xrv(ptr.pareto)
poisson = _as_xrv(ptr.poisson) @_as_xrv(ptr.betabinom)
t = _as_xrv(ptr.t) def betabinom(): ...
triangular = _as_xrv(ptr.triangular)
truncexpon = _as_xrv(ptr.truncexpon)
uniform = _as_xrv(ptr.uniform) @_as_xrv(ptr.binomial)
vonmises = _as_xrv(ptr.vonmises) def binomial(): ...
wald = _as_xrv(ptr.wald)
weibull = _as_xrv(ptr.weibull)
@_as_xrv(ptr.categorical)
def categorical(): ...
@_as_xrv(ptr.cauchy)
def cauchy(): ...
@_as_xrv(ptr.dirichlet)
def dirichlet(): ...
@_as_xrv(ptr.exponential)
def exponential(): ...
@_as_xrv(ptr._gamma)
def gamma(): ...
@_as_xrv(ptr.gengamma)
def gengamma(): ...
@_as_xrv(ptr.geometric)
def geometric(): ...
@_as_xrv(ptr.gumbel)
def gumbel(): ...
@_as_xrv(ptr.halfcauchy)
def halfcauchy(): ...
@_as_xrv(ptr.halfnormal)
def halfnormal(): ...
@_as_xrv(ptr.hypergeometric)
def hypergeometric(): ...
@_as_xrv(ptr.integers)
def integers(): ...
@_as_xrv(ptr.invgamma)
def invgamma(): ...
@_as_xrv(ptr.laplace)
def laplace(): ...
@_as_xrv(ptr.logistic)
def logistic(): ...
@_as_xrv(ptr.lognormal)
def lognormal(): ...
@_as_xrv(ptr.multinomial)
def multinomial(): ...
@_as_xrv(ptr.negative_binomial)
def negative_binomial(): ...
nbinom = negative_binomial
@_as_xrv(ptr.normal)
def normal(): ...
@_as_xrv(ptr.pareto)
def pareto(): ...
@_as_xrv(ptr.poisson)
def poisson(): ...
@_as_xrv(ptr.t)
def t(): ...
@_as_xrv(ptr.triangular)
def triangular(): ...
@_as_xrv(ptr.truncexpon)
def truncexpon(): ...
@_as_xrv(ptr.uniform)
def uniform(): ...
@_as_xrv(ptr.vonmises)
def vonmises(): ...
@_as_xrv(ptr.wald)
def wald(): ...
@_as_xrv(ptr.weibull)
def weibull(): ...
def multivariate_normal( def multivariate_normal(
...@@ -136,6 +248,7 @@ def multivariate_normal( ...@@ -136,6 +248,7 @@ def multivariate_normal(
rng=None, rng=None,
method: Literal["cholesky", "svd", "eigh"] = "cholesky", method: Literal["cholesky", "svd", "eigh"] = "cholesky",
): ):
"""Multivariate normal random variable."""
mean = as_xtensor(mean) mean = as_xtensor(mean)
if len(core_dims) != 2: if len(core_dims) != 2:
raise ValueError( raise ValueError(
...@@ -147,7 +260,7 @@ def multivariate_normal( ...@@ -147,7 +260,7 @@ def multivariate_normal(
if core_dims[0] not in mean.type.dims: if core_dims[0] not in mean.type.dims:
core_dims = core_dims[::-1] core_dims = core_dims[::-1]
xop = _as_xrv(ptr.MvNormalRV(method=method)) xop = as_xrv(ptr.MvNormalRV(method=method))
return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng) return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng)
......
# XTensor Module
This module implements as abstraction layer on regular tensor operations, that behaves like Xarray.
A new type `XTensorType`, generalizes the `TensorType` with the addition of a `dims` attribute,
that labels the dimensions of the tensor.
Variables of `XTensorType` (i.e., `XTensorVariable`s) are the symbolic counterpart to xarray DataArray objects.
The module implements several PyTensor operations `XOp`s, whose signature mimics that of xarray (and xarray_einstants) DataArray operations.
These operations, unlike most regular PyTensor operations, cannot be directly evaluated, but require a rewrite (lowering) into
a regular tensor graph that can itself be evaluated as usual.
Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray.
If the existing XOps can be composed to produce the desired result, then we can use them directly.
## Coordinates
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`.
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor.
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor.
## Example
```python
import pytensor.tensor as pt
import pytensor.xtensor as px
a = pt.tensor("a", shape=(3,))
b = pt.tensor("b", shape=(4,))
ax = px.as_xtensor(a, dims=["x"])
bx = px.as_xtensor(b, dims=["y"])
zx = ax + bx
assert zx.type == px.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4))
z = zx.values
z.dprint()
# TensorFromXTensor [id A]
# └─ XElemwise{scalar_op=Add()} [id B]
# ├─ XTensorFromTensor{dims=('x',)} [id C]
# │ └─ a [id D]
# └─ XTensorFromTensor{dims=('y',)} [id E]
# └─ b [id F]
```
Once we compile the graph, no `XOp`s are left.
```python
import pytensor
with pytensor.config.change_flags(optimizer_verbose=True):
fn = pytensor.function([a, b], z)
# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0)
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0)
fn.dprint()
# Add [id A] 2
# ├─ ExpandDims{axis=1} [id B] 1
# │ └─ a [id C]
# └─ ExpandDims{axis=0} [id D] 0
# └─ b [id E]
```
...@@ -303,6 +303,35 @@ class Concat(XOp): ...@@ -303,6 +303,35 @@ class Concat(XOp):
def concat(xtensors, dim: str): def concat(xtensors, dim: str):
"""Concatenate a sequence of XTensorVariables along a specified dimension.
Parameters
----------
xtensors : Sequence of XTensorVariable
The tensors to concatenate.
dim : str
The dimension along which to concatenate the tensors.
Returns
-------
XTensorVariable
Example
-------
.. testcode::
from pytensor.xtensor import as_xtensor, xtensor, concat
x = xtensor("x", shape=(2, 3), dims=("a", "b"))
zero = as_xtensor([0], dims=("a"))
out = concat([zero, x, zero], dim="a")
assert out.type.dims == ("a", "b")
assert out.type.shape == (4, 3)
"""
return Concat(dim=dim)(*xtensors) return Concat(dim=dim)(*xtensors)
......
...@@ -201,6 +201,24 @@ def xtensor( ...@@ -201,6 +201,24 @@ def xtensor(
shape: Sequence[int | None] | None = None, shape: Sequence[int | None] | None = None,
dtype: str | np.dtype = "floatX", dtype: str | np.dtype = "floatX",
): ):
"""Create an XTensorVariable.
Parameters
----------
name : str or None, optional
The name of the variable
dims : Sequence[str]
The names of the dimensions of the tensor
shape : Sequence[int | None] or None, optional
The shape of the tensor. If None, defaults to a shape with None for each dimension.
dtype : str or np.dtype, optional
The data type of the tensor. Defaults to 'floatX' (config.floatX).
Returns
-------
XTensorVariable
A new XTensorVariable with the specified name, dims, shape, and dtype.
"""
return XTensorType(dtype=dtype, dims=dims, shape=shape)(name=name) return XTensorType(dtype=dtype, dims=dims, shape=shape)(name=name)
...@@ -208,6 +226,8 @@ _XTensorTypeType = TypeVar("_XTensorTypeType", bound=XTensorType) ...@@ -208,6 +226,8 @@ _XTensorTypeType = TypeVar("_XTensorTypeType", bound=XTensorType)
class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
"""Variable of XTensorType."""
# These can't work because Python requires native output types # These can't work because Python requires native output types
def __bool__(self): def __bool__(self):
raise TypeError( raise TypeError(
...@@ -406,7 +426,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -406,7 +426,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def copy(self, name: str | None = None): def copy(self, name: str | None = None):
out = px.math.identity(self) out = px.math.identity(self)
out.name = name # type: ignore out.name = name
return out return out
def astype(self, dtype): def astype(self, dtype):
...@@ -751,6 +771,8 @@ class XTensorConstantSignature(TensorConstantSignature): ...@@ -751,6 +771,8 @@ class XTensorConstantSignature(TensorConstantSignature):
class XTensorConstant(XTensorVariable, Constant[_XTensorTypeType]): class XTensorConstant(XTensorVariable, Constant[_XTensorTypeType]):
"""Constant of XtensorType."""
def __init__(self, type: _XTensorTypeType, data, name=None): def __init__(self, type: _XTensorTypeType, data, name=None):
data_shape = np.shape(data) data_shape = np.shape(data)
...@@ -776,6 +798,8 @@ XTensorType.constant_type = XTensorConstant # type: ignore ...@@ -776,6 +798,8 @@ XTensorType.constant_type = XTensorConstant # type: ignore
def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
"""Convert a constant value to an XTensorConstant."""
x_dims: tuple[str, ...] x_dims: tuple[str, ...]
if XARRAY_AVAILABLE and isinstance(x, xr.DataArray): if XARRAY_AVAILABLE and isinstance(x, xr.DataArray):
xarray_dims = x.dims xarray_dims = x.dims
...@@ -819,7 +843,20 @@ if XARRAY_AVAILABLE: ...@@ -819,7 +843,20 @@ if XARRAY_AVAILABLE:
return xtensor_constant(x, **kwargs) return xtensor_constant(x, **kwargs)
def as_xtensor(x, name=None, dims: Sequence[str] | None = None): def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None):
"""Convert a variable or data to an XTensorVariable.
Parameters
----------
x : Variable or data
dims: Sequence[str] or None, optional
If dims are provided, TensorVariable (or data) will be converted to an XTensorVariable with those dims.
XTensorVariables will be returned as is, if the dims match. Otherwise, a ValueError is raised.
If dims are not provided, and the data is not a scalar, an XTensorVariable or xarray.DataArray, an error is raised.
name: str or None, optional
Name of the resulting XTensorVariable.
"""
if isinstance(x, Apply): if isinstance(x, Apply):
if len(x.outputs) != 1: if len(x.outputs) != 1:
raise ValueError( raise ValueError(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论