提交 c647bb23 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement shared XTensorVariables

上级 26d81f81
......@@ -4,6 +4,7 @@ from types import EllipsisType
from pytensor.compile import (
DeepCopyOp,
SharedVariable,
ViewOp,
register_deep_copy_op_c_code,
register_view_op_c_code,
......@@ -32,6 +33,7 @@ import numpy as np
import pytensor.xtensor as px
from pytensor import _as_symbolic, config
from pytensor.compile.sharedvalue import shared_constructor
from pytensor.graph import Apply, Constant
from pytensor.graph.basic import OptionalApplyType, Variable
from pytensor.graph.type import HasDataType, HasShape, Type
......@@ -93,6 +95,8 @@ class XTensorType(Type, HasDataType, HasShape):
def filter(self, value, strict=False, allow_downcast=None):
# XTensorType behaves like TensorType at runtime, so we filter the same way.
if XARRAY_AVAILABLE and isinstance(value, xr.DataArray):
value = value.transpose(*self.dims).values
return TensorType.filter(
self, value, strict=strict, allow_downcast=allow_downcast
)
......@@ -105,6 +109,8 @@ class XTensorType(Type, HasDataType, HasShape):
if not isinstance(other, Variable):
# The value is not a Variable: we cast it into
# a Constant of the appropriate Type.
if XARRAY_AVAILABLE and isinstance(other, xr.DataArray):
other = other.transpose(*self.dims).values
other = XTensorConstant(type=self, data=other)
if self.is_super(other.type):
......@@ -929,15 +935,15 @@ XTensorType.variable_type = XTensorVariable # type: ignore
XTensorType.constant_type = XTensorConstant # type: ignore
def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
"""Convert a constant value to an XTensorConstant."""
def _extract_data_and_dims(
x, dims: None | Sequence[str] = None
) -> tuple[np.ndarray, tuple[str, ...]]:
x_dims: tuple[str, ...]
if XARRAY_AVAILABLE and isinstance(x, xr.DataArray):
xarray_dims = x.dims
if not all(isinstance(dim, str) for dim in xarray_dims):
raise NotImplementedError(
"DataArray can only be converted to xtensor_constant if all dims are of string type"
"DataArray can only be converted to xtensor if all dims are of string type"
)
x_dims = tuple(typing.cast(typing.Iterable[str], xarray_dims))
x_data = x.values
......@@ -958,6 +964,13 @@ def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
raise TypeError(
"Cannot convert TensorLike constant to XTensorConstant without specifying dims."
)
return x_data, x_dims
def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
"""Convert a constant value to an XTensorConstant."""
x_data, x_dims = _extract_data_and_dims(x, dims)
try:
return XTensorConstant(
XTensorType(dtype=x_data.dtype, dims=x_dims, shape=x_data.shape),
......@@ -968,11 +981,42 @@ def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
raise TypeError(f"Could not convert {x} to XTensorType")
if XARRAY_AVAILABLE:
class XTensorSharedVariable(SharedVariable, XTensorVariable):
"""Shared variable of XTensorType."""
@_as_symbolic.register(xr.DataArray)
def as_symbolic_xarray(x, **kwargs):
return xtensor_constant(x, **kwargs)
def xtensor_shared(
x,
*,
name=None,
shape=None,
dims=None,
strict=False,
allow_downcast=None,
borrow=False,
):
r"""`SharedVariable` constructor for `XTensorType`\s.
Notes
-----
The default is to assume that the `shape` value might be resized in any
dimension, so the default shape is ``(None,) * len(value.shape)``. The
optional `shape` argument will override this default.
"""
x_data, x_dims = _extract_data_and_dims(x, dims)
return XTensorSharedVariable(
type=XTensorType(dtype=x_data.dtype, dims=x_dims, shape=shape),
value=x_data if borrow else x_data.copy(),
strict=strict,
allow_downcast=allow_downcast,
name=name if name is not None else getattr(x, "name", None),
)
if XARRAY_AVAILABLE:
_as_symbolic.register(xr.DataArray, xtensor_constant)
shared_constructor.register(xr.DataArray, xtensor_shared)
def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None):
......
......@@ -2,6 +2,9 @@ import re
import pytest
from pytensor import as_symbolic, shared
from pytensor.compile import SharedVariable
pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
......@@ -9,10 +12,17 @@ pytestmark = pytest.mark.filterwarnings("error")
import numpy as np
from xarray import DataArray
from pytensor.graph.basic import equal_computations
from pytensor.graph.basic import Constant, equal_computations
from pytensor.tensor import as_tensor, specify_shape, tensor
from pytensor.xtensor import xtensor
from pytensor.xtensor.type import XTensorConstant, XTensorType, as_xtensor
from pytensor.xtensor.type import (
XTensorConstant,
XTensorSharedVariable,
XTensorType,
as_xtensor,
xtensor_constant,
xtensor_shared,
)
def test_xtensortype():
......@@ -110,23 +120,76 @@ def test_xtensortype_filter_variable_constant():
assert isinstance(res, XTensorConstant) and res.type == x.type
def test_xtensor_constant():
x = as_xtensor(DataArray(np.ones((2, 3)), dims=("a", "b")))
@pytest.mark.parametrize(
"constant_constructor", (as_symbolic, as_xtensor, xtensor_constant)
)
def test_xtensor_constant(constant_constructor):
x = constant_constructor(DataArray(np.ones((2, 3)), dims=("a", "b")))
assert isinstance(x, Constant)
assert isinstance(x, XTensorConstant)
assert x.type == XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))
y = as_xtensor(np.ones((2, 3)), dims=("a", "b"))
assert y.type == x.type
assert x.signature() == y.signature()
assert x.equals(y)
x_eval = x.eval()
assert isinstance(x.eval(), np.ndarray)
np.testing.assert_array_equal(x_eval, y.eval(), strict=True)
z = as_xtensor(np.ones((3, 2)), dims=("b", "a"))
assert z.type != x.type
assert z.signature() != x.signature()
assert not x.equals(z)
np.testing.assert_array_equal(x_eval, z.eval().T, strict=True)
if constant_constructor is not as_symbolic:
# We should be able to pass numpy arrays if we pass dims
y = as_xtensor(np.ones((2, 3)), dims=("a", "b"))
assert y.type == x.type
assert x.signature() == y.signature()
assert x.equals(y)
x_eval = x.eval()
assert isinstance(x.eval(), np.ndarray)
np.testing.assert_array_equal(x_eval, y.eval(), strict=True)
z = as_xtensor(np.ones((3, 2)), dims=("b", "a"))
assert z.type != x.type
assert z.signature() != x.signature()
assert not x.equals(z)
np.testing.assert_array_equal(x_eval, z.eval().T, strict=True)
@pytest.mark.parametrize("shared_constructor", (shared, xtensor_shared))
def test_xtensor_shared(shared_constructor):
arr = np.array([[1, 2, 3], [4, 5, 6]], dtype="int64")
xarr = DataArray(arr, dims=("a", "b"), name="xarr")
shared_xarr = shared_constructor(xarr)
assert isinstance(shared_xarr, SharedVariable)
assert isinstance(shared_xarr, XTensorSharedVariable)
assert shared_xarr.type == XTensorType(
dtype="int64", dims=("a", "b"), shape=(None, None)
)
assert xarr.name == "xarr"
shared_rrax = shared_constructor(xarr, shape=(2, None), name="rrax")
assert isinstance(shared_rrax, XTensorSharedVariable)
assert shared_rrax.type == XTensorType(
dtype="int64", dims=("a", "b"), shape=(2, None)
)
assert shared_rrax.name == "rrax"
if shared_constructor == xtensor_shared:
# We should be able to pass numpy arrays, if we pass dims
with pytest.raises(TypeError):
shared_constructor(arr)
shared_arr = shared_constructor(arr, dims=("a", "b"))
assert isinstance(shared_arr, XTensorSharedVariable)
assert shared_arr.type == shared_xarr.type
# Test get and set_value
retrieved_value = shared_xarr.get_value()
assert isinstance(retrieved_value, np.ndarray)
np.testing.assert_allclose(retrieved_value, xarr.to_numpy())
shared_xarr.set_value(xarr[::-1])
np.testing.assert_allclose(shared_xarr.get_value(), xarr[::-1].to_numpy())
# Test dims in different order
shared_xarr.set_value(xarr[::-1].T)
np.testing.assert_allclose(shared_xarr.get_value(), xarr[::-1].to_numpy())
with pytest.raises(ValueError):
shared_xarr.set_value(xarr.rename(b="c"))
shared_xarr.set_value(arr[::-1])
np.testing.assert_allclose(shared_xarr.get_value(), arr[::-1])
def test_as_tensor():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论