提交 8cb589f1 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

checked if all elements are nonnegative, added an argument to enable the check, added a test

上级 f9d4ec48
...@@ -436,7 +436,7 @@ class BinCountOp(theano.Op): ...@@ -436,7 +436,7 @@ class BinCountOp(theano.Op):
return self.__class__.__name__ return self.__class__.__name__
def bincount(x, weights=None, minlength=None): def bincount(x, weights=None, minlength=None, assert_nonneg=False):
"""Count number of occurrences of each value in array of ints. """Count number of occurrences of each value in array of ints.
The number of bins (of size 1) is one larger than the largest The number of bins (of size 1) is one larger than the largest
...@@ -453,7 +453,9 @@ def bincount(x, weights=None, minlength=None): ...@@ -453,7 +453,9 @@ def bincount(x, weights=None, minlength=None):
Optional. Optional.
:param minlength: A minimum number of bins for the output array. :param minlength: A minimum number of bins for the output array.
Optional. Optional.
:param assert_nonneg: A flag that inserts an assert_op to check if
every input x is nonnegative.
Optional.
.. versionadded:: 0.6 .. versionadded:: 0.6
""" """
compatible_type = ('int8', 'int16', 'int32', 'int64', compatible_type = ('int8', 'int16', 'int32', 'int64',
...@@ -471,6 +473,11 @@ def bincount(x, weights=None, minlength=None): ...@@ -471,6 +473,11 @@ def bincount(x, weights=None, minlength=None):
if x.ndim != 1: if x.ndim != 1:
raise TypeError("Inputs must be of dimension 1.") raise TypeError("Inputs must be of dimension 1.")
if assert_nonneg:
from theano.tensor.opt import Assert
assert_op = Assert('Input to bincount has negative values!')
x = assert_op(x, theano.tensor.all(x >= 0))
max_value = theano.tensor.cast(x.max() + 1, 'int64') max_value = theano.tensor.cast(x.max() + 1, 'int64')
if minlength is not None: if minlength is not None:
......
...@@ -14,6 +14,7 @@ from theano.tensor.extra_ops import (CumsumOp, cumsum, CumprodOp, cumprod, ...@@ -14,6 +14,7 @@ from theano.tensor.extra_ops import (CumsumOp, cumsum, CumprodOp, cumprod,
to_one_hot) to_one_hot)
from theano import tensor as T from theano import tensor as T
from theano import config, tensor, function from theano import config, tensor, function
from nose.tools import assert_raises
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]] numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
...@@ -138,6 +139,9 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -138,6 +139,9 @@ class TestBinCountOp(utt.InferShapeTester):
f4 = theano.function([x], bincount(x, minlength=5)) f4 = theano.function([x], bincount(x, minlength=5))
assert (np.bincount(a, minlength=23) == f3(a)).all() assert (np.bincount(a, minlength=23) == f3(a)).all()
assert (np.bincount(a, minlength=5) == f4(a)).all() assert (np.bincount(a, minlength=5) == f4(a)).all()
a[0] = -1
f5 = theano.function([x], bincount(x, assert_nonneg=True))
self.assertRaises(AssertionError, f5, a)
def test_bincountOp(self): def test_bincountOp(self):
w = T.vector('w') w = T.vector('w')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论