提交 a84654ad authored 作者: Frederic Bastien's avatar Frederic Bastien

Make SearchSortedOp always return int64 as I'm not able to always figure out…

Make SearchSortedOp always return int64 as I'm not able to always figure out what is the output dtype. In particular, in Python 32bit, it don't return int64, but if I make it return intN where N is the python bit version, it also don't work.
上级 03f1e9c9
...@@ -102,6 +102,11 @@ class SearchsortedOp(theano.Op): ...@@ -102,6 +102,11 @@ class SearchsortedOp(theano.Op):
return theano.Apply(self, [x, v], [out_type()]) return theano.Apply(self, [x, v], [out_type()])
else: else:
sorter = basic.as_tensor(sorter, ndim=1) sorter = basic.as_tensor(sorter, ndim=1)
if (theano.configdefaults.python_int_bitwidth() == 32 and
sorter.dtype == 'int64'):
raise TypeError(
"numpy.searchsorted with Python 32bit do not support a"
" sorter of int64.")
if sorter.type not in basic.int_vector_types: if sorter.type not in basic.int_vector_types:
raise TypeError('sorter must be an integer vector', raise TypeError('sorter must be an integer vector',
sorter.type) sorter.type)
...@@ -119,7 +124,8 @@ class SearchsortedOp(theano.Op): ...@@ -119,7 +124,8 @@ class SearchsortedOp(theano.Op):
sorter = None sorter = None
z = output_storage[0] z = output_storage[0]
z[0] = np.searchsorted(x, v, side=params, sorter=sorter) z[0] = np.searchsorted(x, v, side=params, sorter=sorter).astype(
node.outputs[0].dtype)
def c_support_code_struct(self, node, name): def c_support_code_struct(self, node, name):
return """ return """
...@@ -154,10 +160,15 @@ class SearchsortedOp(theano.Op): ...@@ -154,10 +160,15 @@ class SearchsortedOp(theano.Op):
right_%(name)s ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) %(sorter)s); right_%(name)s ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) %(sorter)s);
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
if (PyArray_TYPE(%(z)s) != NPY_INT64){
PyObject * tmp = PyArray_Cast(%(z)s, NPY_INT64);
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) tmp;
}
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
num_ins = len(inputs) num_ins = len(inputs)
......
...@@ -50,14 +50,14 @@ class TestSearchsortedOp(utt.InferShapeTester): ...@@ -50,14 +50,14 @@ class TestSearchsortedOp(utt.InferShapeTester):
self.a = 30 * np.random.random(50).astype(config.floatX) self.a = 30 * np.random.random(50).astype(config.floatX)
self.b = 30 * np.random.random((8, 10, 5)).astype(config.floatX) self.b = 30 * np.random.random((8, 10, 5)).astype(config.floatX)
self.idx_sorted = np.argsort(self.a) self.idx_sorted = np.argsort(self.a).astype('int32')
def test_searchsortedOp_on_sorted_input(self): def test_searchsortedOp_on_sorted_input(self):
f = theano.function([self.x, self.v], searchsorted(self.x, self.v)) f = theano.function([self.x, self.v], searchsorted(self.x, self.v))
assert np.allclose(np.searchsorted(self.a[self.idx_sorted], self.b), assert np.allclose(np.searchsorted(self.a[self.idx_sorted], self.b),
f(self.a[self.idx_sorted], self.b)) f(self.a[self.idx_sorted], self.b))
sorter = T.vector('sorter', dtype='int64') sorter = T.vector('sorter', dtype='int32')
f = theano.function([self.x, self.v, sorter], self.x.searchsorted(self.v, sorter=sorter, side='right')) f = theano.function([self.x, self.v, sorter], self.x.searchsorted(self.v, sorter=sorter, side='right'))
assert np.allclose(self.a.searchsorted(self.b, sorter=self.idx_sorted, side='right'), assert np.allclose(self.a.searchsorted(self.b, sorter=self.idx_sorted, side='right'),
f(self.a, self.b, self.idx_sorted)) f(self.a, self.b, self.idx_sorted))
...@@ -80,7 +80,9 @@ class TestSearchsortedOp(utt.InferShapeTester): ...@@ -80,7 +80,9 @@ class TestSearchsortedOp(utt.InferShapeTester):
self.x, self.v, sorter=sorter) self.x, self.v, sorter=sorter)
def test_searchsortedOp_on_int_sorter(self): def test_searchsortedOp_on_int_sorter(self):
compatible_types = ('int8', 'int16', 'int32', 'int64',) compatible_types = ('int8', 'int16', 'int32')
if theano.configdefaults.python_int_bitwidth() == 64:
compatible_types += ('int64',)
# 'uint8', 'uint16', 'uint32', 'uint64') # 'uint8', 'uint16', 'uint32', 'uint64')
for dtype in compatible_types: for dtype in compatible_types:
sorter = T.vector('sorter', dtype=dtype) sorter = T.vector('sorter', dtype=dtype)
...@@ -104,7 +106,7 @@ class TestSearchsortedOp(utt.InferShapeTester): ...@@ -104,7 +106,7 @@ class TestSearchsortedOp(utt.InferShapeTester):
self.op_class) self.op_class)
# Test parameter ``sorter`` # Test parameter ``sorter``
sorter = T.vector('sorter', dtype="int64") sorter = T.vector('sorter', dtype="int32")
self._compile_and_check([self.x, self.v, sorter], self._compile_and_check([self.x, self.v, sorter],
[searchsorted(self.x, self.v, sorter=sorter)], [searchsorted(self.x, self.v, sorter=sorter)],
[self.a, self.b, self.idx_sorted], [self.a, self.b, self.idx_sorted],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论