提交 cdb026f7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement cast for XTensorVariables

上级 be1330bf
import sys import sys
import numpy as np
import pytensor.scalar as ps import pytensor.scalar as ps
from pytensor import config
from pytensor.scalar import ScalarOp from pytensor.scalar import ScalarOp
from pytensor.scalar.basic import _cast_mapping
from pytensor.xtensor.basic import as_xtensor
from pytensor.xtensor.vectorization import XElemwise from pytensor.xtensor.vectorization import XElemwise
...@@ -107,3 +112,25 @@ tri_gamma = _as_xelemwise(ps.tri_gamma) ...@@ -107,3 +112,25 @@ tri_gamma = _as_xelemwise(ps.tri_gamma)
true_divide = true_div = _as_xelemwise(ps.true_div) true_divide = true_div = _as_xelemwise(ps.true_div)
trunc = _as_xelemwise(ps.trunc) trunc = _as_xelemwise(ps.trunc)
logical_xor = bitwise_xor = xor = _as_xelemwise(ps.xor) logical_xor = bitwise_xor = xor = _as_xelemwise(ps.xor)
_xelemwise_cast_op: dict[str, XElemwise] = {}
def cast(x, dtype):
if dtype == "floatX":
dtype = config.floatX
else:
dtype = np.dtype(dtype).name
x = as_xtensor(x)
if x.type.dtype == dtype:
return x
if x.type.dtype.startswith("complex") and not dtype.startswith("complex"):
raise TypeError(
"Casting from complex to real is ambiguous: consider"
" real(), imag(), angle() or abs()"
)
if dtype not in _xelemwise_cast_op:
_xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype])
return _xelemwise_cast_op[dtype](x)
...@@ -401,6 +401,9 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -401,6 +401,9 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
out.name = name # type: ignore out.name = name # type: ignore
return out return out
def astype(self, dtype):
return px.math.cast(self, dtype)
def item(self): def item(self):
raise NotImplementedError("item not implemented for XTensorVariable") raise NotImplementedError("item not implemented for XTensorVariable")
......
...@@ -16,7 +16,7 @@ from pytensor.scalar import ScalarOp ...@@ -16,7 +16,7 @@ from pytensor.scalar import ScalarOp
from pytensor.xtensor.basic import rename from pytensor.xtensor.basic import rename
from pytensor.xtensor.math import add, exp from pytensor.xtensor.math import add, exp
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_assert_allclose, xr_function from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
def test_all_scalar_ops_are_wrapped(): def test_all_scalar_ops_are_wrapped():
...@@ -132,3 +132,21 @@ def test_multiple_constant(): ...@@ -132,3 +132,21 @@ def test_multiple_constant():
res = fn(x_test) res = fn(x_test)
expected_res = np.exp(x_test * 2) + 2 expected_res = np.exp(x_test * 2) + 2
np.testing.assert_allclose(res, expected_res) np.testing.assert_allclose(res, expected_res)
def test_cast():
x = xtensor("x", shape=(2, 3), dims=("a", "b"), dtype="float32")
yf64 = x.astype("float64")
yi16 = x.astype("int16")
ybool = x.astype("bool")
fn = xr_function([x], [yf64, yi16, ybool])
x_test = xr_arange_like(x)
res_f64, res_i16, res_bool = fn(x_test)
xr_assert_allclose(res_f64, x_test.astype("float64"))
xr_assert_allclose(res_i16, x_test.astype("int16"))
xr_assert_allclose(res_bool, x_test.astype("bool"))
yc64 = x.astype("complex64")
with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"):
yc64.astype("float64")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论