提交 95a63ec7 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Enables axis=None in tensor.sort

上级 8dec19eb
......@@ -5790,12 +5790,15 @@ class SortOp(theano.Op):
str(self.order))
def make_node(self, input, axis=-1):
if axis is None:
raise ValueError("Current Implementation does not support"
" axis=None")
input = theano.tensor.as_tensor_variable(input)
axis = theano.tensor.as_tensor_variable(axis)
return theano.Apply(self, [input, axis], [input.type()])
if axis is None:
axis = Constant(gof.generic, None)
# axis=None flattens the array before sorting
out_type = tensor(dtype=input.dtype, broadcastable=[False])
else:
axis = theano.tensor.as_tensor_variable(axis)
out_type = input.type()
return theano.Apply(self, [input, axis], [out_type])
def perform(self, node, inputs, output_storage):
a = inputs[0]
......@@ -5804,6 +5807,10 @@ class SortOp(theano.Op):
z[0] = numpy.sort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes):
if inputs_shapes[1] is None:
# That probably means axis = None,
# so the array is flattened before being sorted
return [(mul(*inputs_shapes[0]),)]
return [inputs_shapes[0]]
#**** It need the argsort, so we can't do it now.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论