提交 1005832b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add tests for arg "dtype" of tensor.mean.

上级 d7cf98bf
...@@ -544,6 +544,45 @@ class T_sum_dtype(unittest.TestCase): ...@@ -544,6 +544,45 @@ class T_sum_dtype(unittest.TestCase):
idx += 1 idx += 1
class T_mean_dtype(unittest.TestCase):
def test_mean_default_dtype(self):
"""
Test the default dtype of a mean().
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype).mean(axis=axis)
if dtype in discrete_dtypes:
assert x.dtype == 'float64'
else:
assert x.dtype == dtype, (x, x.dtype, dtype)
def test_mean_custom_dtype(self):
"""
Test the ability to provide your own output dtype for a mean.
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
idx = 0
for input_dtype in imap(str, theano.scalar.all_types):
x = tensor.matrix(dtype=input_dtype)
for sum_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)]
# If the inner sum cannot be created, it will raise a TypeError.
try:
mean_var = x.mean(dtype=sum_dtype, axis=axis)
except TypeError:
pass
else:
# Executed if no TypeError was raised
if sum_dtype in discrete_dtypes:
assert mean_var.dtype == 'float64', (mean_var.dtype, sum_dtype)
else:
assert mean_var.dtype == sum_dtype, (mean_var.dtype, output_dtype)
idx += 1
if __name__ == '__main__': if __name__ == '__main__':
#unittest.main() #unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论