提交 3d0164d5 authored 作者: Hani's avatar Hani

Argsort numpy wrapping

上级 00f26d88
...@@ -5825,3 +5825,62 @@ def sort(a, axis=-1, kind='quicksort', order=None): ...@@ -5825,3 +5825,62 @@ def sort(a, axis=-1, kind='quicksort', order=None):
""" """
return SortOp(kind, order)(a, axis) return SortOp(kind, order)(a, axis)
class ArgSortOp(theano.Op):
"""
This class is a wrapper for numpy argsort function
"""
def __init__(self, kind, order=None):
self.kind = kind
self.order = order
def __eq__(self, other):
return type(self) == type(other) and self.order == other.order and self.kind==other.kind
def __hash__(self):
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
def __str__(self):
return self.__class__.__name__ + "{%s, %s}" % (self.kind, str(self.order))
def make_node(self, input, axis=-1):
if axis is None:
raise ValueError("Current Implementation does not sipport axis=None")
return
input = theano.tensor.as_tensor_variable(input)
axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis],
[theano.tensor.TensorType(dtype="int64", broadcastable=input.type.broadcastable)()])
def perform(self, node, inputs, output_storage):
a = inputs[0]
axis = inputs[1]
z = output_storage[0]
z[0] = numpy.argsort(a,axis,self.kind,self.order)
def infer_shape(self, node, inputs_shapes):
return [inputs_shapes[0]]
#**** No grad defined for intergers.
def grad(self, inputs, output_grads):
return [None, None]
"""
def R_op(self, inputs, eval_points):
# R_op can receive None as eval_points.
# That mean there is no diferientiable path through that input
# If this imply that you cannot compute some outputs,
# return None for those.
if eval_points[0] is None:
return eval_points
return self.grad(inputs, eval_points)
"""
def argSort(a, axis=-1, kind='quicksort', order=None):
"""
Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified by the kind keyword.
It returns an array of indices of the same shape as a that index data along the given axis in sorted order.
"""
return ArgSortOp(kind, order)(a, axis)
\ No newline at end of file
...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, sort, SortOp, ) tile, patternbroadcast, sort, SortOp, argSort, ArgSortOp,)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -5639,6 +5639,56 @@ def test_sort(): ...@@ -5639,6 +5639,56 @@ def test_sort():
else: else:
assert False assert False
def test_argsort():
#Set up
rng = numpy.random.RandomState(seed=utt.fetch_seed())
m_val = rng.rand(3, 2)
v_val = rng.rand(4)
#Example 1
a = theano.tensor.dmatrix()
w = argSort(a)
f = theano.function([a], w)
assert numpy.allclose(f(m_val), numpy.argsort(m_val))
#Example 2
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
w = argSort(a, axis)
f = theano.function([a, axis], w)
for axis_val in 0, 1:
assert numpy.allclose(
f(m_val, axis_val),
numpy.argsort(m_val, axis_val))
#Example 3
a = theano.tensor.dvector()
w2 = argSort(a)
f = theano.function([a], w2)
assert numpy.allclose(f(v_val), numpy.argsort(v_val))
#Example 4
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
l = argSort(a, axis, "mergesort")
f = theano.function([a, axis], l)
for axis_val in 0, 1:
assert numpy.allclose(
f(m_val, axis_val),
numpy.argsort(m_val, axis_val))
#Example 5
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
a1 = ArgSortOp("mergesort", [])
a2 = ArgSortOp("quicksort", [])
#All the below should give true
assert a1 != a2
assert a1 == ArgSortOp("mergesort", [])
assert a2 == ArgSortOp("quicksort", [])
if __name__ == '__main__': if __name__ == '__main__':
if 0: if 0:
unittest.main() unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论