提交 c137ab9c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add test for dtype keyword of Prod.

上级 563e4adf
...@@ -584,6 +584,51 @@ class T_mean_dtype(unittest.TestCase): ...@@ -584,6 +584,51 @@ class T_mean_dtype(unittest.TestCase):
idx += 1 idx += 1
class T_prod_dtype(unittest.TestCase):
def test_prod_default_dtype(self):
"""
Test the default dtype of a prod().
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype).prod(axis=axis)
assert x.dtype == dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(dtype, dtype)
def test_prod_custom_dtype(self):
"""
Test the ability to provide your own output dtype for a prod.
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
idx = 0
for input_dtype in imap(str, theano.scalar.all_types):
x = tensor.matrix(dtype=input_dtype)
for output_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)]
# If output_dtype would force a downcast, we expect a TypeError
# We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and
output_dtype in continuous_dtypes)
):
prod_var = x.prod(dtype=output_dtype, axis=axis)
assert prod_var.dtype == output_dtype
else:
self.assertRaises(TypeError,
x.prod, dtype=output_dtype, axis=axis)
idx += 1
if __name__ == '__main__': if __name__ == '__main__':
#unittest.main() #unittest.main()
suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')]) suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论