提交 7af328c7 authored 作者: Frederic's avatar Frederic

Add many tests related to the reduction on 0 axis and upcast bug.

上级 5e02eb8b
...@@ -747,11 +747,16 @@ class T_mean_dtype(unittest.TestCase): ...@@ -747,11 +747,16 @@ class T_mean_dtype(unittest.TestCase):
axes = [None, 0, 1, [], [0], [1], [0, 1]] axes = [None, 0, 1, [], [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)): for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)] axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype).mean(axis=axis) x = tensor.matrix(dtype=dtype)
if dtype in tensor.discrete_dtypes: m = x.mean(axis=axis)
assert x.dtype == 'float64' if dtype in tensor.discrete_dtypes and axis != []:
assert m.dtype == 'float64'
else: else:
assert x.dtype == dtype, (x, x.dtype, dtype) assert m.dtype == dtype, (m, m.dtype, dtype)
f = theano.function([x], m)
data = numpy.random.rand(3, 4) * 10
data = data.astype(dtype)
f(data)
def test_mean_custom_dtype(self): def test_mean_custom_dtype(self):
""" """
...@@ -772,12 +777,17 @@ class T_mean_dtype(unittest.TestCase): ...@@ -772,12 +777,17 @@ class T_mean_dtype(unittest.TestCase):
pass pass
else: else:
# Executed if no TypeError was raised # Executed if no TypeError was raised
if sum_dtype in tensor.discrete_dtypes: if sum_dtype in tensor.discrete_dtypes and axis != []:
assert mean_var.dtype == 'float64', ( assert mean_var.dtype == 'float64', (
(mean_var.dtype, sum_dtype)) (mean_var.dtype, sum_dtype))
else: else:
assert mean_var.dtype == sum_dtype, ( assert mean_var.dtype == sum_dtype, (
(mean_var.dtype, sum_dtype)) (mean_var.dtype, sum_dtype))
if ("complex" not in sum_dtype and "complex" not in input_dtype):
f = theano.function([x], mean_var)
data = numpy.random.rand(3, 4) * 10
data = data.astype(input_dtype)
f(data)
# Check that we can take the gradient, when implemented # Check that we can take the gradient, when implemented
if "complex" in mean_var.dtype: if "complex" in mean_var.dtype:
continue continue
...@@ -812,8 +822,9 @@ class T_prod_dtype(unittest.TestCase): ...@@ -812,8 +822,9 @@ class T_prod_dtype(unittest.TestCase):
axes = [None, 0, 1, [], [0], [1], [0, 1]] axes = [None, 0, 1, [], [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)): for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)] axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype).prod(axis=axis) x = tensor.matrix(dtype=dtype)
assert x.dtype == dict( p = x.prod(axis=axis)
assert p.dtype == dict(
int8='int64', int8='int64',
int16='int64', int16='int64',
int32='int64', int32='int64',
...@@ -821,6 +832,10 @@ class T_prod_dtype(unittest.TestCase): ...@@ -821,6 +832,10 @@ class T_prod_dtype(unittest.TestCase):
uint16='uint64', uint16='uint64',
uint32='uint64', uint32='uint64',
).get(dtype, dtype) ).get(dtype, dtype)
f = theano.function([x], p)
data = numpy.random.rand(3, 4) * 10
data = data.astype(dtype)
f(data)
def test_prod_default_acc_dtype(self): def test_prod_default_acc_dtype(self):
""" """
...@@ -830,8 +845,9 @@ class T_prod_dtype(unittest.TestCase): ...@@ -830,8 +845,9 @@ class T_prod_dtype(unittest.TestCase):
axes = [None, 0, 1, [], [0], [1], [0, 1]] axes = [None, 0, 1, [], [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)): for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)] axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype).prod(axis=axis) x = tensor.matrix(dtype=dtype)
assert x.owner.op.acc_dtype == dict( p = x.prod(axis=axis)
assert p.owner.op.acc_dtype == dict(
int8='int64', int8='int64',
int16='int64', int16='int64',
int32='int64', int32='int64',
...@@ -841,6 +857,10 @@ class T_prod_dtype(unittest.TestCase): ...@@ -841,6 +857,10 @@ class T_prod_dtype(unittest.TestCase):
float32='float64', float32='float64',
complex64='complex128', complex64='complex128',
).get(dtype, dtype) ).get(dtype, dtype)
f = theano.function([x], p)
data = numpy.random.rand(3, 4) * 10
data = data.astype(dtype)
f(data)
def test_prod_custom_dtype(self): def test_prod_custom_dtype(self):
""" """
...@@ -861,6 +881,10 @@ class T_prod_dtype(unittest.TestCase): ...@@ -861,6 +881,10 @@ class T_prod_dtype(unittest.TestCase):
# Check that we can take the gradient # Check that we can take the gradient
tensor.grad(prod_var.sum(), x, tensor.grad(prod_var.sum(), x,
disconnected_inputs='ignore') disconnected_inputs='ignore')
f = theano.function([x], prod_var)
data = numpy.random.rand(3, 4) * 10
data = data.astype(input_dtype)
f(data)
idx += 1 idx += 1
def test_prod_custom_acc_dtype(self): def test_prod_custom_acc_dtype(self):
...@@ -889,6 +913,10 @@ class T_prod_dtype(unittest.TestCase): ...@@ -889,6 +913,10 @@ class T_prod_dtype(unittest.TestCase):
# Check that we can take the gradient # Check that we can take the gradient
tensor.grad(prod_var.sum(), x, tensor.grad(prod_var.sum(), x,
disconnected_inputs='ignore') disconnected_inputs='ignore')
f = theano.function([x], prod_var)
data = numpy.random.rand(3, 4) * 10
data = data.astype(input_dtype)
f(data)
else: else:
self.assertRaises(TypeError, self.assertRaises(TypeError,
x.prod, acc_dtype=acc_dtype, axis=axis) x.prod, acc_dtype=acc_dtype, axis=axis)
...@@ -923,8 +951,9 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -923,8 +951,9 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
axes = [None, 0, 1, [], [0], [1], [0, 1]] axes = [None, 0, 1, [], [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)): for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)] axis = axes[idx % len(axes)]
x = ProdWithoutZeros(axis=axis)(tensor.matrix(dtype=dtype)) x = tensor.matrix(dtype=dtype)
assert x.owner.op.acc_dtype == dict( p = ProdWithoutZeros(axis=axis)(x)
assert p.owner.op.acc_dtype == dict(
int8='int64', int8='int64',
int16='int64', int16='int64',
int32='int64', int32='int64',
...@@ -935,6 +964,13 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -935,6 +964,13 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
complex64='complex128' complex64='complex128'
).get(dtype, dtype) ).get(dtype, dtype)
if 'complex' in dtype:
continue
f = theano.function([x], p)
data = numpy.random.rand(3, 4) * 10
data = data.astype(dtype)
f(data)
def test_prod_without_zeros_custom_dtype(self): def test_prod_without_zeros_custom_dtype(self):
""" """
Test ability to provide your own output dtype for a ProdWithoutZeros(). Test ability to provide your own output dtype for a ProdWithoutZeros().
...@@ -949,6 +985,12 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -949,6 +985,12 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
prod_woz_var = ProdWithoutZeros( prod_woz_var = ProdWithoutZeros(
axis=axis, dtype=output_dtype)(x) axis=axis, dtype=output_dtype)(x)
assert prod_woz_var.dtype == output_dtype assert prod_woz_var.dtype == output_dtype
if ('complex' not in input_dtype and
'complex' not in output_dtype):
f = theano.function([x], prod_woz_var)
data = numpy.random.rand(3, 4) * 10
data = data.astype(input_dtype)
f(data)
idx += 1 idx += 1
def test_prod_without_zeros_custom_acc_dtype(self): def test_prod_without_zeros_custom_acc_dtype(self):
...@@ -972,6 +1014,13 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -972,6 +1014,13 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
prod_woz_var = ProdWithoutZeros( prod_woz_var = ProdWithoutZeros(
axis=axis, acc_dtype=acc_dtype)(x) axis=axis, acc_dtype=acc_dtype)(x)
assert prod_woz_var.owner.op.acc_dtype == acc_dtype assert prod_woz_var.owner.op.acc_dtype == acc_dtype
if acc_dtype.startswith('complex') and input_dtype != acc_dtype:
continue
f = theano.function([x], prod_woz_var)
data = numpy.random.rand(2, 3) * 3
data = data.astype(input_dtype)
f(data)
else: else:
self.assertRaises(TypeError, self.assertRaises(TypeError,
ProdWithoutZeros(axis=axis, acc_dtype=acc_dtype), ProdWithoutZeros(axis=axis, acc_dtype=acc_dtype),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论