提交 cdb026f7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement cast for XTensorVariables

上级 be1330bf
import sys
import numpy as np
import pytensor.scalar as ps
from pytensor import config
from pytensor.scalar import ScalarOp
from pytensor.scalar.basic import _cast_mapping
from pytensor.xtensor.basic import as_xtensor
from pytensor.xtensor.vectorization import XElemwise
......@@ -107,3 +112,25 @@ tri_gamma = _as_xelemwise(ps.tri_gamma)
true_divide = true_div = _as_xelemwise(ps.true_div)
trunc = _as_xelemwise(ps.trunc)
logical_xor = bitwise_xor = xor = _as_xelemwise(ps.xor)
_xelemwise_cast_op: dict[str, XElemwise] = {}
def cast(x, dtype):
if dtype == "floatX":
dtype = config.floatX
else:
dtype = np.dtype(dtype).name
x = as_xtensor(x)
if x.type.dtype == dtype:
return x
if x.type.dtype.startswith("complex") and not dtype.startswith("complex"):
raise TypeError(
"Casting from complex to real is ambiguous: consider"
" real(), imag(), angle() or abs()"
)
if dtype not in _xelemwise_cast_op:
_xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype])
return _xelemwise_cast_op[dtype](x)
......@@ -401,6 +401,9 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
out.name = name # type: ignore
return out
def astype(self, dtype):
return px.math.cast(self, dtype)
def item(self):
raise NotImplementedError("item not implemented for XTensorVariable")
......
......@@ -16,7 +16,7 @@ from pytensor.scalar import ScalarOp
from pytensor.xtensor.basic import rename
from pytensor.xtensor.math import add, exp
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_assert_allclose, xr_function
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
def test_all_scalar_ops_are_wrapped():
......@@ -132,3 +132,21 @@ def test_multiple_constant():
res = fn(x_test)
expected_res = np.exp(x_test * 2) + 2
np.testing.assert_allclose(res, expected_res)
def test_cast():
x = xtensor("x", shape=(2, 3), dims=("a", "b"), dtype="float32")
yf64 = x.astype("float64")
yi16 = x.astype("int16")
ybool = x.astype("bool")
fn = xr_function([x], [yf64, yi16, ybool])
x_test = xr_arange_like(x)
res_f64, res_i16, res_bool = fn(x_test)
xr_assert_allclose(res_f64, x_test.astype("float64"))
xr_assert_allclose(res_i16, x_test.astype("int16"))
xr_assert_allclose(res_bool, x_test.astype("bool"))
yc64 = x.astype("complex64")
with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"):
yc64.astype("float64")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论