提交 935ce79a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Tune down TestMeandDtype.test_mean_custom_dtype

上级 3eea7d0e
...@@ -3210,52 +3210,56 @@ class TestMeanDtype: ...@@ -3210,52 +3210,56 @@ class TestMeanDtype:
# TODO FIXME: This is a bad test # TODO FIXME: This is a bad test
f(data) f(data)
@pytest.mark.slow @pytest.mark.parametrize(
def test_mean_custom_dtype(self): "input_dtype",
(
"bool",
"uint16",
"int8",
"int64",
"float16",
"float32",
"float64",
"complex64",
"complex128",
),
)
@pytest.mark.parametrize(
"sum_dtype",
(
"bool",
"uint16",
"int8",
"int64",
"float16",
"float32",
"float64",
"complex64",
"complex128",
),
)
@pytest.mark.parametrize("axis", [None, ()])
def test_mean_custom_dtype(self, input_dtype, sum_dtype, axis):
# Test the ability to provide your own output dtype for a mean. # Test the ability to provide your own output dtype for a mean.
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [], [0], [1], [0, 1]]
idx = 0
for input_dtype in map(str, ps.all_types):
x = matrix(dtype=input_dtype) x = matrix(dtype=input_dtype)
for sum_dtype in map(str, ps.all_types): # If the inner sum cannot be created, it will raise a TypeError.
axis = axes[idx % len(axes)]
# If the inner sum cannot be created, it will raise a
# TypeError.
try:
mean_var = x.mean(dtype=sum_dtype, axis=axis) mean_var = x.mean(dtype=sum_dtype, axis=axis)
except TypeError:
pass
else:
# Executed if no TypeError was raised
if sum_dtype in discrete_dtypes: if sum_dtype in discrete_dtypes:
assert mean_var.dtype == "float64", (mean_var.dtype, sum_dtype) assert mean_var.dtype == "float64", (mean_var.dtype, sum_dtype)
else: else:
assert mean_var.dtype == sum_dtype, (mean_var.dtype, sum_dtype) assert mean_var.dtype == sum_dtype, (mean_var.dtype, sum_dtype)
if (
"complex" in input_dtype or "complex" in sum_dtype f = function([x], mean_var, mode="FAST_COMPILE")
) and input_dtype != sum_dtype: data = np.ones((2, 1)).astype(input_dtype)
continue if axis != ():
f = function([x], mean_var) expected_res = np.array(2).astype(sum_dtype) / 2
data = np.random.random((3, 4)) * 10
data = data.astype(input_dtype)
# TODO FIXME: This is a bad test
f(data)
# Check that we can take the gradient, when implemented
if "complex" in mean_var.dtype:
continue
try:
grad(mean_var.sum(), x, disconnected_inputs="ignore")
except NotImplementedError:
# TrueDiv does not seem to have a gradient when
# the numerator is complex.
if mean_var.dtype in complex_dtypes:
pass
else: else:
raise expected_res = data
np.testing.assert_allclose(f(data), expected_res)
idx += 1 if "complex" not in mean_var.dtype:
grad(mean_var.sum(), x, disconnected_inputs="ignore")
def test_mean_precision(self): def test_mean_precision(self):
# Check that the default accumulator precision is sufficient # Check that the default accumulator precision is sufficient
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论