提交 c647bb23 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement shared XTensorVariables

上级 26d81f81
...@@ -4,6 +4,7 @@ from types import EllipsisType ...@@ -4,6 +4,7 @@ from types import EllipsisType
from pytensor.compile import ( from pytensor.compile import (
DeepCopyOp, DeepCopyOp,
SharedVariable,
ViewOp, ViewOp,
register_deep_copy_op_c_code, register_deep_copy_op_c_code,
register_view_op_c_code, register_view_op_c_code,
...@@ -32,6 +33,7 @@ import numpy as np ...@@ -32,6 +33,7 @@ import numpy as np
import pytensor.xtensor as px import pytensor.xtensor as px
from pytensor import _as_symbolic, config from pytensor import _as_symbolic, config
from pytensor.compile.sharedvalue import shared_constructor
from pytensor.graph import Apply, Constant from pytensor.graph import Apply, Constant
from pytensor.graph.basic import OptionalApplyType, Variable from pytensor.graph.basic import OptionalApplyType, Variable
from pytensor.graph.type import HasDataType, HasShape, Type from pytensor.graph.type import HasDataType, HasShape, Type
...@@ -93,6 +95,8 @@ class XTensorType(Type, HasDataType, HasShape): ...@@ -93,6 +95,8 @@ class XTensorType(Type, HasDataType, HasShape):
def filter(self, value, strict=False, allow_downcast=None): def filter(self, value, strict=False, allow_downcast=None):
# XTensorType behaves like TensorType at runtime, so we filter the same way. # XTensorType behaves like TensorType at runtime, so we filter the same way.
if XARRAY_AVAILABLE and isinstance(value, xr.DataArray):
value = value.transpose(*self.dims).values
return TensorType.filter( return TensorType.filter(
self, value, strict=strict, allow_downcast=allow_downcast self, value, strict=strict, allow_downcast=allow_downcast
) )
...@@ -105,6 +109,8 @@ class XTensorType(Type, HasDataType, HasShape): ...@@ -105,6 +109,8 @@ class XTensorType(Type, HasDataType, HasShape):
if not isinstance(other, Variable): if not isinstance(other, Variable):
# The value is not a Variable: we cast it into # The value is not a Variable: we cast it into
# a Constant of the appropriate Type. # a Constant of the appropriate Type.
if XARRAY_AVAILABLE and isinstance(other, xr.DataArray):
other = other.transpose(*self.dims).values
other = XTensorConstant(type=self, data=other) other = XTensorConstant(type=self, data=other)
if self.is_super(other.type): if self.is_super(other.type):
...@@ -929,15 +935,15 @@ XTensorType.variable_type = XTensorVariable # type: ignore ...@@ -929,15 +935,15 @@ XTensorType.variable_type = XTensorVariable # type: ignore
XTensorType.constant_type = XTensorConstant # type: ignore XTensorType.constant_type = XTensorConstant # type: ignore
def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): def _extract_data_and_dims(
"""Convert a constant value to an XTensorConstant.""" x, dims: None | Sequence[str] = None
) -> tuple[np.ndarray, tuple[str, ...]]:
x_dims: tuple[str, ...] x_dims: tuple[str, ...]
if XARRAY_AVAILABLE and isinstance(x, xr.DataArray): if XARRAY_AVAILABLE and isinstance(x, xr.DataArray):
xarray_dims = x.dims xarray_dims = x.dims
if not all(isinstance(dim, str) for dim in xarray_dims): if not all(isinstance(dim, str) for dim in xarray_dims):
raise NotImplementedError( raise NotImplementedError(
"DataArray can only be converted to xtensor_constant if all dims are of string type" "DataArray can only be converted to xtensor if all dims are of string type"
) )
x_dims = tuple(typing.cast(typing.Iterable[str], xarray_dims)) x_dims = tuple(typing.cast(typing.Iterable[str], xarray_dims))
x_data = x.values x_data = x.values
...@@ -958,6 +964,13 @@ def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): ...@@ -958,6 +964,13 @@ def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
raise TypeError( raise TypeError(
"Cannot convert TensorLike constant to XTensorConstant without specifying dims." "Cannot convert TensorLike constant to XTensorConstant without specifying dims."
) )
return x_data, x_dims
def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
"""Convert a constant value to an XTensorConstant."""
x_data, x_dims = _extract_data_and_dims(x, dims)
try: try:
return XTensorConstant( return XTensorConstant(
XTensorType(dtype=x_data.dtype, dims=x_dims, shape=x_data.shape), XTensorType(dtype=x_data.dtype, dims=x_dims, shape=x_data.shape),
...@@ -968,11 +981,42 @@ def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): ...@@ -968,11 +981,42 @@ def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
raise TypeError(f"Could not convert {x} to XTensorType") raise TypeError(f"Could not convert {x} to XTensorType")
if XARRAY_AVAILABLE: class XTensorSharedVariable(SharedVariable, XTensorVariable):
"""Shared variable of XTensorType."""
@_as_symbolic.register(xr.DataArray)
def as_symbolic_xarray(x, **kwargs): def xtensor_shared(
return xtensor_constant(x, **kwargs) x,
*,
name=None,
shape=None,
dims=None,
strict=False,
allow_downcast=None,
borrow=False,
):
r"""`SharedVariable` constructor for `XTensorType`\s.
Notes
-----
The default is to assume that the `shape` value might be resized in any
dimension, so the default shape is ``(None,) * len(value.shape)``. The
optional `shape` argument will override this default.
"""
x_data, x_dims = _extract_data_and_dims(x, dims)
return XTensorSharedVariable(
type=XTensorType(dtype=x_data.dtype, dims=x_dims, shape=shape),
value=x_data if borrow else x_data.copy(),
strict=strict,
allow_downcast=allow_downcast,
name=name if name is not None else getattr(x, "name", None),
)
if XARRAY_AVAILABLE:
_as_symbolic.register(xr.DataArray, xtensor_constant)
shared_constructor.register(xr.DataArray, xtensor_shared)
def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None): def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None):
......
...@@ -2,6 +2,9 @@ import re ...@@ -2,6 +2,9 @@ import re
import pytest import pytest
from pytensor import as_symbolic, shared
from pytensor.compile import SharedVariable
pytest.importorskip("xarray") pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error") pytestmark = pytest.mark.filterwarnings("error")
...@@ -9,10 +12,17 @@ pytestmark = pytest.mark.filterwarnings("error") ...@@ -9,10 +12,17 @@ pytestmark = pytest.mark.filterwarnings("error")
import numpy as np import numpy as np
from xarray import DataArray from xarray import DataArray
from pytensor.graph.basic import equal_computations from pytensor.graph.basic import Constant, equal_computations
from pytensor.tensor import as_tensor, specify_shape, tensor from pytensor.tensor import as_tensor, specify_shape, tensor
from pytensor.xtensor import xtensor from pytensor.xtensor import xtensor
from pytensor.xtensor.type import XTensorConstant, XTensorType, as_xtensor from pytensor.xtensor.type import (
XTensorConstant,
XTensorSharedVariable,
XTensorType,
as_xtensor,
xtensor_constant,
xtensor_shared,
)
def test_xtensortype(): def test_xtensortype():
...@@ -110,10 +120,17 @@ def test_xtensortype_filter_variable_constant(): ...@@ -110,10 +120,17 @@ def test_xtensortype_filter_variable_constant():
assert isinstance(res, XTensorConstant) and res.type == x.type assert isinstance(res, XTensorConstant) and res.type == x.type
def test_xtensor_constant(): @pytest.mark.parametrize(
x = as_xtensor(DataArray(np.ones((2, 3)), dims=("a", "b"))) "constant_constructor", (as_symbolic, as_xtensor, xtensor_constant)
)
def test_xtensor_constant(constant_constructor):
x = constant_constructor(DataArray(np.ones((2, 3)), dims=("a", "b")))
assert isinstance(x, Constant)
assert isinstance(x, XTensorConstant)
assert x.type == XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3)) assert x.type == XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))
if constant_constructor is not as_symbolic:
# We should be able to pass numpy arrays if we pass dims
y = as_xtensor(np.ones((2, 3)), dims=("a", "b")) y = as_xtensor(np.ones((2, 3)), dims=("a", "b"))
assert y.type == x.type assert y.type == x.type
assert x.signature() == y.signature() assert x.signature() == y.signature()
...@@ -129,6 +146,52 @@ def test_xtensor_constant(): ...@@ -129,6 +146,52 @@ def test_xtensor_constant():
np.testing.assert_array_equal(x_eval, z.eval().T, strict=True) np.testing.assert_array_equal(x_eval, z.eval().T, strict=True)
@pytest.mark.parametrize("shared_constructor", (shared, xtensor_shared))
def test_xtensor_shared(shared_constructor):
arr = np.array([[1, 2, 3], [4, 5, 6]], dtype="int64")
xarr = DataArray(arr, dims=("a", "b"), name="xarr")
shared_xarr = shared_constructor(xarr)
assert isinstance(shared_xarr, SharedVariable)
assert isinstance(shared_xarr, XTensorSharedVariable)
assert shared_xarr.type == XTensorType(
dtype="int64", dims=("a", "b"), shape=(None, None)
)
assert xarr.name == "xarr"
shared_rrax = shared_constructor(xarr, shape=(2, None), name="rrax")
assert isinstance(shared_rrax, XTensorSharedVariable)
assert shared_rrax.type == XTensorType(
dtype="int64", dims=("a", "b"), shape=(2, None)
)
assert shared_rrax.name == "rrax"
if shared_constructor == xtensor_shared:
# We should be able to pass numpy arrays, if we pass dims
with pytest.raises(TypeError):
shared_constructor(arr)
shared_arr = shared_constructor(arr, dims=("a", "b"))
assert isinstance(shared_arr, XTensorSharedVariable)
assert shared_arr.type == shared_xarr.type
# Test get and set_value
retrieved_value = shared_xarr.get_value()
assert isinstance(retrieved_value, np.ndarray)
np.testing.assert_allclose(retrieved_value, xarr.to_numpy())
shared_xarr.set_value(xarr[::-1])
np.testing.assert_allclose(shared_xarr.get_value(), xarr[::-1].to_numpy())
# Test dims in different order
shared_xarr.set_value(xarr[::-1].T)
np.testing.assert_allclose(shared_xarr.get_value(), xarr[::-1].to_numpy())
with pytest.raises(ValueError):
shared_xarr.set_value(xarr.rename(b="c"))
shared_xarr.set_value(arr[::-1])
np.testing.assert_allclose(shared_xarr.get_value(), arr[::-1])
def test_as_tensor(): def test_as_tensor():
x = xtensor("x", dims=("a", "b"), shape=(2, 3)) x = xtensor("x", dims=("a", "b"), shape=(2, 3))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论