提交 be2ab8dd authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

XFAIL/SKIP float16 tests

上级 b9468e04
......@@ -18,6 +18,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.link.numba import NumbaLinker
from pytensor.printing import debugprint, pprint
from pytensor.raise_op import Assert, CheckAndRaise
from pytensor.scalar import Composite, float64
......@@ -1206,6 +1207,10 @@ class TestLocalOptAlloc:
f(5)
@pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Numba does not support float16",
)
class TestLocalOptAllocF16(TestLocalOptAlloc):
dtype = "float16"
......
......@@ -24,6 +24,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.graph.traversal import ancestors, applys_between
from pytensor.link.c.basic import DualLinker
from pytensor.link.numba import NumbaLinker
from pytensor.printing import pprint
from pytensor.raise_op import Assert
from pytensor.tensor import blas, blas_c
......@@ -858,6 +859,10 @@ class TestMaxAndArgmax:
([1, 0], None),
],
)
@pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Numba does not support float16",
)
def test_basic_2_float16(self, axis, np_axis):
# Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
data = (random(20, 30).astype("float16") - 0.5) * 20
......@@ -1114,6 +1119,10 @@ class TestArgminArgmax:
v_shape = eval_outputs(fct(n, axis).shape)
assert tuple(v_shape) == nfct(data, np_axis).shape
@pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Numba does not support float16",
)
def test2_float16(self):
# Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
data = (random(20, 30).astype("float16") - 0.5) * 20
......@@ -1981,6 +1990,10 @@ class TestMean:
res = mean(np.zeros(1))
assert res.eval() == 0.0
@pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Numba does not support float16",
)
def test_mean_f16(self):
x = vector(dtype="float16")
y = x.mean()
......@@ -3153,7 +3166,9 @@ class TestSumProdReduceDtype:
op = CAReduce
axes = [None, 0, 1, [], [0], [1], [0, 1]]
methods = ["sum", "prod"]
dtypes = list(map(str, ps.all_types))
dtypes = tuple(map(str, ps.all_types))
if isinstance(mode.linker, NumbaLinker):
dtypes = tuple(d for d in dtypes if d != "float16")
# Test the default dtype of a method().
def test_reduce_default_dtype(self):
......@@ -3313,10 +3328,13 @@ class TestSumProdReduceDtype:
class TestMeanDtype:
def test_mean_default_dtype(self):
# Test the default dtype of a mean().
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [], [0], [1], [0, 1]]
for idx, dtype in enumerate(map(str, ps.all_types)):
if is_numba and dtype == "float16":
continue
axis = axes[idx % len(axes)]
x = matrix(dtype=dtype)
m = x.mean(axis=axis)
......@@ -3411,10 +3429,13 @@ class TestProdWithoutZerosDtype:
def test_prod_without_zeros_default_acc_dtype(self):
# Test the default dtype of a ProdWithoutZeros().
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [], [0], [1], [0, 1]]
for idx, dtype in enumerate(map(str, ps.all_types)):
if is_numba and dtype == "float16":
continue
axis = axes[idx % len(axes)]
x = matrix(dtype=dtype)
p = ProdWithoutZeros(axis=axis)(x)
......@@ -3442,13 +3463,17 @@ class TestProdWithoutZerosDtype:
@pytest.mark.slow
def test_prod_without_zeros_custom_dtype(self):
# Test ability to provide your own output dtype for a ProdWithoutZeros().
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [], [0], [1], [0, 1]]
idx = 0
for input_dtype in map(str, ps.all_types):
if is_numba and input_dtype == "float16":
continue
x = matrix(dtype=input_dtype)
for output_dtype in map(str, ps.all_types):
if is_numba and output_dtype == "float16":
continue
axis = axes[idx % len(axes)]
prod_woz_var = ProdWithoutZeros(axis=axis, dtype=output_dtype)(x)
assert prod_woz_var.dtype == output_dtype
......@@ -3464,13 +3489,18 @@ class TestProdWithoutZerosDtype:
@pytest.mark.slow
def test_prod_without_zeros_custom_acc_dtype(self):
# Test ability to provide your own acc_dtype for a ProdWithoutZeros().
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [], [0], [1], [0, 1]]
idx = 0
for input_dtype in map(str, ps.all_types):
if is_numba and input_dtype == "float16":
continue
x = matrix(dtype=input_dtype)
for acc_dtype in map(str, ps.all_types):
if is_numba and acc_dtype == "float16":
continue
axis = axes[idx % len(axes)]
# If acc_dtype would force a downcast, we expect a TypeError
# We always allow int/uint inputs with float/complex outputs.
......@@ -3746,7 +3776,20 @@ class TestMatMul:
with pytest.raises(ValueError, match="cannot be scalar"):
self.op(4, [4, 1])
@pytest.mark.parametrize("dtype", (np.float16, np.float32, np.float64))
@pytest.mark.parametrize(
"dtype",
(
pytest.param(
np.float16,
marks=pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Numba does not support float16",
),
),
np.float32,
np.float64,
),
)
def test_dtype_param(self, dtype):
sol = self.op([1, 2, 3], [3, 2, 1], dtype=dtype)
assert sol.eval().dtype == dtype
......
......@@ -10,8 +10,10 @@ from scipy import linalg as scipy_linalg
from pytensor import function, grad
from pytensor import tensor as pt
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations
from pytensor.link.numba import NumbaLinker
from pytensor.tensor import TensorVariable
from pytensor.tensor.slinalg import (
Cholesky,
......@@ -606,6 +608,8 @@ class TestCholeskySolve(utt.InferShapeTester):
)
def test_solve_dtype(self):
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
dtypes = [
"uint8",
"uint16",
......@@ -626,6 +630,9 @@ class TestCholeskySolve(utt.InferShapeTester):
# try all dtype combinations
for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
if is_numba and (A_dtype == "float16" or b_dtype == "float16"):
# Numba does not support float16
continue
A = matrix(dtype=A_dtype)
b = matrix(dtype=b_dtype)
x = op(A, b)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论