提交 6506ecd9 authored 作者: Christos Tsirigotis's avatar Christos Tsirigotis

Raise ValueError on invalid `side` kwd of searchsorted

- Add extra tests for searchsorted
上级 8bffb72e
......@@ -85,7 +85,11 @@ class SearchsortedOp(theano.Op):
__props__ = ("side", )
def __init__(self, side='left'):
self.side = side
if side == 'left' or side == 'right':
self.side = side
else:
raise ValueError('\'%(side)s\' is an invalid value for keyword \'side\''
% locals())
def get_params(self, node):
return self.side
......
......@@ -52,18 +52,28 @@ class TestSearchsortedOp(utt.InferShapeTester):
self.b = 30 * np.random.random((8, 10, 5)).astype(config.floatX)
self.idx_sorted = np.argsort(self.a)
def tearDown(self):
self.x = None
self.v = None
self.a = None
self.b = None
self.idx_sorted = None
def test_searchsortedOp_on_sorted_input(self):
f = theano.function([self.x, self.v], searchsorted(self.x, self.v))
assert np.allclose(np.searchsorted(self.a[self.idx_sorted], self.b),
f(self.a[self.idx_sorted], self.b))
sorter = T.vector('sorter', dtype='int64')
f = theano.function([self.x, self.v, sorter], self.x.searchsorted(self.v, sorter=sorter, side='right'))
assert np.allclose(self.a.searchsorted(self.b, sorter=self.idx_sorted, side='right'),
f(self.a, self.b, self.idx_sorted))
sa = self.a[self.idx_sorted]
f = theano.function([self.x, self.v], self.x.searchsorted(self.v, side='right'))
assert np.allclose(sa.searchsorted(self.b, side='right'), f(sa, self.b))
def test_searchsortedOp_wrong_side_kwd(self):
self.assertRaises(ValueError, searchsorted, self.x, self.v, side='asdfa')
def test_searchsortedOp_on_no_1d_inp(self):
no_1d = T.dmatrix('no_1d')
self.assertRaises(ValueError, searchsorted, no_1d, self.v)
self.assertRaises(ValueError, searchsorted, self.x, self.v, sorter=no_1d)
def test_searchsortedOp_on_float_sorter(self):
sorter = T.vector('sorter', dtype="float32")
self.assertRaises(TypeError, searchsorted,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论