提交 73a6651f authored 作者: carriepl's avatar carriepl

Merge pull request #2628 from aalmah/bincount-op

fixes #2580: theano function for bincount
import numpy as np import numpy as np
import numpy import numpy
import warnings
import theano import theano
from theano.tensor import basic from theano.tensor import basic
...@@ -332,8 +332,11 @@ def diff(x, n=1, axis=-1): ...@@ -332,8 +332,11 @@ def diff(x, n=1, axis=-1):
class BinCountOp(theano.Op): class BinCountOp(theano.Op):
# See function bincount for docstring """
DEPRECATED: use bincount() instead.
See function bincount for docstring
"""
compatible_type = ('int8', 'int16', 'int32', 'int64', compatible_type = ('int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64') 'uint8', 'uint16', 'uint32', 'uint64')
"""Tuple of all compatible dtype for the parameter of this op.""" """Tuple of all compatible dtype for the parameter of this op."""
...@@ -355,6 +358,10 @@ class BinCountOp(theano.Op): ...@@ -355,6 +358,10 @@ class BinCountOp(theano.Op):
return hash(type(self)) ^ hash(self.minlength) return hash(type(self)) ^ hash(self.minlength)
def make_node(self, x, weights): def make_node(self, x, weights):
warnings.warn((
"Tile op is deprecated, use tile function instead."),
stacklevel=3)
x = basic.as_tensor_variable(x) x = basic.as_tensor_variable(x)
if x.dtype not in BinCountOp.compatible_type: if x.dtype not in BinCountOp.compatible_type:
...@@ -429,8 +436,8 @@ class BinCountOp(theano.Op): ...@@ -429,8 +436,8 @@ class BinCountOp(theano.Op):
return self.__class__.__name__ return self.__class__.__name__
def bincount(x, weights=None, minlength=None): def bincount(x, weights=None, minlength=None, assert_nonneg=False):
"""Count number of occurrences of each value in array of non-negative ints. """Count number of occurrences of each value in array of ints.
The number of bins (of size 1) is one larger than the largest The number of bins (of size 1) is one larger than the largest
value in x. If minlength is specified, there will be at least value in x. If minlength is specified, there will be at least
...@@ -439,7 +446,6 @@ def bincount(x, weights=None, minlength=None): ...@@ -439,7 +446,6 @@ def bincount(x, weights=None, minlength=None):
number of occurrences of its index value in x. If weights is number of occurrences of its index value in x. If weights is
specified the input array is weighted by it, i.e. if a value n specified the input array is weighted by it, i.e. if a value n
is found at position i, out[n] += weight[i] instead of out[n] += 1. is found at position i, out[n] += weight[i] instead of out[n] += 1.
Wraping of numpy.bincount
:param x: 1 dimension, nonnegative ints :param x: 1 dimension, nonnegative ints
...@@ -447,10 +453,43 @@ def bincount(x, weights=None, minlength=None): ...@@ -447,10 +453,43 @@ def bincount(x, weights=None, minlength=None):
Optional. Optional.
:param minlength: A minimum number of bins for the output array. :param minlength: A minimum number of bins for the output array.
Optional. Optional.
:param assert_nonneg: A flag that inserts an assert_op to check if
every input x is nonnegative.
Optional.
.. versionadded:: 0.6 .. versionadded:: 0.6
""" """
return BinCountOp(minlength=minlength)(x, weights) compatible_type = ('int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32')
unsupported_dtypes = ('uint64',)
if x.dtype in unsupported_dtypes:
raise TypeError(
("Input dtype %s is not supported, "
% unsupported_dtypes), x.dtype)
if x.dtype not in compatible_type:
raise TypeError("Inputs dtype must be an integer.")
if x.ndim != 1:
raise TypeError("Inputs must be of dimension 1.")
if assert_nonneg:
from theano.tensor.opt import Assert
assert_op = Assert('Input to bincount has negative values!')
x = assert_op(x, theano.tensor.all(x >= 0))
max_value = theano.tensor.cast(x.max() + 1, 'int64')
if minlength is not None:
max_value = theano.tensor.maximum(max_value, minlength)
if weights is None:
out = theano.tensor.zeros([max_value], dtype=x.dtype)
out = theano.tensor.inc_subtensor(out[x], 1)
else:
out = theano.tensor.zeros([max_value], dtype=weights.dtype)
out = theano.tensor.inc_subtensor(out[x], weights)
return out
def squeeze(x): def squeeze(x):
......
...@@ -115,6 +115,36 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -115,6 +115,36 @@ class TestBinCountOp(utt.InferShapeTester):
self.op_class = BinCountOp self.op_class = BinCountOp
self.op = BinCountOp() self.op = BinCountOp()
def test_bincountFn(self):
w = T.vector('w')
for dtype in ('int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64'):
x = T.vector('x', dtype=dtype)
# uint64 always fails
if dtype in ('uint64',):
self.assertRaises(TypeError, bincount, x)
else:
a = np.random.random_integers(50, size=(25)).astype(dtype)
weights = np.random.random((25,)).astype(config.floatX)
f1 = theano.function([x], bincount(x))
f2 = theano.function([x, w], bincount(x, weights=w))
assert (np.bincount(a) == f1(a)).all()
assert np.allclose(np.bincount(a, weights=weights),
f2(a, weights))
f3 = theano.function([x], bincount(x, minlength=23))
f4 = theano.function([x], bincount(x, minlength=5))
assert (np.bincount(a, minlength=23) == f3(a)).all()
assert (np.bincount(a, minlength=5) == f4(a)).all()
# skip the following test when using unsigned ints
if not dtype.startswith('u'):
a[0] = -1
f5 = theano.function([x], bincount(x, assert_nonneg=True))
self.assertRaises(AssertionError, f5, a)
def test_bincountOp(self): def test_bincountOp(self):
w = T.vector('w') w = T.vector('w')
for dtype in ('int8', 'int16', 'int32', 'int64', for dtype in ('int8', 'int16', 'int32', 'int64',
...@@ -130,22 +160,22 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -130,22 +160,22 @@ class TestBinCountOp(utt.InferShapeTester):
x = T.vector('x', dtype=dtype) x = T.vector('x', dtype=dtype)
if dtype in numpy_unsupported_dtypes: if dtype in numpy_unsupported_dtypes:
self.assertRaises(TypeError, bincount, x) self.assertRaises(TypeError, BinCountOp(), x)
else: else:
a = np.random.random_integers(50, size=(25)).astype(dtype) a = np.random.random_integers(50, size=(25)).astype(dtype)
weights = np.random.random((25,)).astype(config.floatX) weights = np.random.random((25,)).astype(config.floatX)
f1 = theano.function([x], bincount(x)) f1 = theano.function([x], BinCountOp()(x, weights=None))
f2 = theano.function([x, w], bincount(x, weights=w)) f2 = theano.function([x, w], BinCountOp()(x, weights=w))
assert (np.bincount(a) == f1(a)).all() assert (np.bincount(a) == f1(a)).all()
assert np.allclose(np.bincount(a, weights=weights), assert np.allclose(np.bincount(a, weights=weights),
f2(a, weights)) f2(a, weights))
if not numpy_16: if not numpy_16:
continue continue
f3 = theano.function([x], bincount(x, minlength=23)) f3 = theano.function([x], BinCountOp(minlength=23)(x, weights=None))
f4 = theano.function([x], bincount(x, minlength=5)) f4 = theano.function([x], BinCountOp(minlength=5)(x, weights=None))
assert (np.bincount(a, minlength=23) == f3(a)).all() assert (np.bincount(a, minlength=23) == f3(a)).all()
assert (np.bincount(a, minlength=5) == f4(a)).all() assert (np.bincount(a, minlength=5) == f4(a)).all()
...@@ -162,12 +192,12 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -162,12 +192,12 @@ class TestBinCountOp(utt.InferShapeTester):
x = T.vector('x', dtype=dtype) x = T.vector('x', dtype=dtype)
if dtype in numpy_unsupported_dtypes: if dtype in numpy_unsupported_dtypes:
self.assertRaises(TypeError, bincount, x) self.assertRaises(TypeError, BinCountOp(), x)
else: else:
self._compile_and_check( self._compile_and_check(
[x], [x],
[bincount(x)], [BinCountOp()(x,None)],
[np.random.random_integers( [np.random.random_integers(
50, size=(25,)).astype(dtype)], 50, size=(25,)).astype(dtype)],
self.op_class) self.op_class)
...@@ -175,7 +205,7 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -175,7 +205,7 @@ class TestBinCountOp(utt.InferShapeTester):
weights = np.random.random((25,)).astype(config.floatX) weights = np.random.random((25,)).astype(config.floatX)
self._compile_and_check( self._compile_and_check(
[x], [x],
[bincount(x, weights=weights)], [BinCountOp()(x, weights=weights)],
[np.random.random_integers( [np.random.random_integers(
50, size=(25,)).astype(dtype)], 50, size=(25,)).astype(dtype)],
self.op_class) self.op_class)
...@@ -184,14 +214,14 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -184,14 +214,14 @@ class TestBinCountOp(utt.InferShapeTester):
continue continue
self._compile_and_check( self._compile_and_check(
[x], [x],
[bincount(x, minlength=60)], [BinCountOp(minlength=60)(x, weights=weights)],
[np.random.random_integers( [np.random.random_integers(
50, size=(25,)).astype(dtype)], 50, size=(25,)).astype(dtype)],
self.op_class) self.op_class)
self._compile_and_check( self._compile_and_check(
[x], [x],
[bincount(x, minlength=5)], [BinCountOp(minlength=5)(x, weights=weights)],
[np.random.random_integers( [np.random.random_integers(
50, size=(25,)).astype(dtype)], 50, size=(25,)).astype(dtype)],
self.op_class) self.op_class)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论