提交 5024d54e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add docstrings to more XTensorVariable methods

Also remove broadcast which is not a method in Xarray
上级 fdb40877
......@@ -366,6 +366,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# https://docs.xarray.dev/en/latest/api.html#id1
@property
def values(self) -> TensorVariable:
"""Convert to a TensorVariable with the same data."""
return typing.cast(TensorVariable, px.basic.tensor_from_xtensor(self))
# Can't provide property data because that's already taken by Constants!
......@@ -373,14 +374,17 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
@property
def coords(self):
"""Not implemented."""
raise NotImplementedError("coords not implemented for XTensorVariable")
@property
def dims(self) -> tuple[str, ...]:
"""The names of the dimensions of the variable."""
return self.type.dims
@property
def sizes(self) -> dict[str, TensorVariable]:
"""The sizes of the dimensions of the variable."""
return dict(zip(self.dims, self.shape))
@property
......@@ -392,18 +396,22 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
@property
def ndim(self) -> int:
"""The number of dimensions of the variable."""
return self.type.ndim
@property
def shape(self) -> tuple[TensorVariable, ...]:
"""The shape of the variable."""
return tuple(px.basic.tensor_from_xtensor(self).shape) # type: ignore
@property
def size(self) -> TensorVariable:
"""The total number of elements in the variable."""
return typing.cast(TensorVariable, variadic_mul(*self.shape))
@property
def dtype(self):
def dtype(self) -> str:
"""The data type of the variable."""
return self.type.dtype
@property
......@@ -414,6 +422,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# DataArray contents
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents
def rename(self, new_name_or_name_dict=None, **names):
"""Rename the variable or its dimension(s)."""
if isinstance(new_name_or_name_dict, str):
new_name = new_name_or_name_dict
name_dict = None
......@@ -425,31 +434,41 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return new_out
def copy(self, name: str | None = None):
"""Create a copy of the variable.
This is just an identity operation, as XTensorVariables are immutable.
"""
out = px.math.identity(self)
out.name = name
return out
def astype(self, dtype):
"""Convert the variable to a different data type."""
return px.math.cast(self, dtype)
def item(self):
"""Not implemented."""
raise NotImplementedError("item not implemented for XTensorVariable")
# Indexing
# https://docs.xarray.dev/en/latest/api.html#id2
def __setitem__(self, idx, value):
"""Not implemented. Use `x[idx].set(value)` or `x[idx].inc(value)` instead."""
raise TypeError(
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
)
@property
def loc(self):
"""Not implemented."""
raise NotImplementedError("loc not implemented for XTensorVariable")
def sel(self, *args, **kwargs):
"""Not implemented."""
raise NotImplementedError("sel not implemented for XTensorVariable")
def __getitem__(self, idx):
"""Index the variable positionally."""
if isinstance(idx, dict):
return self.isel(idx)
......@@ -465,6 +484,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
**indexers_kwargs,
):
"""Index the variable along the specified dimension(s)."""
if indexers_kwargs:
if indexers is not None:
raise ValueError(
......@@ -505,6 +525,48 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return px.indexing.index(self, *indices)
def set(self, value):
"""Return a copy of the variable indexed by self with the indexed values set to y.
The original variable is not modified.
Raises
------
ValueError
If self is not the result of an index operation
Examples
--------
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x[:, idx].set(1)
print(out.eval())
.. testoutput::
[[1 0]
[0 1]]
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x.isel({"b": idx}).set(-1)
print(out.eval())
.. testoutput::
[[-1 0]
[ 0 -1]]
"""
if not (
self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
):
......@@ -516,6 +578,48 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return px.indexing.index_assignment(x, value, *idxs)
def inc(self, value):
"""Return a copy of the variable indexed by self with the indexed values incremented by value.
The original variable is not modified.
Raises
------
ValueError
If self is not the result of an index operation
Examples
--------
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x[:, idx].inc(1)
print(out.eval())
.. testoutput::
[[2 1]
[1 2]]
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x.isel({"b": idx}).inc(-1)
print(out.eval())
.. testoutput::
[[0 1]
[1 0]]
"""
if not (
self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
):
......@@ -579,7 +683,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
drop=None,
axis: int | Sequence[int] | None = None,
):
"""Remove dimensions of size 1 from an XTensorVariable.
"""Remove dimensions of size 1.
Parameters
----------
......@@ -606,24 +710,21 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
axis: int | Sequence[int] | None = None,
**dim_kwargs,
):
"""Add one or more new dimensions to the tensor.
"""Add one or more new dimensions to the variable.
Parameters
----------
dim : str | Sequence[str] | dict[str, int | Sequence] | None
If str or sequence of str, new dimensions with size 1.
If dict, keys are dimension names and values are either:
- int: the new size
- sequence: coordinates (length determines size)
- int: the new size
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool, default: True
Currently ignored. Reserved for future coordinate support.
In xarray, when True (default), creates a coordinate index for the new dimension
with values from 0 to size-1. When False, no coordinate index is created.
Ignored by PyTensor
axis : int | Sequence[int] | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
By default (None), new dimensions are inserted at the beginning (axis=0).
Symbolic axis is not supported yet.
Negative values count from the end.
**dim_kwargs : int | Sequence
Alternative to `dim` dict. Only used if `dim` is None.
......@@ -643,65 +744,75 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max):
"""Clip the values of the variable to a specified range."""
return px.math.clip(self, min, max)
def conj(self):
"""Return the complex conjugate of the variable."""
return px.math.conj(self)
@property
def imag(self):
"""Return the imaginary part of the variable."""
return px.math.imag(self)
@property
def real(self):
"""Return the real part of the variable."""
return px.math.real(self)
@property
def T(self):
"""Return the full transpose of the tensor.
"""Return the full transpose of the variable.
This is equivalent to calling transpose() with no arguments.
Returns
-------
XTensorVariable
Fully transposed tensor.
"""
return self.transpose()
# Aggregation
# https://docs.xarray.dev/en/latest/api.html#id6
def all(self, dim=None):
"""Reduce the variable by applying `all` along some dimension(s)."""
return px.reduction.all(self, dim)
def any(self, dim=None):
"""Reduce the variable by applying `any` along some dimension(s)."""
return px.reduction.any(self, dim)
def max(self, dim=None):
"""Compute the maximum along the given dimension(s)."""
return px.reduction.max(self, dim)
def min(self, dim=None):
"""Compute the minimum along the given dimension(s)."""
return px.reduction.min(self, dim)
def mean(self, dim=None):
"""Compute the mean along the given dimension(s)."""
return px.reduction.mean(self, dim)
def prod(self, dim=None):
"""Compute the product along the given dimension(s)."""
return px.reduction.prod(self, dim)
def sum(self, dim=None):
"""Compute the sum along the given dimension(s)."""
return px.reduction.sum(self, dim)
def std(self, dim=None, ddof=0):
"""Compute the standard deviation along the given dimension(s)."""
return px.reduction.std(self, dim, ddof=ddof)
def var(self, dim=None, ddof=0):
"""Compute the variance along the given dimension(s)."""
return px.reduction.var(self, dim, ddof=ddof)
def cumsum(self, dim=None):
"""Compute the cumulative sum along the given dimension(s)."""
return px.reduction.cumsum(self, dim)
def cumprod(self, dim=None):
"""Compute the cumulative product along the given dimension(s)."""
return px.reduction.cumprod(self, dim)
def diff(self, dim, n=1):
......@@ -720,7 +831,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
*dim: str | EllipsisType,
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
):
"""Transpose dimensions of the tensor.
"""Transpose the dimensions of the variable.
Parameters
----------
......@@ -729,6 +840,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
Can use ellipsis (...) to represent remaining dimensions.
missing_dims : {"raise", "warn", "ignore"}, default="raise"
How to handle dimensions that don't exist in the tensor:
- "raise": Raise an error if any dimensions don't exist
- "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist
......@@ -747,21 +859,38 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return px.shape.transpose(self, *dim, missing_dims=missing_dims)
def stack(self, dim, **dims):
"""Stack existing dimensions into a single new dimension."""
return px.shape.stack(self, dim, **dims)
def unstack(self, dim, **dims):
"""Unstack a dimension into multiple dimensions of a given size.
Because XTensorVariables don't have coords, this operation requires the sizes of each unstacked dimension to be specified.
Also, unstacked dims will follow a C-style order, regardless of the order of the original dimensions.
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[1, 2], [3, 4]], dims=("a", "b"))
stacked_cumsum = x.stack({"c": ["a", "b"]}).cumsum("c")
unstacked_cumsum = stacked_cumsum.unstack({"c": x.sizes})
print(unstacked_cumsum.eval())
.. testoutput::
[[ 1 3]
[ 6 10]]
"""
return px.shape.unstack(self, dim, **dims)
def dot(self, other, dim=None):
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
"""Generalized dot product with another XTensorVariable."""
return px.math.dot(self, other, dim=dim)
def broadcast(self, *others, exclude=None):
"""Broadcast this tensor against other XTensorVariables."""
return px.shape.broadcast(self, *others, exclude=exclude)
def broadcast_like(self, other, exclude=None):
"""Broadcast this tensor against another XTensorVariable."""
"""Broadcast against another XTensorVariable."""
_, self_bcast = px.shape.broadcast(other, self, exclude=exclude)
return self_bcast
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论