提交 5024d54e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add docstrings to more XTensorVariable methods

Also remove broadcast which is not a method in Xarray
上级 fdb40877
...@@ -366,6 +366,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -366,6 +366,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# https://docs.xarray.dev/en/latest/api.html#id1 # https://docs.xarray.dev/en/latest/api.html#id1
@property @property
def values(self) -> TensorVariable: def values(self) -> TensorVariable:
"""Convert to a TensorVariable with the same data."""
return typing.cast(TensorVariable, px.basic.tensor_from_xtensor(self)) return typing.cast(TensorVariable, px.basic.tensor_from_xtensor(self))
# Can't provide property data because that's already taken by Constants! # Can't provide property data because that's already taken by Constants!
...@@ -373,14 +374,17 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -373,14 +374,17 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
@property @property
def coords(self): def coords(self):
"""Not implemented."""
raise NotImplementedError("coords not implemented for XTensorVariable") raise NotImplementedError("coords not implemented for XTensorVariable")
@property @property
def dims(self) -> tuple[str, ...]: def dims(self) -> tuple[str, ...]:
"""The names of the dimensions of the variable."""
return self.type.dims return self.type.dims
@property @property
def sizes(self) -> dict[str, TensorVariable]: def sizes(self) -> dict[str, TensorVariable]:
"""The sizes of the dimensions of the variable."""
return dict(zip(self.dims, self.shape)) return dict(zip(self.dims, self.shape))
@property @property
...@@ -392,18 +396,22 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -392,18 +396,22 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# https://docs.xarray.dev/en/latest/api.html#ndarray-attributes # https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
@property @property
def ndim(self) -> int: def ndim(self) -> int:
"""The number of dimensions of the variable."""
return self.type.ndim return self.type.ndim
@property @property
def shape(self) -> tuple[TensorVariable, ...]: def shape(self) -> tuple[TensorVariable, ...]:
"""The shape of the variable."""
return tuple(px.basic.tensor_from_xtensor(self).shape) # type: ignore return tuple(px.basic.tensor_from_xtensor(self).shape) # type: ignore
@property @property
def size(self) -> TensorVariable: def size(self) -> TensorVariable:
"""The total number of elements in the variable."""
return typing.cast(TensorVariable, variadic_mul(*self.shape)) return typing.cast(TensorVariable, variadic_mul(*self.shape))
@property @property
def dtype(self): def dtype(self) -> str:
"""The data type of the variable."""
return self.type.dtype return self.type.dtype
@property @property
...@@ -414,6 +422,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -414,6 +422,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# DataArray contents # DataArray contents
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents # https://docs.xarray.dev/en/latest/api.html#dataarray-contents
def rename(self, new_name_or_name_dict=None, **names): def rename(self, new_name_or_name_dict=None, **names):
"""Rename the variable or its dimension(s)."""
if isinstance(new_name_or_name_dict, str): if isinstance(new_name_or_name_dict, str):
new_name = new_name_or_name_dict new_name = new_name_or_name_dict
name_dict = None name_dict = None
...@@ -425,31 +434,41 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -425,31 +434,41 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return new_out return new_out
def copy(self, name: str | None = None): def copy(self, name: str | None = None):
"""Create a copy of the variable.
This is just an identity operation, as XTensorVariables are immutable.
"""
out = px.math.identity(self) out = px.math.identity(self)
out.name = name out.name = name
return out return out
def astype(self, dtype): def astype(self, dtype):
"""Convert the variable to a different data type."""
return px.math.cast(self, dtype) return px.math.cast(self, dtype)
def item(self): def item(self):
"""Not implemented."""
raise NotImplementedError("item not implemented for XTensorVariable") raise NotImplementedError("item not implemented for XTensorVariable")
# Indexing # Indexing
# https://docs.xarray.dev/en/latest/api.html#id2 # https://docs.xarray.dev/en/latest/api.html#id2
def __setitem__(self, idx, value): def __setitem__(self, idx, value):
"""Not implemented. Use `x[idx].set(value)` or `x[idx].inc(value)` instead."""
raise TypeError( raise TypeError(
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead." "XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
) )
@property @property
def loc(self): def loc(self):
"""Not implemented."""
raise NotImplementedError("loc not implemented for XTensorVariable") raise NotImplementedError("loc not implemented for XTensorVariable")
def sel(self, *args, **kwargs): def sel(self, *args, **kwargs):
"""Not implemented."""
raise NotImplementedError("sel not implemented for XTensorVariable") raise NotImplementedError("sel not implemented for XTensorVariable")
def __getitem__(self, idx): def __getitem__(self, idx):
"""Index the variable positionally."""
if isinstance(idx, dict): if isinstance(idx, dict):
return self.isel(idx) return self.isel(idx)
...@@ -465,6 +484,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -465,6 +484,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
missing_dims: Literal["raise", "warn", "ignore"] = "raise", missing_dims: Literal["raise", "warn", "ignore"] = "raise",
**indexers_kwargs, **indexers_kwargs,
): ):
"""Index the variable along the specified dimension(s)."""
if indexers_kwargs: if indexers_kwargs:
if indexers is not None: if indexers is not None:
raise ValueError( raise ValueError(
...@@ -505,6 +525,48 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -505,6 +525,48 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return px.indexing.index(self, *indices) return px.indexing.index(self, *indices)
def set(self, value): def set(self, value):
"""Return a copy of the variable indexed by self with the indexed values set to y.
The original variable is not modified.
Raises
------
ValueError
If self is not the result of an index operation
Examples
--------
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x[:, idx].set(1)
print(out.eval())
.. testoutput::
[[1 0]
[0 1]]
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x.isel({"b": idx}).set(-1)
print(out.eval())
.. testoutput::
[[-1 0]
[ 0 -1]]
"""
if not ( if not (
self.owner is not None and isinstance(self.owner.op, px.indexing.Index) self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
): ):
...@@ -516,6 +578,48 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -516,6 +578,48 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return px.indexing.index_assignment(x, value, *idxs) return px.indexing.index_assignment(x, value, *idxs)
def inc(self, value): def inc(self, value):
"""Return a copy of the variable indexed by self with the indexed values incremented by value.
The original variable is not modified.
Raises
------
ValueError
If self is not the result of an index operation
Examples
--------
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x[:, idx].inc(1)
print(out.eval())
.. testoutput::
[[2 1]
[1 2]]
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x.isel({"b": idx}).inc(-1)
print(out.eval())
.. testoutput::
[[0 1]
[1 0]]
"""
if not ( if not (
self.owner is not None and isinstance(self.owner.op, px.indexing.Index) self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
): ):
...@@ -579,7 +683,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -579,7 +683,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
drop=None, drop=None,
axis: int | Sequence[int] | None = None, axis: int | Sequence[int] | None = None,
): ):
"""Remove dimensions of size 1 from an XTensorVariable. """Remove dimensions of size 1.
Parameters Parameters
---------- ----------
...@@ -606,24 +710,21 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -606,24 +710,21 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
axis: int | Sequence[int] | None = None, axis: int | Sequence[int] | None = None,
**dim_kwargs, **dim_kwargs,
): ):
"""Add one or more new dimensions to the tensor. """Add one or more new dimensions to the variable.
Parameters Parameters
---------- ----------
dim : str | Sequence[str] | dict[str, int | Sequence] | None dim : str | Sequence[str] | dict[str, int | Sequence] | None
If str or sequence of str, new dimensions with size 1. If str or sequence of str, new dimensions with size 1.
If dict, keys are dimension names and values are either: If dict, keys are dimension names and values are either:
- int: the new size
- sequence: coordinates (length determines size) - int: the new size
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool, default: True create_index_for_new_dim : bool, default: True
Currently ignored. Reserved for future coordinate support. Ignored by PyTensor
In xarray, when True (default), creates a coordinate index for the new dimension
with values from 0 to size-1. When False, no coordinate index is created.
axis : int | Sequence[int] | None, default: None axis : int | Sequence[int] | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s). Not implemented yet. In xarray, specifies where to insert the new dimension(s).
By default (None), new dimensions are inserted at the beginning (axis=0). By default (None), new dimensions are inserted at the beginning (axis=0).
Symbolic axis is not supported yet.
Negative values count from the end.
**dim_kwargs : int | Sequence **dim_kwargs : int | Sequence
Alternative to `dim` dict. Only used if `dim` is None. Alternative to `dim` dict. Only used if `dim` is None.
...@@ -643,65 +744,75 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -643,65 +744,75 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# ndarray methods # ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7 # https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max): def clip(self, min, max):
"""Clip the values of the variable to a specified range."""
return px.math.clip(self, min, max) return px.math.clip(self, min, max)
def conj(self): def conj(self):
"""Return the complex conjugate of the variable."""
return px.math.conj(self) return px.math.conj(self)
@property @property
def imag(self): def imag(self):
"""Return the imaginary part of the variable."""
return px.math.imag(self) return px.math.imag(self)
@property @property
def real(self): def real(self):
"""Return the real part of the variable."""
return px.math.real(self) return px.math.real(self)
@property @property
def T(self): def T(self):
"""Return the full transpose of the tensor. """Return the full transpose of the variable.
This is equivalent to calling transpose() with no arguments. This is equivalent to calling transpose() with no arguments.
Returns
-------
XTensorVariable
Fully transposed tensor.
""" """
return self.transpose() return self.transpose()
# Aggregation # Aggregation
# https://docs.xarray.dev/en/latest/api.html#id6 # https://docs.xarray.dev/en/latest/api.html#id6
def all(self, dim=None): def all(self, dim=None):
"""Reduce the variable by applying `all` along some dimension(s)."""
return px.reduction.all(self, dim) return px.reduction.all(self, dim)
def any(self, dim=None): def any(self, dim=None):
"""Reduce the variable by applying `any` along some dimension(s)."""
return px.reduction.any(self, dim) return px.reduction.any(self, dim)
def max(self, dim=None): def max(self, dim=None):
"""Compute the maximum along the given dimension(s)."""
return px.reduction.max(self, dim) return px.reduction.max(self, dim)
def min(self, dim=None): def min(self, dim=None):
"""Compute the minimum along the given dimension(s)."""
return px.reduction.min(self, dim) return px.reduction.min(self, dim)
def mean(self, dim=None): def mean(self, dim=None):
"""Compute the mean along the given dimension(s)."""
return px.reduction.mean(self, dim) return px.reduction.mean(self, dim)
def prod(self, dim=None): def prod(self, dim=None):
"""Compute the product along the given dimension(s)."""
return px.reduction.prod(self, dim) return px.reduction.prod(self, dim)
def sum(self, dim=None): def sum(self, dim=None):
"""Compute the sum along the given dimension(s)."""
return px.reduction.sum(self, dim) return px.reduction.sum(self, dim)
def std(self, dim=None, ddof=0): def std(self, dim=None, ddof=0):
"""Compute the standard deviation along the given dimension(s)."""
return px.reduction.std(self, dim, ddof=ddof) return px.reduction.std(self, dim, ddof=ddof)
def var(self, dim=None, ddof=0): def var(self, dim=None, ddof=0):
"""Compute the variance along the given dimension(s)."""
return px.reduction.var(self, dim, ddof=ddof) return px.reduction.var(self, dim, ddof=ddof)
def cumsum(self, dim=None): def cumsum(self, dim=None):
"""Compute the cumulative sum along the given dimension(s)."""
return px.reduction.cumsum(self, dim) return px.reduction.cumsum(self, dim)
def cumprod(self, dim=None): def cumprod(self, dim=None):
"""Compute the cumulative product along the given dimension(s)."""
return px.reduction.cumprod(self, dim) return px.reduction.cumprod(self, dim)
def diff(self, dim, n=1): def diff(self, dim, n=1):
...@@ -720,7 +831,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -720,7 +831,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
*dim: str | EllipsisType, *dim: str | EllipsisType,
missing_dims: Literal["raise", "warn", "ignore"] = "raise", missing_dims: Literal["raise", "warn", "ignore"] = "raise",
): ):
"""Transpose dimensions of the tensor. """Transpose the dimensions of the variable.
Parameters Parameters
---------- ----------
...@@ -729,6 +840,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -729,6 +840,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
Can use ellipsis (...) to represent remaining dimensions. Can use ellipsis (...) to represent remaining dimensions.
missing_dims : {"raise", "warn", "ignore"}, default="raise" missing_dims : {"raise", "warn", "ignore"}, default="raise"
How to handle dimensions that don't exist in the tensor: How to handle dimensions that don't exist in the tensor:
- "raise": Raise an error if any dimensions don't exist - "raise": Raise an error if any dimensions don't exist
- "warn": Warn if any dimensions don't exist - "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist - "ignore": Silently ignore any dimensions that don't exist
...@@ -747,21 +859,38 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -747,21 +859,38 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return px.shape.transpose(self, *dim, missing_dims=missing_dims) return px.shape.transpose(self, *dim, missing_dims=missing_dims)
def stack(self, dim, **dims): def stack(self, dim, **dims):
"""Stack existing dimensions into a single new dimension."""
return px.shape.stack(self, dim, **dims) return px.shape.stack(self, dim, **dims)
def unstack(self, dim, **dims): def unstack(self, dim, **dims):
"""Unstack a dimension into multiple dimensions of a given size.
Because XTensorVariables don't have coords, this operation requires the sizes of each unstacked dimension to be specified.
Also, unstacked dims will follow a C-style order, regardless of the order of the original dimensions.
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[1, 2], [3, 4]], dims=("a", "b"))
stacked_cumsum = x.stack({"c": ["a", "b"]}).cumsum("c")
unstacked_cumsum = stacked_cumsum.unstack({"c": x.sizes})
print(unstacked_cumsum.eval())
.. testoutput::
[[ 1 3]
[ 6 10]]
"""
return px.shape.unstack(self, dim, **dims) return px.shape.unstack(self, dim, **dims)
def dot(self, other, dim=None): def dot(self, other, dim=None):
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims.""" """Generalized dot product with another XTensorVariable."""
return px.math.dot(self, other, dim=dim) return px.math.dot(self, other, dim=dim)
def broadcast(self, *others, exclude=None):
"""Broadcast this tensor against other XTensorVariables."""
return px.shape.broadcast(self, *others, exclude=exclude)
def broadcast_like(self, other, exclude=None): def broadcast_like(self, other, exclude=None):
"""Broadcast this tensor against another XTensorVariable.""" """Broadcast against another XTensorVariable."""
_, self_bcast = px.shape.broadcast(other, self, exclude=exclude) _, self_bcast = px.shape.broadcast(other, self, exclude=exclude)
return self_bcast return self_bcast
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论