提交 53a2b374 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Check that TypeError is raised when sum downcasts

Update existing test to check that a TypeError is raised if an inappropriate output dtype is specified in sum().
上级 cd118af7
...@@ -499,7 +499,8 @@ class test_IsInf_IsNan(unittest.TestCase): ...@@ -499,7 +499,8 @@ class test_IsInf_IsNan(unittest.TestCase):
return self.run_isfunc('isnan') return self.run_isfunc('isnan')
def test_sum_default_dtype(): class T_sum_dtype(unittest.TestCase):
def test_sum_default_dtype(self):
""" """
Test the default dtype of a sum(). Test the default dtype of a sum().
""" """
...@@ -517,8 +518,7 @@ def test_sum_default_dtype(): ...@@ -517,8 +518,7 @@ def test_sum_default_dtype():
uint32='uint64', uint32='uint64',
).get(dtype, dtype) ).get(dtype, dtype)
def test_sum_custom_dtype(self):
def test_sum_custom_dtype():
""" """
Test the ability to provide your own output dtype for a sum. Test the ability to provide your own output dtype for a sum.
""" """
...@@ -529,7 +529,19 @@ def test_sum_custom_dtype(): ...@@ -529,7 +529,19 @@ def test_sum_custom_dtype():
x = tensor.matrix(dtype=input_dtype) x = tensor.matrix(dtype=input_dtype)
for output_dtype in imap(str, theano.scalar.all_types): for output_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)] axis = axes[idx % len(axes)]
assert x.sum(dtype=output_dtype, axis=axis).dtype == output_dtype # If output_dtype would force a downcast, we expect a TypeError
# We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and
output_dtype in continuous_dtypes)
):
sum_var = x.sum(dtype=output_dtype, axis=axis)
assert sum_var.dtype == output_dtype
else:
self.assertRaises(TypeError,
x.sum, dtype=output_dtype, axis=axis)
idx += 1 idx += 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论