提交 45c7b96a authored 作者: Frederic Bastien's avatar Frederic Bastien

Make test of bincount pass with 32bit python. We don't have the numpy limit anymore

上级 2bbcab02
...@@ -156,13 +156,25 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -156,13 +156,25 @@ class TestBinCountOp(utt.InferShapeTester):
f1 = theano.function([x], bincount(x)) f1 = theano.function([x], bincount(x))
f2 = theano.function([x, w], bincount(x, weights=w)) f2 = theano.function([x, w], bincount(x, weights=w))
assert (np.bincount(a) == f1(a)).all() def ref(data, w=None, minlength=None):
assert np.allclose(np.bincount(a, weights=weights), size = data.max() + 1
f2(a, weights)) if minlength:
f3 = theano.function([x], bincount(x, minlength=23)) size = max(size, minlength)
if w:
out = np.zeros(size, dtype=weights.dtype)
for i in range(data.shape[0]):
out[data[i]] += weights[i]
else:
out = np.zeros(size, dtype=a.dtype)
for i in range(data.shape[0]):
out[data[i]] += 1
return out
assert (ref(a) == f1(a)).all()
assert np.allclose(ref(a, w), f2(a, weights))
f3 = theano.function([x], bincount(x, minlength=55))
f4 = theano.function([x], bincount(x, minlength=5)) f4 = theano.function([x], bincount(x, minlength=5))
assert (np.bincount(a, minlength=23) == f3(a)).all() assert (ref(a, minlength=55) == f3(a)).all()
assert (np.bincount(a, minlength=5) == f4(a)).all() assert (ref(a, minlength=5) == f4(a)).all()
# skip the following test when using unsigned ints # skip the following test when using unsigned ints
if not dtype.startswith('u'): if not dtype.startswith('u'):
a[0] = -1 a[0] = -1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论