提交 df63bd24 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Test: Reduce number of parametrizations

2304 -> 288
上级 ae20b6ab
......@@ -9,6 +9,7 @@ import pytensor
import pytensor.sparse.math as psm
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.scalar import upcast
from pytensor.sparse.basic import (
CSR,
CSMProperties,
......@@ -367,7 +368,7 @@ class TestStructuredDot:
)
for dense_dtype in typenames:
for sparse_dtype in typenames:
correct_dtype = pytensor.scalar.upcast(sparse_dtype, dense_dtype)
correct_dtype = upcast(sparse_dtype, dense_dtype)
a = SparseTensorType("csc", dtype=sparse_dtype)()
b = matrix(dtype=dense_dtype)
d = structured_dot(a, b)
......@@ -736,11 +737,10 @@ class TestUsmm:
@pytest.mark.slow
@pytest.mark.parametrize("dtype1", ["float32", "float64", "int16", "complex64"])
@pytest.mark.parametrize("dtype2", ["float32", "float64", "int16", "complex64"])
@pytest.mark.parametrize("dtype3", ["float32", "float64", "int16", "complex64"])
@pytest.mark.parametrize("dtype4", ["float32", "float64", "int16", "complex64"])
@pytest.mark.parametrize("can_inplace", [False, True])
@pytest.mark.parametrize("format1", ["dense", "csc", "csr"])
@pytest.mark.parametrize("format2", ["dense", "csc", "csr"])
def test_basic(self, dtype1, dtype2, dtype3, dtype4, format1, format2):
def test_basic(self, dtype1, dtype2, can_inplace, format1, format2):
def mat(format, name, dtype):
if format == "dense":
return matrix(name, dtype=dtype)
......@@ -750,8 +750,13 @@ class TestUsmm:
if format1 == "dense" and format2 == "dense":
pytest.skip("Skipping dense-dense case")
dtype3 = upcast(dtype1, dtype2)
dtype4 = dtype3 if can_inplace else "int32"
inplace = can_inplace
x = mat(format1, "x", dtype1)
y = mat(format2, "y", dtype2)
a = scalar("a", dtype=dtype3)
z = pytensor.shared(np.asarray(self.z, dtype=dtype4).copy())
......@@ -769,9 +774,6 @@ class TestUsmm:
f_b_out = f_b(z_data, a_data, x_data, y_data)
# Can it work inplace?
inplace = dtype4 == pytensor.scalar.upcast(dtype1, dtype2, dtype3)
# To make it easier to check the toposort
mode = pytensor.compile.mode.get_default_mode().excluding("fusion")
......@@ -782,17 +784,7 @@ class TestUsmm:
f_a_out = z.get_value(borrow=True)
else:
f_a = pytensor.function([a, x, y], z - a * psm.dot(x, y), mode=mode)
# In DebugMode there is a strange difference with complex
# So we raise a little the threshold a little.
try:
orig_atol = pytensor.tensor.math.float64_atol
orig_rtol = pytensor.tensor.math.float64_rtol
pytensor.tensor.math.float64_atol = 1e-7
pytensor.tensor.math.float64_rtol = 1e-6
f_a_out = f_a(a_data, x_data, y_data)
finally:
pytensor.tensor.math.float64_atol = orig_atol
pytensor.tensor.math.float64_rtol = orig_rtol
f_a_out = f_a(a_data, x_data, y_data)
# As we do a dot product of 2 vector of 100 element,
# This mean we can have 2*100*eps abs error.
......@@ -804,7 +796,7 @@ class TestUsmm:
rtol = None
utt.assert_allclose(f_a_out, f_b_out, rtol=rtol, atol=atol)
topo = f_a.maker.fgraph.toposort()
up = pytensor.scalar.upcast(dtype1, dtype2, dtype3, dtype4)
up = upcast(dtype1, dtype2, dtype3, dtype4)
fast_compile = pytensor.config.mode == "FAST_COMPILE"
......@@ -906,9 +898,6 @@ class TestUsmm:
f_b_out = f_b(z_data, a_data, x_data, y_data)
# Can it work inplace?
# inplace = dtype4 == pytensor.scalar.upcast(dtype1, dtype2, dtype3)
# To make it easier to check the toposort
mode = pytensor.compile.mode.get_default_mode().excluding("fusion")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论