提交 935ce79a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Tune down TestMeandDtype.test_mean_custom_dtype

上级 3eea7d0e
...@@ -3210,52 +3210,56 @@ class TestMeanDtype: ...@@ -3210,52 +3210,56 @@ class TestMeanDtype:
# TODO FIXME: This is a bad test # TODO FIXME: This is a bad test
f(data) f(data)
@pytest.mark.slow @pytest.mark.parametrize(
def test_mean_custom_dtype(self): "input_dtype",
(
"bool",
"uint16",
"int8",
"int64",
"float16",
"float32",
"float64",
"complex64",
"complex128",
),
)
@pytest.mark.parametrize(
"sum_dtype",
(
"bool",
"uint16",
"int8",
"int64",
"float16",
"float32",
"float64",
"complex64",
"complex128",
),
)
@pytest.mark.parametrize("axis", [None, ()])
def test_mean_custom_dtype(self, input_dtype, sum_dtype, axis):
# Test the ability to provide your own output dtype for a mean. # Test the ability to provide your own output dtype for a mean.
# We try multiple axis combinations even though axis should not matter. x = matrix(dtype=input_dtype)
axes = [None, 0, 1, [], [0], [1], [0, 1]] # If the inner sum cannot be created, it will raise a TypeError.
idx = 0 mean_var = x.mean(dtype=sum_dtype, axis=axis)
for input_dtype in map(str, ps.all_types): if sum_dtype in discrete_dtypes:
x = matrix(dtype=input_dtype) assert mean_var.dtype == "float64", (mean_var.dtype, sum_dtype)
for sum_dtype in map(str, ps.all_types): else:
axis = axes[idx % len(axes)] assert mean_var.dtype == sum_dtype, (mean_var.dtype, sum_dtype)
# If the inner sum cannot be created, it will raise a
# TypeError.
try:
mean_var = x.mean(dtype=sum_dtype, axis=axis)
except TypeError:
pass
else:
# Executed if no TypeError was raised
if sum_dtype in discrete_dtypes:
assert mean_var.dtype == "float64", (mean_var.dtype, sum_dtype)
else:
assert mean_var.dtype == sum_dtype, (mean_var.dtype, sum_dtype)
if (
"complex" in input_dtype or "complex" in sum_dtype
) and input_dtype != sum_dtype:
continue
f = function([x], mean_var)
data = np.random.random((3, 4)) * 10
data = data.astype(input_dtype)
# TODO FIXME: This is a bad test
f(data)
# Check that we can take the gradient, when implemented
if "complex" in mean_var.dtype:
continue
try:
grad(mean_var.sum(), x, disconnected_inputs="ignore")
except NotImplementedError:
# TrueDiv does not seem to have a gradient when
# the numerator is complex.
if mean_var.dtype in complex_dtypes:
pass
else:
raise
idx += 1 f = function([x], mean_var, mode="FAST_COMPILE")
data = np.ones((2, 1)).astype(input_dtype)
if axis != ():
expected_res = np.array(2).astype(sum_dtype) / 2
else:
expected_res = data
np.testing.assert_allclose(f(data), expected_res)
if "complex" not in mean_var.dtype:
grad(mean_var.sum(), x, disconnected_inputs="ignore")
def test_mean_precision(self): def test_mean_precision(self):
# Check that the default accumulator precision is sufficient # Check that the default accumulator precision is sufficient
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论