提交 df63bd24 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Test: Reduce number of parametrizations

2304 -> 288
上级 ae20b6ab
...@@ -9,6 +9,7 @@ import pytensor ...@@ -9,6 +9,7 @@ import pytensor
import pytensor.sparse.math as psm import pytensor.sparse.math as psm
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.scalar import upcast
from pytensor.sparse.basic import ( from pytensor.sparse.basic import (
CSR, CSR,
CSMProperties, CSMProperties,
...@@ -367,7 +368,7 @@ class TestStructuredDot: ...@@ -367,7 +368,7 @@ class TestStructuredDot:
) )
for dense_dtype in typenames: for dense_dtype in typenames:
for sparse_dtype in typenames: for sparse_dtype in typenames:
correct_dtype = pytensor.scalar.upcast(sparse_dtype, dense_dtype) correct_dtype = upcast(sparse_dtype, dense_dtype)
a = SparseTensorType("csc", dtype=sparse_dtype)() a = SparseTensorType("csc", dtype=sparse_dtype)()
b = matrix(dtype=dense_dtype) b = matrix(dtype=dense_dtype)
d = structured_dot(a, b) d = structured_dot(a, b)
...@@ -736,11 +737,10 @@ class TestUsmm: ...@@ -736,11 +737,10 @@ class TestUsmm:
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("dtype1", ["float32", "float64", "int16", "complex64"]) @pytest.mark.parametrize("dtype1", ["float32", "float64", "int16", "complex64"])
@pytest.mark.parametrize("dtype2", ["float32", "float64", "int16", "complex64"]) @pytest.mark.parametrize("dtype2", ["float32", "float64", "int16", "complex64"])
@pytest.mark.parametrize("dtype3", ["float32", "float64", "int16", "complex64"]) @pytest.mark.parametrize("can_inplace", [False, True])
@pytest.mark.parametrize("dtype4", ["float32", "float64", "int16", "complex64"])
@pytest.mark.parametrize("format1", ["dense", "csc", "csr"]) @pytest.mark.parametrize("format1", ["dense", "csc", "csr"])
@pytest.mark.parametrize("format2", ["dense", "csc", "csr"]) @pytest.mark.parametrize("format2", ["dense", "csc", "csr"])
def test_basic(self, dtype1, dtype2, dtype3, dtype4, format1, format2): def test_basic(self, dtype1, dtype2, can_inplace, format1, format2):
def mat(format, name, dtype): def mat(format, name, dtype):
if format == "dense": if format == "dense":
return matrix(name, dtype=dtype) return matrix(name, dtype=dtype)
...@@ -750,8 +750,13 @@ class TestUsmm: ...@@ -750,8 +750,13 @@ class TestUsmm:
if format1 == "dense" and format2 == "dense": if format1 == "dense" and format2 == "dense":
pytest.skip("Skipping dense-dense case") pytest.skip("Skipping dense-dense case")
dtype3 = upcast(dtype1, dtype2)
dtype4 = dtype3 if can_inplace else "int32"
inplace = can_inplace
x = mat(format1, "x", dtype1) x = mat(format1, "x", dtype1)
y = mat(format2, "y", dtype2) y = mat(format2, "y", dtype2)
a = scalar("a", dtype=dtype3) a = scalar("a", dtype=dtype3)
z = pytensor.shared(np.asarray(self.z, dtype=dtype4).copy()) z = pytensor.shared(np.asarray(self.z, dtype=dtype4).copy())
...@@ -769,9 +774,6 @@ class TestUsmm: ...@@ -769,9 +774,6 @@ class TestUsmm:
f_b_out = f_b(z_data, a_data, x_data, y_data) f_b_out = f_b(z_data, a_data, x_data, y_data)
# Can it work inplace?
inplace = dtype4 == pytensor.scalar.upcast(dtype1, dtype2, dtype3)
# To make it easier to check the toposort # To make it easier to check the toposort
mode = pytensor.compile.mode.get_default_mode().excluding("fusion") mode = pytensor.compile.mode.get_default_mode().excluding("fusion")
...@@ -782,17 +784,7 @@ class TestUsmm: ...@@ -782,17 +784,7 @@ class TestUsmm:
f_a_out = z.get_value(borrow=True) f_a_out = z.get_value(borrow=True)
else: else:
f_a = pytensor.function([a, x, y], z - a * psm.dot(x, y), mode=mode) f_a = pytensor.function([a, x, y], z - a * psm.dot(x, y), mode=mode)
# In DebugMode there is a strange difference with complex
# So we raise a little the threshold a little.
try:
orig_atol = pytensor.tensor.math.float64_atol
orig_rtol = pytensor.tensor.math.float64_rtol
pytensor.tensor.math.float64_atol = 1e-7
pytensor.tensor.math.float64_rtol = 1e-6
f_a_out = f_a(a_data, x_data, y_data) f_a_out = f_a(a_data, x_data, y_data)
finally:
pytensor.tensor.math.float64_atol = orig_atol
pytensor.tensor.math.float64_rtol = orig_rtol
# As we do a dot product of 2 vector of 100 element, # As we do a dot product of 2 vector of 100 element,
# This mean we can have 2*100*eps abs error. # This mean we can have 2*100*eps abs error.
...@@ -804,7 +796,7 @@ class TestUsmm: ...@@ -804,7 +796,7 @@ class TestUsmm:
rtol = None rtol = None
utt.assert_allclose(f_a_out, f_b_out, rtol=rtol, atol=atol) utt.assert_allclose(f_a_out, f_b_out, rtol=rtol, atol=atol)
topo = f_a.maker.fgraph.toposort() topo = f_a.maker.fgraph.toposort()
up = pytensor.scalar.upcast(dtype1, dtype2, dtype3, dtype4) up = upcast(dtype1, dtype2, dtype3, dtype4)
fast_compile = pytensor.config.mode == "FAST_COMPILE" fast_compile = pytensor.config.mode == "FAST_COMPILE"
...@@ -906,9 +898,6 @@ class TestUsmm: ...@@ -906,9 +898,6 @@ class TestUsmm:
f_b_out = f_b(z_data, a_data, x_data, y_data) f_b_out = f_b(z_data, a_data, x_data, y_data)
# Can it work inplace?
# inplace = dtype4 == pytensor.scalar.upcast(dtype1, dtype2, dtype3)
# To make it easier to check the toposort # To make it easier to check the toposort
mode = pytensor.compile.mode.get_default_mode().excluding("fusion") mode = pytensor.compile.mode.get_default_mode().excluding("fusion")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论