提交 e554dd5a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #4540 from nouiz/fix_test

Finish the regression fix in gh-4533
...@@ -975,7 +975,8 @@ def local_assert_no_cpu_op(node): ...@@ -975,7 +975,8 @@ def local_assert_no_cpu_op(node):
assert_no_cpu_op = theano.tensor.opt.in2out(local_assert_no_cpu_op, assert_no_cpu_op = theano.tensor.opt.in2out(local_assert_no_cpu_op,
name='assert_no_cpu_op') name='assert_no_cpu_op')
# 49.2 is after device specialization & fusion optimizations for last transfers # 49.2 is after device specialization & fusion optimizations for last transfers
optdb.register('gpua_assert_no_cpu_op', assert_no_cpu_op, 49.2) optdb.register('gpua_assert_no_cpu_op', assert_no_cpu_op, 49.2,
'assert_no_cpu_op')
def tensor_to_gpu(x, context_name): def tensor_to_gpu(x, context_name):
......
...@@ -102,6 +102,11 @@ class SearchsortedOp(theano.Op): ...@@ -102,6 +102,11 @@ class SearchsortedOp(theano.Op):
return theano.Apply(self, [x, v], [out_type()]) return theano.Apply(self, [x, v], [out_type()])
else: else:
sorter = basic.as_tensor(sorter, ndim=1) sorter = basic.as_tensor(sorter, ndim=1)
if (theano.configdefaults.python_int_bitwidth() == 32 and
sorter.dtype == 'int64'):
raise TypeError(
"numpy.searchsorted with Python 32bit do not support a"
" sorter of int64.")
if sorter.type not in basic.int_vector_types: if sorter.type not in basic.int_vector_types:
raise TypeError('sorter must be an integer vector', raise TypeError('sorter must be an integer vector',
sorter.type) sorter.type)
...@@ -119,7 +124,8 @@ class SearchsortedOp(theano.Op): ...@@ -119,7 +124,8 @@ class SearchsortedOp(theano.Op):
sorter = None sorter = None
z = output_storage[0] z = output_storage[0]
z[0] = np.searchsorted(x, v, side=params, sorter=sorter) z[0] = np.searchsorted(x, v, side=params, sorter=sorter).astype(
node.outputs[0].dtype)
def c_support_code_struct(self, node, name): def c_support_code_struct(self, node, name):
return """ return """
...@@ -154,10 +160,15 @@ class SearchsortedOp(theano.Op): ...@@ -154,10 +160,15 @@ class SearchsortedOp(theano.Op):
right_%(name)s ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) %(sorter)s); right_%(name)s ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) %(sorter)s);
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
if (PyArray_TYPE(%(z)s) != NPY_INT64){
PyObject * tmp = PyArray_Cast(%(z)s, NPY_INT64);
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) tmp;
}
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
num_ins = len(inputs) num_ins = len(inputs)
......
...@@ -50,14 +50,14 @@ class TestSearchsortedOp(utt.InferShapeTester): ...@@ -50,14 +50,14 @@ class TestSearchsortedOp(utt.InferShapeTester):
self.a = 30 * np.random.random(50).astype(config.floatX) self.a = 30 * np.random.random(50).astype(config.floatX)
self.b = 30 * np.random.random((8, 10, 5)).astype(config.floatX) self.b = 30 * np.random.random((8, 10, 5)).astype(config.floatX)
self.idx_sorted = np.argsort(self.a) self.idx_sorted = np.argsort(self.a).astype('int32')
def test_searchsortedOp_on_sorted_input(self): def test_searchsortedOp_on_sorted_input(self):
f = theano.function([self.x, self.v], searchsorted(self.x, self.v)) f = theano.function([self.x, self.v], searchsorted(self.x, self.v))
assert np.allclose(np.searchsorted(self.a[self.idx_sorted], self.b), assert np.allclose(np.searchsorted(self.a[self.idx_sorted], self.b),
f(self.a[self.idx_sorted], self.b)) f(self.a[self.idx_sorted], self.b))
sorter = T.vector('sorter', dtype='int64') sorter = T.vector('sorter', dtype='int32')
f = theano.function([self.x, self.v, sorter], self.x.searchsorted(self.v, sorter=sorter, side='right')) f = theano.function([self.x, self.v, sorter], self.x.searchsorted(self.v, sorter=sorter, side='right'))
assert np.allclose(self.a.searchsorted(self.b, sorter=self.idx_sorted, side='right'), assert np.allclose(self.a.searchsorted(self.b, sorter=self.idx_sorted, side='right'),
f(self.a, self.b, self.idx_sorted)) f(self.a, self.b, self.idx_sorted))
...@@ -80,7 +80,9 @@ class TestSearchsortedOp(utt.InferShapeTester): ...@@ -80,7 +80,9 @@ class TestSearchsortedOp(utt.InferShapeTester):
self.x, self.v, sorter=sorter) self.x, self.v, sorter=sorter)
def test_searchsortedOp_on_int_sorter(self): def test_searchsortedOp_on_int_sorter(self):
compatible_types = ('int8', 'int16', 'int32', 'int64',) compatible_types = ('int8', 'int16', 'int32')
if theano.configdefaults.python_int_bitwidth() == 64:
compatible_types += ('int64',)
# 'uint8', 'uint16', 'uint32', 'uint64') # 'uint8', 'uint16', 'uint32', 'uint64')
for dtype in compatible_types: for dtype in compatible_types:
sorter = T.vector('sorter', dtype=dtype) sorter = T.vector('sorter', dtype=dtype)
...@@ -104,7 +106,7 @@ class TestSearchsortedOp(utt.InferShapeTester): ...@@ -104,7 +106,7 @@ class TestSearchsortedOp(utt.InferShapeTester):
self.op_class) self.op_class)
# Test parameter ``sorter`` # Test parameter ``sorter``
sorter = T.vector('sorter', dtype="int64") sorter = T.vector('sorter', dtype="int32")
self._compile_and_check([self.x, self.v, sorter], self._compile_and_check([self.x, self.v, sorter],
[searchsorted(self.x, self.v, sorter=sorter)], [searchsorted(self.x, self.v, sorter=sorter)],
[self.a, self.b, self.idx_sorted], [self.a, self.b, self.idx_sorted],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论