提交 f9d4ec48 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

bincount now supports for more dtypes

上级 3c6c3367
...@@ -437,7 +437,7 @@ class BinCountOp(theano.Op): ...@@ -437,7 +437,7 @@ class BinCountOp(theano.Op):
def bincount(x, weights=None, minlength=None): def bincount(x, weights=None, minlength=None):
"""Count number of occurrences of each value in array of non-negative ints. """Count number of occurrences of each value in array of ints.
The number of bins (of size 1) is one larger than the largest The number of bins (of size 1) is one larger than the largest
value in x. If minlength is specified, there will be at least value in x. If minlength is specified, there will be at least
...@@ -446,7 +446,6 @@ def bincount(x, weights=None, minlength=None): ...@@ -446,7 +446,6 @@ def bincount(x, weights=None, minlength=None):
number of occurrences of its index value in x. If weights is number of occurrences of its index value in x. If weights is
specified the input array is weighted by it, i.e. if a value n specified the input array is weighted by it, i.e. if a value n
is found at position i, out[n] += weight[i] instead of out[n] += 1. is found at position i, out[n] += weight[i] instead of out[n] += 1.
Wraping of numpy.bincount
:param x: 1 dimension, nonnegative ints :param x: 1 dimension, nonnegative ints
...@@ -458,45 +457,30 @@ def bincount(x, weights=None, minlength=None): ...@@ -458,45 +457,30 @@ def bincount(x, weights=None, minlength=None):
.. versionadded:: 0.6 .. versionadded:: 0.6
""" """
compatible_type = ('int8', 'int16', 'int32', 'int64', compatible_type = ('int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64') 'uint8', 'uint16', 'uint32')
unsupported_dtypes = ('uint64',)
if x.dtype in unsupported_dtypes:
raise TypeError(
("Input dtype %s is not supported, "
% unsupported_dtypes), x.dtype)
if x.dtype not in compatible_type: if x.dtype not in compatible_type:
raise TypeError("Inputs dtype must be an integer.") raise TypeError("Inputs dtype must be an integer.")
# Some dtypes are not supported by numpy's implementation of bincount.
# Until another one is available, we should fail at graph construction
# time, not wait for execution.
int_bitwidth = theano.gof.python_int_bitwidth()
if int_bitwidth == 64:
numpy_unsupported_dtypes = ('uint64',)
if int_bitwidth == 32:
numpy_unsupported_dtypes = ('uint32', 'int64', 'uint64')
intp_bitwidth = theano.gof.local_bitwidth()
if intp_bitwidth == 32:
out_type = basic.ivector()
elif intp_bitwidth == 64:
out_type = basic.lvector()
if x.dtype in numpy_unsupported_dtypes:
raise TypeError(
("Input dtypes %s are not supported by numpy.bincount, "
% numpy_unsupported_dtypes), x.dtype)
if x.ndim != 1: if x.ndim != 1:
raise TypeError("Inputs must be of dimension 1.") raise TypeError("Inputs must be of dimension 1.")
max_value = x.max() + 1 max_value = theano.tensor.cast(x.max() + 1, 'int64')
if minlength is not None: if minlength is not None:
max_value = theano.tensor.maximum(max_value, minlength) max_value = theano.tensor.maximum(max_value, minlength)
if weights is None: if weights is None:
out = theano.tensor.zeros([max_value], dtype=out_type.dtype) out = theano.tensor.zeros([max_value], dtype=x.dtype)
out = theano.tensor.inc_subtensor(out[x], 1) out = theano.tensor.inc_subtensor(out[x], 1)
else: else:
out_type = basic.dvector() out = theano.tensor.zeros([max_value], dtype=weights.dtype)
out = theano.tensor.zeros([max_value], dtype=out_type.dtype)
out = theano.tensor.inc_subtensor(out[x], weights) out = theano.tensor.inc_subtensor(out[x], weights)
return out return out
......
...@@ -114,6 +114,31 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -114,6 +114,31 @@ class TestBinCountOp(utt.InferShapeTester):
self.op_class = BinCountOp self.op_class = BinCountOp
self.op = BinCountOp() self.op = BinCountOp()
def test_bincountFn(self):
w = T.vector('w')
for dtype in ('int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64'):
x = T.vector('x', dtype=dtype)
# uint64 always fails
if dtype in ('uint64',):
self.assertRaises(TypeError, bincount, x)
else:
a = np.random.random_integers(50, size=(25)).astype(dtype)
weights = np.random.random((25,)).astype(config.floatX)
f1 = theano.function([x], bincount(x))
f2 = theano.function([x, w], bincount(x, weights=w))
assert (np.bincount(a) == f1(a)).all()
assert np.allclose(np.bincount(a, weights=weights),
f2(a, weights))
f3 = theano.function([x], bincount(x, minlength=23))
f4 = theano.function([x], bincount(x, minlength=5))
assert (np.bincount(a, minlength=23) == f3(a)).all()
assert (np.bincount(a, minlength=5) == f4(a)).all()
def test_bincountOp(self): def test_bincountOp(self):
w = T.vector('w') w = T.vector('w')
for dtype in ('int8', 'int16', 'int32', 'int64', for dtype in ('int8', 'int16', 'int32', 'int64',
...@@ -129,22 +154,22 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -129,22 +154,22 @@ class TestBinCountOp(utt.InferShapeTester):
x = T.vector('x', dtype=dtype) x = T.vector('x', dtype=dtype)
if dtype in numpy_unsupported_dtypes: if dtype in numpy_unsupported_dtypes:
self.assertRaises(TypeError, bincount, x) self.assertRaises(TypeError, BinCountOp(), x)
else: else:
a = np.random.random_integers(50, size=(25)).astype(dtype) a = np.random.random_integers(50, size=(25)).astype(dtype)
weights = np.random.random((25,)).astype(config.floatX) weights = np.random.random((25,)).astype(config.floatX)
f1 = theano.function([x], bincount(x)) f1 = theano.function([x], BinCountOp()(x, weights=None))
f2 = theano.function([x, w], bincount(x, weights=w)) f2 = theano.function([x, w], BinCountOp()(x, weights=w))
assert (np.bincount(a) == f1(a)).all() assert (np.bincount(a) == f1(a)).all()
assert np.allclose(np.bincount(a, weights=weights), assert np.allclose(np.bincount(a, weights=weights),
f2(a, weights)) f2(a, weights))
if not numpy_16: if not numpy_16:
continue continue
f3 = theano.function([x], bincount(x, minlength=23)) f3 = theano.function([x], BinCountOp(minlength=23)(x, weights=None))
f4 = theano.function([x], bincount(x, minlength=5)) f4 = theano.function([x], BinCountOp(minlength=5)(x, weights=None))
assert (np.bincount(a, minlength=23) == f3(a)).all() assert (np.bincount(a, minlength=23) == f3(a)).all()
assert (np.bincount(a, minlength=5) == f4(a)).all() assert (np.bincount(a, minlength=5) == f4(a)).all()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论