提交 11f6400a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fixes to argsort when axis is None.

上级 d596f2f0
......@@ -6369,10 +6369,12 @@ class ArgSortOp(theano.Op):
input = theano.tensor.as_tensor_variable(input)
if axis is None:
axis = Constant(gof.generic, None)
bcast = [False]
else:
axis = theano.tensor.as_tensor_variable(axis)
bcast = input.type.broadcastable
return theano.Apply(self, [input, axis],
[theano.tensor.TensorType(dtype="int64", broadcastable=input.type.broadcastable)()])
[theano.tensor.TensorType(dtype="int64", broadcastable=bcast)()])
def perform(self, node, inputs, output_storage):
a = inputs[0]
......@@ -6381,6 +6383,13 @@ class ArgSortOp(theano.Op):
z[0] = numpy.argsort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes):
if (isinstance(node.inputs[1], Constant) and
node.inputs[1].data is None):
return [(mul(*inputs_shapes[0]),)]
# axis should not be None, so there should be the same number of
# dimensions in the input and output
assert node.inputs[0].ndim == node.outputs[0].ndim
assert inputs_shapes[1] is ()
return [inputs_shapes[0]]
def grad(self, inputs, output_grads):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论