提交 53a2b374 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Check that TypeError is raised when sum downcasts

Update existing test to check that a TypeError is raised if an inappropriate output dtype is specified in sum().
上级 cd118af7
......@@ -499,7 +499,8 @@ class test_IsInf_IsNan(unittest.TestCase):
return self.run_isfunc('isnan')
def test_sum_default_dtype():
class T_sum_dtype(unittest.TestCase):
def test_sum_default_dtype(self):
"""
Test the default dtype of a sum().
"""
......@@ -517,8 +518,7 @@ def test_sum_default_dtype():
uint32='uint64',
).get(dtype, dtype)
def test_sum_custom_dtype():
def test_sum_custom_dtype(self):
"""
Test the ability to provide your own output dtype for a sum.
"""
......@@ -529,7 +529,19 @@ def test_sum_custom_dtype():
x = tensor.matrix(dtype=input_dtype)
for output_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)]
assert x.sum(dtype=output_dtype, axis=axis).dtype == output_dtype
# If output_dtype would force a downcast, we expect a TypeError
# We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and
output_dtype in continuous_dtypes)
):
sum_var = x.sum(dtype=output_dtype, axis=axis)
assert sum_var.dtype == output_dtype
else:
self.assertRaises(TypeError,
x.sum, dtype=output_dtype, axis=axis)
idx += 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论