提交 9aa725a9 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge pull request #596 from lamblin/fix_argsort_none

Fixes to argsort when axis is None.
......@@ -6305,7 +6305,7 @@ class SortOp(theano.Op):
# So there should be the same number of dimensions
# in the input and output
assert node.inputs[0].ndim == node.outputs[0].ndim
assert inputs_shapes[1] is ()
assert inputs_shapes[1] == ()
return [inputs_shapes[0]]
#**** It need the argsort, so we can't do it now.
......@@ -6363,16 +6363,19 @@ class ArgSortOp(theano.Op):
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
def __str__(self):
return self.__class__.__name__ + "{%s, %s}" % (self.kind, str(self.order))
return (self.__class__.__name__
+ "{%s, %s}" % (self.kind, str(self.order)))
def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input)
if axis is None:
axis = Constant(gof.generic, None)
bcast = [False]
else:
axis = theano.tensor.as_tensor_variable(axis)
bcast = input.type.broadcastable
return theano.Apply(self, [input, axis],
[theano.tensor.TensorType(dtype="int64", broadcastable=input.type.broadcastable)()])
[theano.tensor.TensorType(dtype="int64", broadcastable=bcast)()])
def perform(self, node, inputs, output_storage):
a = inputs[0]
......@@ -6381,6 +6384,13 @@ class ArgSortOp(theano.Op):
z[0] = numpy.argsort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes):
if (isinstance(node.inputs[1], Constant) and
node.inputs[1].data is None):
return [(mul(*inputs_shapes[0]),)]
# axis should not be None, so there should be the same number of
# dimensions in the input and output
assert node.inputs[0].ndim == node.outputs[0].ndim
assert inputs_shapes[1] == ()
return [inputs_shapes[0]]
def grad(self, inputs, output_grads):
......@@ -6401,8 +6411,10 @@ class ArgSortOp(theano.Op):
def argsort(a, axis=-1, kind='quicksort', order=None):
"""
Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified by the kind keyword.
It returns an array of indices of the same shape as a that index data along the given axis in sorted order.
"""
Perform an indirect sort along the given axis using the algorithm
specified by the kind keyword. It returns an array of indices of
the same shape as a that index data along the given axis in sorted
order.
"""
return ArgSortOp(kind, order)(a, axis)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论