提交 8cb589f1 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

checked if all elements are nonnegative, added an argument to enable the check, added a test

上级 f9d4ec48
......@@ -436,7 +436,7 @@ class BinCountOp(theano.Op):
return self.__class__.__name__
def bincount(x, weights=None, minlength=None):
def bincount(x, weights=None, minlength=None, assert_nonneg=False):
"""Count number of occurrences of each value in array of ints.
The number of bins (of size 1) is one larger than the largest
......@@ -453,7 +453,9 @@ def bincount(x, weights=None, minlength=None):
Optional.
:param minlength: A minimum number of bins for the output array.
Optional.
:param assert_nonneg: A flag that inserts an assert_op to check if
every input x is nonnegative.
Optional.
.. versionadded:: 0.6
"""
compatible_type = ('int8', 'int16', 'int32', 'int64',
......@@ -471,6 +473,11 @@ def bincount(x, weights=None, minlength=None):
if x.ndim != 1:
raise TypeError("Inputs must be of dimension 1.")
if assert_nonneg:
from theano.tensor.opt import Assert
assert_op = Assert('Input to bincount has negative values!')
x = assert_op(x, theano.tensor.all(x >= 0))
max_value = theano.tensor.cast(x.max() + 1, 'int64')
if minlength is not None:
......
......@@ -14,6 +14,7 @@ from theano.tensor.extra_ops import (CumsumOp, cumsum, CumprodOp, cumprod,
to_one_hot)
from theano import tensor as T
from theano import config, tensor, function
from nose.tools import assert_raises
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
......@@ -138,6 +139,9 @@ class TestBinCountOp(utt.InferShapeTester):
f4 = theano.function([x], bincount(x, minlength=5))
assert (np.bincount(a, minlength=23) == f3(a)).all()
assert (np.bincount(a, minlength=5) == f4(a)).all()
a[0] = -1
f5 = theano.function([x], bincount(x, assert_nonneg=True))
self.assertRaises(AssertionError, f5, a)
def test_bincountOp(self):
w = T.vector('w')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论