提交 bdfe90b7 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix comments and move code to detect those type of errors

上级 55655580
......@@ -134,6 +134,19 @@ class TestBinCountOp(utt.InferShapeTester):
def test_bincountFn(self):
w = T.vector('w')
def ref(data, w=None, minlength=None):
size = data.max() + 1
if minlength:
size = max(size, minlength)
if w is not None:
out = np.zeros(size, dtype=w.dtype)
for i in range(data.shape[0]):
out[data[i]] += w[i]
else:
out = np.zeros(size, dtype=a.dtype)
for i in range(data.shape[0]):
out[data[i]] += 1
return out
for dtype in ('int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64'):
x = T.vector('x', dtype=dtype)
......@@ -144,21 +157,8 @@ class TestBinCountOp(utt.InferShapeTester):
f1 = theano.function([x], bincount(x))
f2 = theano.function([x, w], bincount(x, weights=w))
def ref(data, w=None, minlength=None):
size = data.max() + 1
if minlength:
size = max(size, minlength)
if w:
out = np.zeros(size, dtype=weights.dtype)
for i in range(data.shape[0]):
out[data[i]] += weights[i]
else:
out = np.zeros(size, dtype=a.dtype)
for i in range(data.shape[0]):
out[data[i]] += 1
return out
assert (ref(a) == f1(a)).all()
assert np.allclose(ref(a, w), f2(a, weights))
assert np.allclose(ref(a, weights), f2(a, weights))
f3 = theano.function([x], bincount(x, minlength=55))
f4 = theano.function([x], bincount(x, minlength=5))
assert (ref(a, minlength=55) == f3(a)).all()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论