提交 a84654ad authored 作者: Frederic Bastien's avatar Frederic Bastien

Make SearchSortedOp always return int64 as I'm not able to always figure out…

Make SearchSortedOp always return int64 as I'm not able to always figure out what is the output dtype. In particular, in Python 32bit, it don't return int64, but if I make it return intN where N is the python bit version, it also don't work.
上级 03f1e9c9
......@@ -102,6 +102,11 @@ class SearchsortedOp(theano.Op):
return theano.Apply(self, [x, v], [out_type()])
else:
sorter = basic.as_tensor(sorter, ndim=1)
if (theano.configdefaults.python_int_bitwidth() == 32 and
sorter.dtype == 'int64'):
raise TypeError(
"numpy.searchsorted with Python 32bit do not support a"
" sorter of int64.")
if sorter.type not in basic.int_vector_types:
raise TypeError('sorter must be an integer vector',
sorter.type)
......@@ -119,7 +124,8 @@ class SearchsortedOp(theano.Op):
sorter = None
z = output_storage[0]
z[0] = np.searchsorted(x, v, side=params, sorter=sorter)
z[0] = np.searchsorted(x, v, side=params, sorter=sorter).astype(
node.outputs[0].dtype)
def c_support_code_struct(self, node, name):
return """
......@@ -154,10 +160,15 @@ class SearchsortedOp(theano.Op):
right_%(name)s ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) %(sorter)s);
if (!%(z)s)
%(fail)s;
if (PyArray_TYPE(%(z)s) != NPY_INT64){
PyObject * tmp = PyArray_Cast(%(z)s, NPY_INT64);
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) tmp;
}
""" % locals()
def c_code_cache_version(self):
return (1,)
return (2,)
def grad(self, inputs, output_gradients):
num_ins = len(inputs)
......
......@@ -50,14 +50,14 @@ class TestSearchsortedOp(utt.InferShapeTester):
self.a = 30 * np.random.random(50).astype(config.floatX)
self.b = 30 * np.random.random((8, 10, 5)).astype(config.floatX)
self.idx_sorted = np.argsort(self.a)
self.idx_sorted = np.argsort(self.a).astype('int32')
def test_searchsortedOp_on_sorted_input(self):
f = theano.function([self.x, self.v], searchsorted(self.x, self.v))
assert np.allclose(np.searchsorted(self.a[self.idx_sorted], self.b),
f(self.a[self.idx_sorted], self.b))
sorter = T.vector('sorter', dtype='int64')
sorter = T.vector('sorter', dtype='int32')
f = theano.function([self.x, self.v, sorter], self.x.searchsorted(self.v, sorter=sorter, side='right'))
assert np.allclose(self.a.searchsorted(self.b, sorter=self.idx_sorted, side='right'),
f(self.a, self.b, self.idx_sorted))
......@@ -80,7 +80,9 @@ class TestSearchsortedOp(utt.InferShapeTester):
self.x, self.v, sorter=sorter)
def test_searchsortedOp_on_int_sorter(self):
compatible_types = ('int8', 'int16', 'int32', 'int64',)
compatible_types = ('int8', 'int16', 'int32')
if theano.configdefaults.python_int_bitwidth() == 64:
compatible_types += ('int64',)
# 'uint8', 'uint16', 'uint32', 'uint64')
for dtype in compatible_types:
sorter = T.vector('sorter', dtype=dtype)
......@@ -104,7 +106,7 @@ class TestSearchsortedOp(utt.InferShapeTester):
self.op_class)
# Test parameter ``sorter``
sorter = T.vector('sorter', dtype="int64")
sorter = T.vector('sorter', dtype="int32")
self._compile_and_check([self.x, self.v, sorter],
[searchsorted(self.x, self.v, sorter=sorter)],
[self.a, self.b, self.idx_sorted],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论