提交 c6390da8 authored 作者: Hani's avatar Hani

Applying changes following reviewers.

上级 3d0164d5
...@@ -5832,11 +5832,13 @@ class ArgSortOp(theano.Op): ...@@ -5832,11 +5832,13 @@ class ArgSortOp(theano.Op):
This class is a wrapper for numpy argsort function This class is a wrapper for numpy argsort function
""" """
def __init__(self, kind, order=None): def __init__(self, kind, order=None):
self.kind = kind self.kind = kind
self.order = order self.order = order
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.order == other.order and self.kind==other.kind return (type(self) == type(other) and
self.order == other.order and
self.kind == other.kind)
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind) return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
...@@ -5845,11 +5847,11 @@ class ArgSortOp(theano.Op): ...@@ -5845,11 +5847,11 @@ class ArgSortOp(theano.Op):
return self.__class__.__name__ + "{%s, %s}" % (self.kind, str(self.order)) return self.__class__.__name__ + "{%s, %s}" % (self.kind, str(self.order))
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
if axis is None:
raise ValueError("Current Implementation does not sipport axis=None")
return
input = theano.tensor.as_tensor_variable(input) input = theano.tensor.as_tensor_variable(input)
axis = theano.tensor.as_tensor_variable(axis) if axis is None:
axis = Constant(gof.generic, None)
else:
axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis], return theano.Apply(self, [input, axis],
[theano.tensor.TensorType(dtype="int64", broadcastable=input.type.broadcastable)()]) [theano.tensor.TensorType(dtype="int64", broadcastable=input.type.broadcastable)()])
...@@ -5857,13 +5859,13 @@ class ArgSortOp(theano.Op): ...@@ -5857,13 +5859,13 @@ class ArgSortOp(theano.Op):
a = inputs[0] a = inputs[0]
axis = inputs[1] axis = inputs[1]
z = output_storage[0] z = output_storage[0]
z[0] = numpy.argsort(a,axis,self.kind,self.order) z[0] = numpy.argsort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes): def infer_shape(self, node, inputs_shapes):
return [inputs_shapes[0]] return [inputs_shapes[0]]
#**** No grad defined for intergers.
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
#No grad defined for intergers.
return [None, None] return [None, None]
""" """
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -5877,10 +5879,11 @@ class ArgSortOp(theano.Op): ...@@ -5877,10 +5879,11 @@ class ArgSortOp(theano.Op):
""" """
def argSort(a, axis=-1, kind='quicksort', order=None): def argsort(a, axis=-1, kind='quicksort', order=None):
""" """
Returns the indices that would sort an array. Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified by the kind keyword. Perform an indirect sort along the given axis using the algorithm specified by the kind keyword.
It returns an array of indices of the same shape as a that index data along the given axis in sorted order. It returns an array of indices of the same shape as a that index data along the given axis in sorted order.
""" """
return ArgSortOp(kind, order)(a, axis) return ArgSortOp(kind, order)(a, axis)
\ No newline at end of file
\ No newline at end of file
...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, sort, SortOp, argSort, ArgSortOp,) tile, patternbroadcast, sort, SortOp, argsort, ArgSortOp,)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -5648,14 +5648,14 @@ def test_argsort(): ...@@ -5648,14 +5648,14 @@ def test_argsort():
#Example 1 #Example 1
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
w = argSort(a) w = argsort(a)
f = theano.function([a], w) f = theano.function([a], w)
assert numpy.allclose(f(m_val), numpy.argsort(m_val)) assert numpy.allclose(f(m_val), numpy.argsort(m_val))
#Example 2 #Example 2
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
axis = theano.tensor.scalar() axis = theano.tensor.scalar()
w = argSort(a, axis) w = argsort(a, axis)
f = theano.function([a, axis], w) f = theano.function([a, axis], w)
for axis_val in 0, 1: for axis_val in 0, 1:
assert numpy.allclose( assert numpy.allclose(
...@@ -5664,14 +5664,14 @@ def test_argsort(): ...@@ -5664,14 +5664,14 @@ def test_argsort():
#Example 3 #Example 3
a = theano.tensor.dvector() a = theano.tensor.dvector()
w2 = argSort(a) w2 = argsort(a)
f = theano.function([a], w2) f = theano.function([a], w2)
assert numpy.allclose(f(v_val), numpy.argsort(v_val)) assert numpy.allclose(f(v_val), numpy.argsort(v_val))
#Example 4 #Example 4
a = theano.tensor.dmatrix() a = theano.tensor.dmatrix()
axis = theano.tensor.scalar() axis = theano.tensor.scalar()
l = argSort(a, axis, "mergesort") l = argsort(a, axis, "mergesort")
f = theano.function([a, axis], l) f = theano.function([a, axis], l)
for axis_val in 0, 1: for axis_val in 0, 1:
assert numpy.allclose( assert numpy.allclose(
...@@ -5688,6 +5688,12 @@ def test_argsort(): ...@@ -5688,6 +5688,12 @@ def test_argsort():
assert a1 == ArgSortOp("mergesort", []) assert a1 == ArgSortOp("mergesort", [])
assert a2 == ArgSortOp("quicksort", []) assert a2 == ArgSortOp("quicksort", [])
#Example 6: Testing axis=None
a = theano.tensor.dmatrix()
w2 = argsort(a, None)
f = theano.function([a], w2)
assert numpy.allclose(f(m_val), numpy.argsort(m_val, None))
if __name__ == '__main__': if __name__ == '__main__':
if 0: if 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论