提交 52b83f7a authored 作者: lamblin's avatar lamblin

Merge pull request #579 from HaniAlmousli/argsort

Argsort numpy wrapping
...@@ -6337,3 +6337,65 @@ def sort(a, axis=-1, kind='quicksort', order=None): ...@@ -6337,3 +6337,65 @@ def sort(a, axis=-1, kind='quicksort', order=None):
""" """
return SortOp(kind, order)(a, axis) return SortOp(kind, order)(a, axis)
class ArgSortOp(theano.Op):
"""
This class is a wrapper for numpy argsort function
"""
def __init__(self, kind, order=None):
self.kind = kind
self.order = order
def __eq__(self, other):
return (type(self) == type(other) and
self.order == other.order and
self.kind == other.kind)
def __hash__(self):
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
def __str__(self):
return self.__class__.__name__ + "{%s, %s}" % (self.kind, str(self.order))
def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input)
if axis is None:
axis = Constant(gof.generic, None)
else:
axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis],
[theano.tensor.TensorType(dtype="int64", broadcastable=input.type.broadcastable)()])
def perform(self, node, inputs, output_storage):
a = inputs[0]
axis = inputs[1]
z = output_storage[0]
z[0] = numpy.argsort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes):
return [inputs_shapes[0]]
def grad(self, inputs, output_grads):
#No grad defined for intergers.
return [None, None]
"""
def R_op(self, inputs, eval_points):
# R_op can receive None as eval_points.
# That mean there is no diferientiable path through that input
# If this imply that you cannot compute some outputs,
# return None for those.
if eval_points[0] is None:
return eval_points
return self.grad(inputs, eval_points)
"""
def argsort(a, axis=-1, kind='quicksort', order=None):
"""
Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified by the kind keyword.
It returns an array of indices of the same shape as a that index data along the given axis in sorted order.
"""
return ArgSortOp(kind, order)(a, axis)
...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, sort, SortOp, ) tile, patternbroadcast, sort, SortOp, argsort, ArgSortOp,)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -5709,6 +5709,62 @@ class TensorInferShapeTester(utt.InferShapeTester): ...@@ -5709,6 +5709,62 @@ class TensorInferShapeTester(utt.InferShapeTester):
SortOp) SortOp)
def test_argsort():
#Set up
rng = numpy.random.RandomState(seed=utt.fetch_seed())
m_val = rng.rand(3, 2)
v_val = rng.rand(4)
#Example 1
a = theano.tensor.dmatrix()
w = argsort(a)
f = theano.function([a], w)
assert numpy.allclose(f(m_val), numpy.argsort(m_val))
#Example 2
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
w = argsort(a, axis)
f = theano.function([a, axis], w)
for axis_val in 0, 1:
assert numpy.allclose(
f(m_val, axis_val),
numpy.argsort(m_val, axis_val))
#Example 3
a = theano.tensor.dvector()
w2 = argsort(a)
f = theano.function([a], w2)
assert numpy.allclose(f(v_val), numpy.argsort(v_val))
#Example 4
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
l = argsort(a, axis, "mergesort")
f = theano.function([a, axis], l)
for axis_val in 0, 1:
assert numpy.allclose(
f(m_val, axis_val),
numpy.argsort(m_val, axis_val))
#Example 5
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
a1 = ArgSortOp("mergesort", [])
a2 = ArgSortOp("quicksort", [])
#All the below should give true
assert a1 != a2
assert a1 == ArgSortOp("mergesort", [])
assert a2 == ArgSortOp("quicksort", [])
#Example 6: Testing axis=None
a = theano.tensor.dmatrix()
w2 = argsort(a, None)
f = theano.function([a], w2)
assert numpy.allclose(f(m_val), numpy.argsort(m_val, None))
if __name__ == '__main__': if __name__ == '__main__':
if 0: if 0:
unittest.main() unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论