提交 46f89676 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix casting of boolean/unsigned integers to complex in C-backend

上级 69575d26
...@@ -511,6 +511,10 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -511,6 +511,10 @@ class ScalarType(CType, HasDataType, HasShape):
"npy_int16", "npy_int16",
"npy_int32", "npy_int32",
"npy_int64", "npy_int64",
"npy_uint8", # also covers npy_bool
"npy_uint16",
"npy_uint32",
"npy_uint64",
"npy_float32", "npy_float32",
"npy_float64", "npy_float64",
] ]
......
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode, get_default_mode
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.link.c.basic import DualLinker from pytensor.link.c.basic import DualLinker
from pytensor.link.numba import NumbaLinker from pytensor.link.numba import NumbaLinker
...@@ -15,6 +15,7 @@ from pytensor.scalar.basic import ( ...@@ -15,6 +15,7 @@ from pytensor.scalar.basic import (
ScalarType, ScalarType,
TrueDiv, TrueDiv,
add, add,
all_types,
and_, and_,
arccos, arccos,
arccosh, arccosh,
...@@ -55,6 +56,7 @@ from pytensor.scalar.basic import ( ...@@ -55,6 +56,7 @@ from pytensor.scalar.basic import (
tanh, tanh,
true_div, true_div,
) )
from pytensor.tensor import tensor_from_scalar
from pytensor.tensor.type import fscalar, imatrix, matrix from pytensor.tensor.type import fscalar, imatrix, matrix
from tests.link.test_link import make_function from tests.link.test_link import make_function
...@@ -515,3 +517,16 @@ def test_rfloordiv(): ...@@ -515,3 +517,16 @@ def test_rfloordiv():
assert isinstance(y.owner.op, IntDiv) assert isinstance(y.owner.op, IntDiv)
assert isinstance(y.type, ScalarType) assert isinstance(y.type, ScalarType)
assert y.eval({x: 2.0}) == 2.0 assert y.eval({x: 2.0}) == 2.0
@pytest.mark.parametrize("inp_type", all_types, ids=lambda x: x.dtype)
def test_cast_to_complex(inp_type):
if inp_type.dtype == "float16":
if isinstance(get_default_mode().linker, NumbaLinker):
pytest.skip("Numba doesn't support float16")
x = inp_type("x")
# Output as tensor to sidestep numba issue with numpy scalar outputs
y = tensor_from_scalar(x.astype("complex64"))
res_y = y.eval({x: np.array(1.0, dtype=inp_type.dtype)})
assert res_y == 1
assert res_y.dtype == "complex64"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论