提交 9aa725a9 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge pull request #596 from lamblin/fix_argsort_none

Fixes to argsort when axis is None.
...@@ -6305,7 +6305,7 @@ class SortOp(theano.Op): ...@@ -6305,7 +6305,7 @@ class SortOp(theano.Op):
# So there should be the same number of dimensions # So there should be the same number of dimensions
# in the input and output # in the input and output
assert node.inputs[0].ndim == node.outputs[0].ndim assert node.inputs[0].ndim == node.outputs[0].ndim
assert inputs_shapes[1] is () assert inputs_shapes[1] == ()
return [inputs_shapes[0]] return [inputs_shapes[0]]
#**** It need the argsort, so we can't do it now. #**** It need the argsort, so we can't do it now.
...@@ -6363,16 +6363,19 @@ class ArgSortOp(theano.Op): ...@@ -6363,16 +6363,19 @@ class ArgSortOp(theano.Op):
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind) return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
def __str__(self): def __str__(self):
return self.__class__.__name__ + "{%s, %s}" % (self.kind, str(self.order)) return (self.__class__.__name__
+ "{%s, %s}" % (self.kind, str(self.order)))
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input) input = theano.tensor.as_tensor_variable(input)
if axis is None: if axis is None:
axis = Constant(gof.generic, None) axis = Constant(gof.generic, None)
bcast = [False]
else: else:
axis = theano.tensor.as_tensor_variable(axis) axis = theano.tensor.as_tensor_variable(axis)
bcast = input.type.broadcastable
return theano.Apply(self, [input, axis], return theano.Apply(self, [input, axis],
[theano.tensor.TensorType(dtype="int64", broadcastable=input.type.broadcastable)()]) [theano.tensor.TensorType(dtype="int64", broadcastable=bcast)()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
a = inputs[0] a = inputs[0]
...@@ -6381,6 +6384,13 @@ class ArgSortOp(theano.Op): ...@@ -6381,6 +6384,13 @@ class ArgSortOp(theano.Op):
z[0] = numpy.argsort(a, axis, self.kind, self.order) z[0] = numpy.argsort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes): def infer_shape(self, node, inputs_shapes):
if (isinstance(node.inputs[1], Constant) and
node.inputs[1].data is None):
return [(mul(*inputs_shapes[0]),)]
# axis should not be None, so there should be the same number of
# dimensions in the input and output
assert node.inputs[0].ndim == node.outputs[0].ndim
assert inputs_shapes[1] == ()
return [inputs_shapes[0]] return [inputs_shapes[0]]
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
...@@ -6401,8 +6411,10 @@ class ArgSortOp(theano.Op): ...@@ -6401,8 +6411,10 @@ class ArgSortOp(theano.Op):
def argsort(a, axis=-1, kind='quicksort', order=None): def argsort(a, axis=-1, kind='quicksort', order=None):
""" """
Returns the indices that would sort an array. Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified by the kind keyword.
It returns an array of indices of the same shape as a that index data along the given axis in sorted order. Perform an indirect sort along the given axis using the algorithm
""" specified by the kind keyword. It returns an array of indices of
the same shape as a that index data along the given axis in sorted
order.
"""
return ArgSortOp(kind, order)(a, axis) return ArgSortOp(kind, order)(a, axis)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论