提交 b47dcc0b authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Add unsigneg integers to sparse.

上级 4cc35522
...@@ -108,7 +108,8 @@ The set of all accepted ``dtype`` for the sparse matrices can be found in ...@@ -108,7 +108,8 @@ The set of all accepted ``dtype`` for the sparse matrices can be found in
``sparse.all_dtypes``. ``sparse.all_dtypes``.
>>> sparse.all_dtypes >>> sparse.all_dtypes
set(['int8', 'int16', 'int32', 'int64', 'float32', 'float64', 'complex64', 'complex128']) set(['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'float32', 'float64', 'complex64', 'complex128'])
To and Fro To and Fro
---------- ----------
......
...@@ -535,7 +535,11 @@ class upgrade_to_float(object): ...@@ -535,7 +535,11 @@ class upgrade_to_float(object):
conv = {int8: float32, conv = {int8: float32,
int16: float32, int16: float32,
int32: float64, int32: float64,
int64: float64} int64: float64,
uint8: float32,
uint16: float32,
uint32: float64,
uint64: float64}
return Scalar(Scalar.upcast(*[conv.get(type, type) return Scalar(Scalar.upcast(*[conv.get(type, type)
for type in types])), for type in types])),
......
...@@ -403,6 +403,7 @@ class SparseType(gof.Type): ...@@ -403,6 +403,7 @@ class SparseType(gof.Type):
format_cls = {'csr': scipy.sparse.csr_matrix, format_cls = {'csr': scipy.sparse.csr_matrix,
'csc': scipy.sparse.csc_matrix} 'csc': scipy.sparse.csc_matrix}
dtype_set = set(['int8', 'int16', 'int32', 'int64', 'float32', dtype_set = set(['int8', 'int16', 'int32', 'int64', 'float32',
'uint8', 'uint16', 'uint32', 'uint64',
'float64', 'complex64', 'complex128']) 'float64', 'complex64', 'complex128'])
ndim = 2 ndim = 2
......
...@@ -104,6 +104,8 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5, gap=None): ...@@ -104,6 +104,8 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5, gap=None):
assert len(shape) == 2 assert len(shape) == 2
assert out_dtype in sparse.all_dtypes assert out_dtype in sparse.all_dtypes
assert gap is None or isinstance(gap, (tuple, list)) assert gap is None or isinstance(gap, (tuple, list))
if gap is not None and out_dtype.startswith('u'):
assert gap[0] >= 0
def _rand(): def _rand():
where = numpy.random.binomial(1, p, size=shape).astype('int8') where = numpy.random.binomial(1, p, size=shape).astype('int8')
...@@ -2277,17 +2279,29 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None, ...@@ -2277,17 +2279,29 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
super(Tester, self).setUp() super(Tester, self).setUp()
self.op = op self.op = op
self.expected_f = expected_f self.expected_f = expected_f
self.gap = gap
def test_op(self): def test_op(self):
for format in sparse.sparse_formats: for format in sparse.sparse_formats:
for dtype in test_dtypes: for dtype in test_dtypes:
if dtype == 'int8': if dtype == 'int8' or dtype == 'uint8':
continue continue
# When testing with unsigned integers,
# we must check if the gap contains
# negative numbers.
if dtype.startswith('uint'):
if self.gap and len(self.gap) == 2 and self.gap[0] < 0:
if self.gap[1] > 1:
self.gap = (0, self.gap[1])
else:
continue
variable, data = sparse_random_inputs( variable, data = sparse_random_inputs(
format, format,
shape=(4, 7), shape=(4, 7),
out_dtype=dtype, out_dtype=dtype,
gap=gap) gap=self.gap)
f = theano.function(variable, self.op(*variable)) f = theano.function(variable, self.op(*variable))
...@@ -2312,15 +2326,26 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None, ...@@ -2312,15 +2326,26 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
# function. # function.
# Second, the tolerance for the checkup in DebugMode # Second, the tolerance for the checkup in DebugMode
# is too high. # is too high.
if 'int8' in test_dtypes: for dtype in ['int8', 'uint8']:
if gap: if dtype in test_dtypes:
domain = gap if self.gap:
domain = self.gap
# When testing with unsigned integers,
# we must check if the gap contains
# negative numbers.
if dtype == 'uint8':
if len(domain) == 2 and domain[0] < 0:
if domain[1] > 1:
domain = (0, domain[1])
else:
continue
else: else:
domain = (0, 5) domain = (0, 5)
variable, data = sparse_random_inputs( variable, data = sparse_random_inputs(
format, format,
shape=(4, 7), shape=(4, 7),
out_dtype='int8', out_dtype=dtype,
gap=domain) gap=domain)
f = theano.function(variable, self.op(*variable)) f = theano.function(variable, self.op(*variable))
...@@ -2394,7 +2419,9 @@ StructuredSigmoidTester = elemwise_checker( ...@@ -2394,7 +2419,9 @@ StructuredSigmoidTester = elemwise_checker(
sparse.structured_sigmoid, sparse.structured_sigmoid,
structure_function(lambda x: 1.0 / (1.0 + numpy.exp(-x))), structure_function(lambda x: 1.0 / (1.0 + numpy.exp(-x))),
test_dtypes=[m for m in sparse.all_dtypes test_dtypes=[m for m in sparse.all_dtypes
if not m in sparse.complex_dtypes]) if (not m in sparse.complex_dtypes and
not m.startswith('uint'))],
gap=(-5, 5))
StructuredExpTester = elemwise_checker( StructuredExpTester = elemwise_checker(
sparse.structured_exp, sparse.structured_exp,
...@@ -2469,7 +2496,8 @@ SgnTester = elemwise_checker( ...@@ -2469,7 +2496,8 @@ SgnTester = elemwise_checker(
numpy.sign, numpy.sign,
grad_test=False, grad_test=False,
test_dtypes=[m for m in sparse.all_dtypes test_dtypes=[m for m in sparse.all_dtypes
if not m in sparse.complex_dtypes]) if (not m in sparse.complex_dtypes and
not m.startswith('uint'))])
CeilTester = elemwise_checker( CeilTester = elemwise_checker(
sparse.ceil, sparse.ceil,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论