提交 973fce3e authored 作者: Frederic's avatar Frederic

fix test and raise a good error if we have an older numpy version.

上级 4c0c2258
...@@ -91,6 +91,12 @@ class BinCountOp(theano.Op): ...@@ -91,6 +91,12 @@ class BinCountOp(theano.Op):
def __init__(self, minlength=None): def __init__(self, minlength=None):
self.minlength = minlength self.minlength = minlength
if minlength is not None:
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
if not bool(numpy_ver >= [1, 6]):
raise NotImplementedError(
"BinCountOp with minlength attribute"
" need NumPy 1.6 or higher.")
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (type(self) == type(other) and
...@@ -145,8 +151,11 @@ class BinCountOp(theano.Op): ...@@ -145,8 +151,11 @@ class BinCountOp(theano.Op):
if weights is not None and weights.shape != x.shape: if weights is not None and weights.shape != x.shape:
raise TypeError("All inputs must have the same shape.") raise TypeError("All inputs must have the same shape.")
#Needed for numpy 1.4.1 compatibility
if self.minlength:
z[0] = np.bincount(x, weights=weights, minlength=self.minlength) z[0] = np.bincount(x, weights=weights, minlength=self.minlength)
else:
z[0] = np.bincount(x, weights=weights)
def grad(self, inputs, outputs_gradients): def grad(self, inputs, outputs_gradients):
output = self(*inputs) output = self(*inputs)
......
...@@ -10,6 +10,9 @@ from theano import tensor as T ...@@ -10,6 +10,9 @@ from theano import tensor as T
from theano import config, tensor, function from theano import config, tensor, function
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
numpy_16 = bool(numpy_ver >= [1, 6])
class TestBinCountOp(utt.InferShapeTester): class TestBinCountOp(utt.InferShapeTester):
def setUp(self): def setUp(self):
super(TestBinCountOp, self).setUp() super(TestBinCountOp, self).setUp()
...@@ -39,12 +42,14 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -39,12 +42,14 @@ class TestBinCountOp(utt.InferShapeTester):
f1 = theano.function([x], bincount(x)) f1 = theano.function([x], bincount(x))
f2 = theano.function([x, w], bincount(x, weights=w)) f2 = theano.function([x, w], bincount(x, weights=w))
f3 = theano.function([x], bincount(x, minlength=23))
f4 = theano.function([x], bincount(x, minlength=5))
assert (np.bincount(a) == f1(a)).all() assert (np.bincount(a) == f1(a)).all()
assert np.allclose(np.bincount(a, weights=weights), assert np.allclose(np.bincount(a, weights=weights),
f2(a, weights)) f2(a, weights))
if not numpy_16:
continue
f3 = theano.function([x], bincount(x, minlength=23))
f4 = theano.function([x], bincount(x, minlength=5))
assert (np.bincount(a, minlength=23) == f3(a)).all() assert (np.bincount(a, minlength=23) == f3(a)).all()
assert (np.bincount(a, minlength=5) == f4(a)).all() assert (np.bincount(a, minlength=5) == f4(a)).all()
...@@ -79,6 +84,8 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -79,6 +84,8 @@ class TestBinCountOp(utt.InferShapeTester):
50, size=(25,)).astype(dtype)], 50, size=(25,)).astype(dtype)],
self.op_class) self.op_class)
if not numpy_16:
continue
self._compile_and_check( self._compile_and_check(
[x], [x],
[bincount(x, minlength=60)], [bincount(x, minlength=60)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论