提交 46f89676 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix casting of boolean/unsigned integers to complex in C-backend

上级 69575d26
......@@ -511,6 +511,10 @@ class ScalarType(CType, HasDataType, HasShape):
"npy_int16",
"npy_int32",
"npy_int64",
"npy_uint8", # also covers npy_bool
"npy_uint16",
"npy_uint32",
"npy_uint64",
"npy_float32",
"npy_float64",
]
......
......@@ -3,7 +3,7 @@ import pytest
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import Mode
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.graph.fg import FunctionGraph
from pytensor.link.c.basic import DualLinker
from pytensor.link.numba import NumbaLinker
......@@ -15,6 +15,7 @@ from pytensor.scalar.basic import (
ScalarType,
TrueDiv,
add,
all_types,
and_,
arccos,
arccosh,
......@@ -55,6 +56,7 @@ from pytensor.scalar.basic import (
tanh,
true_div,
)
from pytensor.tensor import tensor_from_scalar
from pytensor.tensor.type import fscalar, imatrix, matrix
from tests.link.test_link import make_function
......@@ -515,3 +517,16 @@ def test_rfloordiv():
assert isinstance(y.owner.op, IntDiv)
assert isinstance(y.type, ScalarType)
assert y.eval({x: 2.0}) == 2.0
@pytest.mark.parametrize("inp_type", all_types, ids=lambda x: x.dtype)
def test_cast_to_complex(inp_type):
if inp_type.dtype == "float16":
if isinstance(get_default_mode().linker, NumbaLinker):
pytest.skip("Numba doesn't support float16")
x = inp_type("x")
# Output as tensor to sidestep numba issue with numpy scalar outputs
y = tensor_from_scalar(x.astype("complex64"))
res_y = y.eval({x: np.array(1.0, dtype=inp_type.dtype)})
assert res_y == 1
assert res_y.dtype == "complex64"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论