提交 9c55d4aa authored 作者: Pascal Lamblin's avatar Pascal Lamblin

More tests, following code review.

上级 c137ab9c
......@@ -13,6 +13,7 @@ from theano.compile.mode import get_default_mode
from theano.tensor.elemwise import *
from theano.tests import unittest_tools
complex_dtypes = map(str, scalar.complex_types)
def Env(i, o):
e = gof.Env(i, o)
......@@ -538,6 +539,10 @@ class T_sum_dtype(unittest.TestCase):
):
sum_var = x.sum(dtype=output_dtype, axis=axis)
assert sum_var.dtype == output_dtype
# Check that we can take the gradient
grad_var = tensor.grad(sum_var.sum(), x,
disconnected_inputs='ignore')
else:
self.assertRaises(TypeError,
x.sum, dtype=output_dtype, axis=axis)
......@@ -582,6 +587,18 @@ class T_mean_dtype(unittest.TestCase):
else:
assert mean_var.dtype == sum_dtype, (mean_var.dtype, output_dtype)
# Check that we can take the gradient, when implemented
try:
grad_var = tensor.grad(mean_var.sum(), x,
disconnected_inputs='ignore')
except NotImplementedError:
# TrueDiv does not seem to have a gradient when
# the numerator is complex.
if mean_var.dtype in complex_dtypes:
pass
else:
raise
idx += 1
class T_prod_dtype(unittest.TestCase):
......@@ -623,12 +640,63 @@ class T_prod_dtype(unittest.TestCase):
):
prod_var = x.prod(dtype=output_dtype, axis=axis)
assert prod_var.dtype == output_dtype
# Check that we can take the gradient
grad_var = tensor.grad(prod_var.sum(), x,
disconnected_inputs='ignore')
else:
self.assertRaises(TypeError,
x.prod, dtype=output_dtype, axis=axis)
idx += 1
class T_prod_without_zeros_dtype(unittest.TestCase):
def test_prod_without_zeros_default_dtype(self):
"""
Test the default dtype of a ProdWithoutZeros().
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)]
x = ProdWithoutZeros(axis=axis)(tensor.matrix(dtype=dtype))
assert x.dtype == dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(dtype, dtype)
def test_prod_without_zeros_custom_dtype(self):
"""
Test the ability to provide your own output dtype for a ProdWithoutZeros().
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
idx = 0
for input_dtype in imap(str, theano.scalar.all_types):
x = tensor.matrix(dtype=input_dtype)
for output_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)]
# If output_dtype would force a downcast, we expect a TypeError
# We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and
output_dtype in continuous_dtypes)
):
prod_woz_var = ProdWithoutZeros(
axis=axis, dtype=output_dtype)(x)
assert prod_woz_var.dtype == output_dtype
else:
self.assertRaises(TypeError,
ProdWithoutZeros(axis=axis, dtype=output_dtype),
x)
idx += 1
if __name__ == '__main__':
#unittest.main()
suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论