提交 bdfe90b7 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix comments and move code to detect those type of errors

上级 55655580
...@@ -134,31 +134,31 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -134,31 +134,31 @@ class TestBinCountOp(utt.InferShapeTester):
def test_bincountFn(self): def test_bincountFn(self):
w = T.vector('w') w = T.vector('w')
for dtype in ('int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64'):
x = T.vector('x', dtype=dtype)
a = np.random.random_integers(50, size=(25)).astype(dtype)
weights = np.random.random((25,)).astype(config.floatX)
f1 = theano.function([x], bincount(x))
f2 = theano.function([x, w], bincount(x, weights=w))
def ref(data, w=None, minlength=None): def ref(data, w=None, minlength=None):
size = data.max() + 1 size = data.max() + 1
if minlength: if minlength:
size = max(size, minlength) size = max(size, minlength)
if w: if w is not None:
out = np.zeros(size, dtype=weights.dtype) out = np.zeros(size, dtype=w.dtype)
for i in range(data.shape[0]): for i in range(data.shape[0]):
out[data[i]] += weights[i] out[data[i]] += w[i]
else: else:
out = np.zeros(size, dtype=a.dtype) out = np.zeros(size, dtype=a.dtype)
for i in range(data.shape[0]): for i in range(data.shape[0]):
out[data[i]] += 1 out[data[i]] += 1
return out return out
for dtype in ('int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64'):
x = T.vector('x', dtype=dtype)
a = np.random.random_integers(50, size=(25)).astype(dtype)
weights = np.random.random((25,)).astype(config.floatX)
f1 = theano.function([x], bincount(x))
f2 = theano.function([x, w], bincount(x, weights=w))
assert (ref(a) == f1(a)).all() assert (ref(a) == f1(a)).all()
assert np.allclose(ref(a, w), f2(a, weights)) assert np.allclose(ref(a, weights), f2(a, weights))
f3 = theano.function([x], bincount(x, minlength=55)) f3 = theano.function([x], bincount(x, minlength=55))
f4 = theano.function([x], bincount(x, minlength=5)) f4 = theano.function([x], bincount(x, minlength=5))
assert (ref(a, minlength=55) == f3(a)).all() assert (ref(a, minlength=55) == f3(a)).all()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论