提交 c6390da8 authored 作者: Hani's avatar Hani

Applying changes following reviewers.

上级 3d0164d5
......@@ -5832,11 +5832,13 @@ class ArgSortOp(theano.Op):
This class is a wrapper for numpy argsort function
"""
def __init__(self, kind, order=None):
self.kind = kind
self.order = order
self.kind = kind
self.order = order
def __eq__(self, other):
return type(self) == type(other) and self.order == other.order and self.kind==other.kind
return (type(self) == type(other) and
self.order == other.order and
self.kind == other.kind)
def __hash__(self):
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
......@@ -5845,11 +5847,11 @@ class ArgSortOp(theano.Op):
return self.__class__.__name__ + "{%s, %s}" % (self.kind, str(self.order))
def make_node(self, input, axis=-1):
if axis is None:
raise ValueError("Current Implementation does not sipport axis=None")
return
input = theano.tensor.as_tensor_variable(input)
axis = theano.tensor.as_tensor_variable(axis)
if axis is None:
axis = Constant(gof.generic, None)
else:
axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis],
[theano.tensor.TensorType(dtype="int64", broadcastable=input.type.broadcastable)()])
......@@ -5857,13 +5859,13 @@ class ArgSortOp(theano.Op):
a = inputs[0]
axis = inputs[1]
z = output_storage[0]
z[0] = numpy.argsort(a,axis,self.kind,self.order)
z[0] = numpy.argsort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes):
return [inputs_shapes[0]]
#**** No grad defined for intergers.
def grad(self, inputs, output_grads):
#No grad defined for intergers.
return [None, None]
"""
def R_op(self, inputs, eval_points):
......@@ -5877,10 +5879,11 @@ class ArgSortOp(theano.Op):
"""
def argSort(a, axis=-1, kind='quicksort', order=None):
def argsort(a, axis=-1, kind='quicksort', order=None):
"""
Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified by the kind keyword.
It returns an array of indices of the same shape as a that index data along the given axis in sorted order.
"""
return ArgSortOp(kind, order)(a, axis)
\ No newline at end of file
return ArgSortOp(kind, order)(a, axis)
\ No newline at end of file
......@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, sort, SortOp, argSort, ArgSortOp,)
tile, patternbroadcast, sort, SortOp, argsort, ArgSortOp,)
from theano.tests import unittest_tools as utt
......@@ -5648,14 +5648,14 @@ def test_argsort():
#Example 1
a = theano.tensor.dmatrix()
w = argSort(a)
w = argsort(a)
f = theano.function([a], w)
assert numpy.allclose(f(m_val), numpy.argsort(m_val))
#Example 2
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
w = argSort(a, axis)
w = argsort(a, axis)
f = theano.function([a, axis], w)
for axis_val in 0, 1:
assert numpy.allclose(
......@@ -5664,14 +5664,14 @@ def test_argsort():
#Example 3
a = theano.tensor.dvector()
w2 = argSort(a)
w2 = argsort(a)
f = theano.function([a], w2)
assert numpy.allclose(f(v_val), numpy.argsort(v_val))
#Example 4
a = theano.tensor.dmatrix()
axis = theano.tensor.scalar()
l = argSort(a, axis, "mergesort")
l = argsort(a, axis, "mergesort")
f = theano.function([a, axis], l)
for axis_val in 0, 1:
assert numpy.allclose(
......@@ -5688,6 +5688,12 @@ def test_argsort():
assert a1 == ArgSortOp("mergesort", [])
assert a2 == ArgSortOp("quicksort", [])
#Example 6: Testing axis=None
a = theano.tensor.dmatrix()
w2 = argsort(a, None)
f = theano.function([a], w2)
assert numpy.allclose(f(m_val), numpy.argsort(m_val, None))
if __name__ == '__main__':
if 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论