提交 155db9f3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement basic labeled tensor functionality

上级 f7cf2734
...@@ -82,11 +82,12 @@ jobs: ...@@ -82,11 +82,12 @@ jobs:
install-numba: [0] install-numba: [0]
install-jax: [0] install-jax: [0]
install-torch: [0] install-torch: [0]
install-xarray: [0]
part: part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse --ignore=tests/xtensor"
- "tests/scan" - "tests/scan"
- "tests/sparse" - "tests/sparse"
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py" - "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py"
- "tests/tensor/conv" - "tests/tensor/conv"
- "tests/tensor/rewriting" - "tests/tensor/rewriting"
- "tests/tensor/test_math.py" - "tests/tensor/test_math.py"
...@@ -115,6 +116,7 @@ jobs: ...@@ -115,6 +116,7 @@ jobs:
install-numba: 0 install-numba: 0
install-jax: 0 install-jax: 0
install-torch: 0 install-torch: 0
install-xarray: 0
- install-numba: 1 - install-numba: 1
os: "ubuntu-latest" os: "ubuntu-latest"
python-version: "3.10" python-version: "3.10"
...@@ -150,6 +152,13 @@ jobs: ...@@ -150,6 +152,13 @@ jobs:
fast-compile: 0 fast-compile: 0
float32: 0 float32: 0
part: "tests/link/pytorch" part: "tests/link/pytorch"
- install-xarray: 1
os: "ubuntu-latest"
python-version: "3.13"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/xtensor"
- os: macos-15 - os: macos-15
python-version: "3.13" python-version: "3.13"
numpy-version: ">=2.0" numpy-version: ">=2.0"
...@@ -196,6 +205,7 @@ jobs: ...@@ -196,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install pytest-sphinx pip install pytest-sphinx
pip install -e ./ pip install -e ./
...@@ -212,6 +222,7 @@ jobs: ...@@ -212,6 +222,7 @@ jobs:
INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}} INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }}
OS: ${{ matrix.os}} OS: ${{ matrix.os}}
- name: Run tests - name: Run tests
......
...@@ -67,6 +67,8 @@ exclude = [] ...@@ -67,6 +67,8 @@ exclude = []
if not config.cxx: if not config.cxx:
exclude = ["cxx_only"] exclude = ["cxx_only"]
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude) OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
# Even if multiple merge optimizer call will be there, this shouldn't # Even if multiple merge optimizer call will be there, this shouldn't
# impact performance. # impact performance.
OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude) OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude)
...@@ -77,6 +79,7 @@ OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"], exclude=exclud ...@@ -77,6 +79,7 @@ OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"], exclude=exclud
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude) OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
OPT_STABILIZE.position_cutoff = 1.5000001 OPT_STABILIZE.position_cutoff = 1.5000001
OPT_NONE.name = "OPT_NONE" OPT_NONE.name = "OPT_NONE"
OPT_MINIMUM.name = "OPT_MINIMUM"
OPT_MERGE.name = "OPT_MERGE" OPT_MERGE.name = "OPT_MERGE"
OPT_FAST_RUN.name = "OPT_FAST_RUN" OPT_FAST_RUN.name = "OPT_FAST_RUN"
OPT_FAST_RUN_STABLE.name = "OPT_FAST_RUN_STABLE" OPT_FAST_RUN_STABLE.name = "OPT_FAST_RUN_STABLE"
...@@ -95,6 +98,7 @@ predefined_optimizers = { ...@@ -95,6 +98,7 @@ predefined_optimizers = {
None: OPT_NONE, None: OPT_NONE,
"None": OPT_NONE, "None": OPT_NONE,
"merge": OPT_MERGE, "merge": OPT_MERGE,
"minimum_compile": OPT_MINIMUM,
"o4": OPT_FAST_RUN, "o4": OPT_FAST_RUN,
"o3": OPT_O3, "o3": OPT_O3,
"o2": OPT_O2, "o2": OPT_O2,
...@@ -191,6 +195,7 @@ optdb.register( ...@@ -191,6 +195,7 @@ optdb.register(
"merge1", MergeOptimizer(), "fast_run", "fast_compile", "merge", position=0 "merge1", MergeOptimizer(), "fast_run", "fast_compile", "merge", position=0
) )
# After scan1 opt at 0.5 and before ShapeOpt at 1 # After scan1 opt at 0.5 and before ShapeOpt at 1
# This should only remove nodes. # This should only remove nodes.
# The opt should not do anything that need shape inference. # The opt should not do anything that need shape inference.
......
import warnings
import pytensor.xtensor.rewriting
from pytensor.xtensor.type import (
XTensorType,
as_xtensor,
xtensor,
xtensor_constant,
)
warnings.warn("xtensor module is experimental and full of bugs")
from collections.abc import Sequence
from pytensor.compile.ops import TypeCastingOp
from pytensor.graph import Apply, Op
from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
class XOp(Op):
"""A base class for XOps that shouldn't be materialized"""
def perform(self, node, inputs, outputs):
raise NotImplementedError(
f"xtensor operation {self} must be lowered to equivalent tensor operations"
)
class XTypeCastOp(TypeCastingOp):
"""Base class for Ops that type cast between TensorType and XTensorType.
This is like a `ViewOp` but without the expectation the input and output have identical types.
"""
class TensorFromXTensor(XTypeCastOp):
__props__ = ()
def make_node(self, x):
if not isinstance(x.type, XTensorType):
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
output = TensorType(x.type.dtype, shape=x.type.shape)()
return Apply(self, [x], [output])
def L_op(self, inputs, outs, g_outs):
[x] = inputs
[g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)]
tensor_from_xtensor = TensorFromXTensor()
class XTensorFromTensor(XTypeCastOp):
__props__ = ("dims",)
def __init__(self, dims: Sequence[str]):
super().__init__()
self.dims = tuple(dims)
def make_node(self, x):
if not isinstance(x.type, TensorType):
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
return Apply(self, [x], [output])
def L_op(self, inputs, outs, g_outs):
[g_out] = g_outs
return [tensor_from_xtensor(g_out)]
def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name)
class Rename(XTypeCastOp):
__props__ = ("new_dims",)
def __init__(self, new_dims: tuple[str, ...]):
super().__init__()
self.new_dims = new_dims
def make_node(self, x):
x = as_xtensor(x)
output = x.type.clone(dims=self.new_dims)()
return Apply(self, [x], [output])
def L_op(self, inputs, outs, g_outs):
[x] = inputs
[g_out] = g_outs
return [rename(g_out, dims=x.type.dims)]
def rename(x, name_dict: dict[str, str] | None = None, **names: str):
if name_dict is not None:
if names:
raise ValueError("Cannot use both positional and keyword names in rename")
names = name_dict
x = as_xtensor(x)
old_names = x.type.dims
new_names = list(old_names)
for old_name, new_name in names.items():
try:
new_names[old_names.index(old_name)] = new_name
except ValueError:
raise ValueError(
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
)
return Rename(tuple(new_names))(x)
# XTensor Module
This module implements as abstraction layer on regular tensor operations, that behaves like Xarray.
A new type `XTensorType`, generalizes the `TensorType` with the addition of a `dims` attribute,
that labels the dimensions of the tensor.
Variables of `XTensorType` (i.e., `XTensorVariable`s) are the symbolic counterpart to xarray DataArray objects.
The module implements several PyTensor operations `XOp`s, whose signature mimics that of xarray (and xarray_einstants) DataArray operations.
These operations, unlike most regular PyTensor operations, cannot be directly evaluated, but require a rewrite (lowering) into
a regular tensor graph that can itself be evaluated as usual.
Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray.
If the existing XOps can be composed to produce the desired result, then we can use them directly.
## Coordinates
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`.
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor.
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor.
## Example
```python
import pytensor.tensor as pt
import pytensor.xtensor as px
a = pt.tensor("a", shape=(3,))
b = pt.tensor("b", shape=(4,))
ax = px.as_xtensor(a, dims=["x"])
bx = px.as_xtensor(b, dims=["y"])
zx = ax + bx
assert zx.type == px.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4))
z = zx.values
z.dprint()
# TensorFromXTensor [id A]
# └─ XElemwise{scalar_op=Add()} [id B]
# ├─ XTensorFromTensor{dims=('x',)} [id C]
# │ └─ a [id D]
# └─ XTensorFromTensor{dims=('y',)} [id E]
# └─ b [id F]
```
Once we compile the graph, no `XOp`s are left.
```python
import pytensor
with pytensor.config.change_flags(optimizer_verbose=True):
fn = pytensor.function([a, b], z)
# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0)
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0)
fn.dprint()
# Add [id A] 2
# ├─ ExpandDims{axis=1} [id B] 1
# │ └─ a [id C]
# └─ ExpandDims{axis=0} [id D] 0
# └─ b [id E]
```
import pytensor.xtensor.rewriting.basic
from pytensor.graph import node_rewriter
from pytensor.tensor.basic import register_infer_shape
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
from pytensor.xtensor.basic import (
Rename,
TensorFromXTensor,
XTensorFromTensor,
xtensor_from_tensor,
)
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
@register_infer_shape
@register_useless
@register_canonicalize
@register_lower_xtensor
@node_rewriter(tracks=[TensorFromXTensor])
def useless_tensor_from_xtensor(fgraph, node):
"""TensorFromXTensor(XTensorFromTensor(x)) -> x"""
[x] = node.inputs
if x.owner and isinstance(x.owner.op, XTensorFromTensor):
return [x.owner.inputs[0]]
@register_infer_shape
@register_useless
@register_canonicalize
@register_lower_xtensor
@node_rewriter(tracks=[XTensorFromTensor])
def useless_xtensor_from_tensor(fgraph, node):
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
[x] = node.inputs
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
return [x.owner.inputs[0]]
@register_lower_xtensor
@node_rewriter(tracks=[TensorFromXTensor])
def useless_tensor_from_xtensor_of_rename(fgraph, node):
"""TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)"""
[renamed_x] = node.inputs
if renamed_x.owner and isinstance(renamed_x.owner.op, Rename):
[x] = renamed_x.owner.inputs
return node.op(x, return_list=True)
@register_lower_xtensor
@node_rewriter(tracks=[Rename])
def useless_rename(fgraph, node):
"""
Rename(Rename(x, inner_dims), outer_dims) -> Rename(x, outer_dims)
Rename(X, XTensorFromTensor(x, inner_dims), outer_dims) -> XTensorFrom_tensor(x, outer_dims)
"""
[renamed_x] = node.inputs
if renamed_x.owner:
if isinstance(renamed_x.owner.op, Rename):
[x] = renamed_x.owner.inputs
return [node.op(x)]
elif isinstance(renamed_x.owner.op, TensorFromXTensor):
[x] = renamed_x.owner.inputs
return [xtensor_from_tensor(x, dims=node.op.new_dims)]
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import NodeRewriter
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)
optdb.register(
"lower_xtensor",
lower_xtensor_db,
"fast_run",
"fast_compile",
"minimum_compile",
position=0.1,
)
def register_lower_xtensor(
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_rewriter: RewriteDatabase | NodeRewriter):
return register_lower_xtensor(
inner_rewriter, node_rewriter, *tags, **kwargs
)
return register
else:
name = kwargs.pop("name", None) or node_rewriter.__name__ # type: ignore
lower_xtensor_db.register(
name,
node_rewriter,
"fast_run",
"fast_compile",
"minimum_compile",
*tags,
**kwargs,
)
return node_rewriter
import typing
from pytensor.compile import (
DeepCopyOp,
ViewOp,
register_deep_copy_op_c_code,
register_view_op_c_code,
)
from pytensor.tensor import (
TensorType,
_as_tensor_variable,
as_tensor_variable,
specify_shape,
)
from pytensor.tensor.math import variadic_mul
try:
import xarray as xr
XARRAY_AVAILABLE = True
except ModuleNotFoundError:
XARRAY_AVAILABLE = False
from collections.abc import Sequence
from typing import TypeVar
import numpy as np
import pytensor.xtensor as px
from pytensor import _as_symbolic, config
from pytensor.graph import Apply, Constant
from pytensor.graph.basic import OptionalApplyType, Variable
from pytensor.graph.type import HasDataType, HasShape, Type
from pytensor.tensor.basic import constant as tensor_constant
from pytensor.tensor.variable import TensorConstantSignature, TensorVariable
class XTensorType(Type, HasDataType, HasShape):
"""A `Type` for Xtensors (Xarray-like tensors with dims)."""
__props__ = ("dtype", "shape", "dims")
def __init__(
self,
dtype: str | np.dtype,
*,
dims: Sequence[str],
shape: Sequence[int | None] | None = None,
name: str | None = None,
):
if dtype == "floatX":
self.dtype = config.floatX
else:
self.dtype = np.dtype(dtype).name
self.dims = tuple(dims)
if len(set(dims)) < len(dims):
raise ValueError(f"Dimensions must be unique. Found duplicates in {dims}: ")
if shape is None:
self.shape = (None,) * len(self.dims)
else:
self.shape = tuple(shape)
if len(self.shape) != len(self.dims):
raise ValueError(
f"Shape {self.shape} must have the same length as dims {self.dims}"
)
self.ndim = len(self.dims)
self.name = name
self.numpy_dtype = np.dtype(self.dtype)
self.filter_checks_isfinite = False
def clone(
self,
dtype=None,
dims=None,
shape=None,
**kwargs,
):
if dtype is None:
dtype = self.dtype
if dims is None:
dims = self.dims
if shape is None:
shape = self.shape
return type(self)(dtype=dtype, shape=shape, dims=dims, **kwargs)
def filter(self, value, strict=False, allow_downcast=None):
# XTensorType behaves like TensorType at runtime, so we filter the same way.
return TensorType.filter(
self, value, strict=strict, allow_downcast=allow_downcast
)
def filter_variable(self, other, allow_convert=True):
if not isinstance(other, Variable):
# The value is not a Variable: we cast it into
# a Constant of the appropriate Type.
other = xtensor_constant(other)
if self.is_super(other.type):
return other
if allow_convert:
other2 = self.convert_variable(other)
if other2 is not None:
return other2
raise TypeError(
f"Cannot convert Type {other.type} (of Variable {other}) into Type {self}."
f"You can try to manually convert {other} into a {self}. "
)
def convert_variable(self, var):
var_type = var.type
if self.is_super(var_type):
return var
if isinstance(var_type, XTensorType):
if (
self.ndim != var_type.ndim
or self.dtype != var_type.dtype
or set(self.dims) != set(var_type.dims)
):
return None
if self.dims != var_type.dims:
var = var.transpose(*self.dims)
var_type = var.type
if self.is_super(var_type):
return var
if any(
s_length is not None
and var_length is not None
and s_length != var_length
for s_length, var_length in zip(self.shape, var_type.shape)
):
# Incompatible static shapes
return None
# Needs a specify_shape
return as_xtensor(specify_shape(var.values, self.shape), dims=self.dims)
if isinstance(var_type, TensorType):
if (
self.ndim != var_type.ndim
or self.dtype != var_type.dtype
or any(
s_length is not None
and var_length is not None
and s_length != var_length
for s_length, var_length in zip(self.shape, var_type.shape)
)
):
return None
else:
return as_xtensor(specify_shape(var, self.shape), dims=self.dims)
return None
def __repr__(self):
return f"XTensorType({self.dtype}, {self.dims}, {self.shape})"
def __hash__(self):
return hash((type(self), self.dtype, self.shape, self.dims))
def __eq__(self, other):
return (
type(self) is type(other)
and self.dtype == other.dtype
and self.dims == other.dims
and self.shape == other.shape
)
def is_super(self, otype):
if type(self) is not type(otype):
return False
if self.dtype != otype.dtype:
return False
if self.dims != otype.dims:
return False
if any(
s_dim_length is not None and s_dim_length != o_dim_length
for s_dim_length, o_dim_length in zip(self.shape, otype.shape)
):
return False
return True
def xtensor(
name: str | None = None,
*,
dims: Sequence[str],
shape: Sequence[int | None] | None = None,
dtype: str | np.dtype = "floatX",
):
return XTensorType(dtype=dtype, dims=dims, shape=shape)(name=name)
_XTensorTypeType = TypeVar("_XTensorTypeType", bound=XTensorType)
class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# These can't work because Python requires native output types
def __bool__(self):
raise TypeError(
"XTensorVariable cannot be converted to Python boolean. "
"Call `.astype(bool)` for the symbolic equivalent."
)
def __index__(self):
raise TypeError(
"XTensorVariable cannot be converted to Python integer. "
"Call `.astype(int)` for the symbolic equivalent."
)
def __int__(self):
raise TypeError(
"XTensorVariable cannot be converted to Python integer. "
"Call `.astype(int)` for the symbolic equivalent."
)
def __float__(self):
raise TypeError(
"XTensorVariables cannot be converted to Python float. "
"Call `.astype(float)` for the symbolic equivalent."
)
def __complex__(self):
raise TypeError(
"XTensorVariables cannot be converted to Python complex number. "
"Call `.astype(complex)` for the symbolic equivalent."
)
# DataArray-like attributes
# https://docs.xarray.dev/en/latest/api.html#id1
@property
def values(self) -> TensorVariable:
return typing.cast(TensorVariable, px.basic.tensor_from_xtensor(self))
# Can't provide property data because that's already taken by Constants!
# data = values
@property
def coords(self):
raise NotImplementedError("coords not implemented for XTensorVariable")
@property
def dims(self) -> tuple[str, ...]:
return self.type.dims
@property
def sizes(self) -> dict[str, TensorVariable]:
return dict(zip(self.dims, self.shape))
@property
def as_numpy(self):
# No-op, since the underlying data is always a numpy array
return self
# ndarray attributes
# https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
@property
def ndim(self) -> int:
return self.type.ndim
@property
def shape(self) -> tuple[TensorVariable, ...]:
return tuple(px.basic.tensor_from_xtensor(self).shape) # type: ignore
@property
def size(self) -> TensorVariable:
return typing.cast(TensorVariable, variadic_mul(*self.shape))
@property
def dtype(self):
return self.type.dtype
@property
def broadcastable(self):
# The concept of broadcastable is not revelant for XTensorVariables, but part of the codebase may request it
return self.type.broadcastable
# DataArray contents
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents
def rename(self, new_name_or_name_dict=None, **names):
if isinstance(new_name_or_name_dict, str):
new_name = new_name_or_name_dict
name_dict = None
else:
new_name = None
name_dict = new_name_or_name_dict
new_out = px.basic.rename(self, name_dict, **names)
new_out.name = new_name
return new_out
def item(self):
raise NotImplementedError("item not implemented for XTensorVariable")
# Indexing
# https://docs.xarray.dev/en/latest/api.html#id2
def __setitem__(self, key, value):
raise TypeError("XTensorVariable does not support item assignment.")
@property
def loc(self):
raise NotImplementedError("loc not implemented for XTensorVariable")
def sel(self, *args, **kwargs):
raise NotImplementedError("sel not implemented for XTensorVariable")
def __getitem__(self, idx):
raise NotImplementedError("Indexing not yet implemnented")
class XTensorConstantSignature(TensorConstantSignature):
pass
class XTensorConstant(XTensorVariable, Constant[_XTensorTypeType]):
def __init__(self, type: _XTensorTypeType, data, name=None):
data_shape = np.shape(data)
if len(data_shape) != type.ndim or any(
ds != ts for ds, ts in zip(np.shape(data), type.shape) if ts is not None
):
raise ValueError(
f"Shape of data ({data_shape}) does not match shape of type ({type.shape})"
)
# We want all the shape information from `data`
if any(s is None for s in type.shape):
type = type.clone(shape=data_shape)
Constant.__init__(self, type, data, name)
def signature(self):
return XTensorConstantSignature((self.type, self.data))
XTensorType.variable_type = XTensorVariable # type: ignore
XTensorType.constant_type = XTensorConstant # type: ignore
def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
x_dims: tuple[str, ...]
if XARRAY_AVAILABLE and isinstance(x, xr.DataArray):
xarray_dims = x.dims
if not all(isinstance(dim, str) for dim in xarray_dims):
raise NotImplementedError(
"DataArray can only be converted to xtensor_constant if all dims are of string type"
)
x_dims = tuple(typing.cast(typing.Iterable[str], xarray_dims))
x_data = x.values
if dims is not None and dims != x_dims:
raise ValueError(
f"xr.DataArray dims {x_dims} don't match requested specified {dims}. "
"Use transpose or rename"
)
else:
x_data = tensor_constant(x).data
if dims is not None:
x_dims = tuple(dims)
else:
if x_data.ndim == 0:
x_dims = ()
else:
raise TypeError(
"Cannot convert TensorLike constant to XTensorConstant without specifying dims."
)
try:
return XTensorConstant(
XTensorType(dtype=x_data.dtype, dims=x_dims, shape=x_data.shape),
x_data,
name=name,
)
except TypeError:
raise TypeError(f"Could not convert {x} to XTensorType")
if XARRAY_AVAILABLE:
@_as_symbolic.register(xr.DataArray)
def as_symbolic_xarray(x, **kwargs):
return xtensor_constant(x, **kwargs)
def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
if isinstance(x, Apply):
if len(x.outputs) != 1:
raise ValueError(
"It is ambiguous which output of a multi-output Op has to be fetched.",
x,
)
else:
x = x.outputs[0]
if isinstance(x, Variable):
if isinstance(x.type, XTensorType):
if (dims is None) or (x.type.dims == dims):
return x
else:
raise ValueError(
f"Variable {x} has dims {x.type.dims}, but requested dims are {dims}."
)
if isinstance(x.type, TensorType):
if dims is None:
if x.type.ndim == 0:
dims = ()
else:
raise TypeError(
"non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
)
return px.basic.xtensor_from_tensor(x, dims=dims, name=name)
else:
raise TypeError(
"Variable with type {x.type} cannot be converted to XTensorVariable."
)
try:
return xtensor_constant(x, dims=dims, name=name)
except TypeError as err:
raise TypeError(f"Cannot convert {x} to XTensorType {type(x)}") from err
register_view_op_c_code(
XTensorType,
# XTensorType is just TensorType under the hood
*ViewOp.c_code_and_version[TensorType],
)
register_deep_copy_op_c_code(
XTensorType,
# XTensorType is just TensorType under the hood
*DeepCopyOp.c_code_and_version[TensorType],
)
@_as_tensor_variable.register(XTensorVariable)
def _xtensor_as_tensor_variable(
x: XTensorVariable, *args, allow_xtensor_conversion: bool = False, **kwargs
) -> TensorVariable:
if not allow_xtensor_conversion:
raise TypeError(
"To avoid subtle bugs, PyTensor forbids automatic conversion of XTensorVariable to TensorVariable.\n"
"You can convert explicitly using `x.values` or pass `allow_xtensor_conversion=True`."
)
return as_tensor_variable(x.values, *args, **kwargs)
# ruff: noqa: E402
import pytest
pytest.importorskip("xarray")
import numpy as np
from xarray import DataArray
from pytensor.graph.basic import equal_computations
from pytensor.tensor import as_tensor, specify_shape, tensor
from pytensor.xtensor import xtensor
from pytensor.xtensor.type import XTensorType, as_xtensor
def test_xtensortype():
x1 = XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))
x2 = XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))
x3 = XTensorType(dtype="float64", dims=("a", "b"), shape=(None, 3))
y1 = XTensorType(dtype="float64", dims=("c", "d"), shape=(4, 5))
z1 = XTensorType(dtype="float32", dims=("a", "b"), shape=(2, 3))
assert x1 == x2 and x1.is_super(x2) and x2.is_super(x1)
assert x1 != x3 and not x1.is_super(x3) and x3.is_super(x1)
assert x1 != y1 and not x1.is_super(y1) and not y1.is_super(x1)
assert x1 != z1 and not x1.is_super(z1) and not z1.is_super(x1)
def test_xtensortype_filter_variable():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y1 = xtensor("y1", dims=("a", "b"), shape=(2, 3))
assert x.type.filter_variable(y1) is y1
y2 = xtensor("y2", dims=("b", "a"), shape=(3, 2))
expected_y2 = as_xtensor(y2.values.transpose(), dims=("a", "b"))
assert equal_computations([x.type.filter_variable(y2)], [expected_y2])
y3 = xtensor("y3", dims=("b", "a"), shape=(3, None))
expected_y3 = as_xtensor(
specify_shape(
as_xtensor(y3.values.transpose(), dims=("a", "b")).values, (2, 3)
),
dims=("a", "b"),
)
assert equal_computations([x.type.filter_variable(y3)], [expected_y3])
# Cases that fail
with pytest.raises(TypeError):
y4 = xtensor("y4", dims=("a", "b"), shape=(3, 2))
x.type.filter_variable(y4)
with pytest.raises(TypeError):
y5 = xtensor("y5", dims=("a", "c"), shape=(2, 3))
x.type.filter_variable(y5)
with pytest.raises(TypeError):
y6 = xtensor("y6", dims=("a", "b", "c"), shape=(2, 3, 4))
x.type.filter_variable(y6)
with pytest.raises(TypeError):
y7 = xtensor("y7", dims=("a", "b"), shape=(2, 3), dtype="int32")
x.type.filter_variable(y7)
z1 = tensor("z1", shape=(2, None))
expected_z1 = as_xtensor(specify_shape(z1, (2, 3)), dims=("a", "b"))
assert equal_computations([x.type.filter_variable(z1)], [expected_z1])
# Cases that fail
with pytest.raises(TypeError):
z2 = tensor("z2", shape=(3, 2))
x.type.filter_variable(z2)
with pytest.raises(TypeError):
z3 = tensor("z3", shape=(1, 2, 3))
x.type.filter_variable(z3)
with pytest.raises(TypeError):
z4 = tensor("z4", shape=(2, 3), dtype="int32")
x.type.filter_variable(z4)
def test_xtensor_constant():
x = as_xtensor(DataArray(np.ones((2, 3)), dims=("a", "b")))
assert x.type == XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))
y = as_xtensor(np.ones((2, 3)), dims=("a", "b"))
assert y.type == x.type
assert x.signature() == y.signature()
assert x.equals(y)
x_eval = x.eval()
assert isinstance(x.eval(), np.ndarray)
np.testing.assert_array_equal(x_eval, y.eval(), strict=True)
z = as_xtensor(np.ones((3, 2)), dims=("b", "a"))
assert z.type != x.type
assert z.signature() != x.signature()
assert not x.equals(z)
np.testing.assert_array_equal(x_eval, z.eval().T, strict=True)
def test_as_tensor():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
with pytest.raises(
TypeError,
match="PyTensor forbids automatic conversion of XTensorVariable to TensorVariable",
):
as_tensor(x)
x_pt = as_tensor(x, allow_xtensor_conversion=True)
assert equal_computations([x_pt], [x.values])
def test_minimum_compile():
from pytensor.compile.mode import Mode
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = as_xtensor(x.values.transpose(), dims=("b", "a"))
minimum_mode = Mode(linker="py", optimizer="minimum_compile")
result = y.eval({"x": np.ones((2, 3))}, mode=minimum_mode)
np.testing.assert_array_equal(result, np.ones((3, 2)))
# ruff: noqa: E402
import pytest
pytest.importorskip("xarray")
import numpy as np
from xarray import DataArray
from xarray.testing import assert_allclose
from pytensor import function
from pytensor.xtensor.type import XTensorType
def xr_function(*args, **kwargs):
"""Compile and wrap a PyTensor function to return xarray DataArrays."""
fn = function(*args, **kwargs)
symbolic_outputs = fn.maker.fgraph.outputs
assert all(
isinstance(out.type, XTensorType) for out in symbolic_outputs
), "All outputs must be xtensor"
def xfn(*xr_inputs):
np_inputs = [
inp.values if isinstance(inp, DataArray) else inp for inp in xr_inputs
]
np_outputs = fn(*np_inputs)
if not isinstance(np_outputs, tuple | list):
return DataArray(np_outputs, dims=symbolic_outputs[0].type.dims)
else:
return tuple(
DataArray(res, dims=out.type.dims)
for res, out in zip(np_outputs, symbolic_outputs)
)
xfn.fn = fn
return xfn
def xr_assert_allclose(x, y, *args, **kwargs):
# Assert that two xarray DataArrays are close, ignoring coordinates
x = x.drop_vars(x.coords)
y = y.drop_vars(y.coords)
assert_allclose(x, y, *args, **kwargs)
def xr_arange_like(x):
return DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)
def xr_random_like(x, rng=None):
if rng is None:
rng = np.random.default_rng()
return DataArray(
rng.standard_normal(size=x.type.shape, dtype=x.type.dtype), dims=x.type.dims
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论