提交 d1c8ad7e authored 作者: Frederic's avatar Frederic

Make sure the axis is None in the Sort.infer_shape.

上级 aae8382f
...@@ -5807,10 +5807,16 @@ class SortOp(theano.Op): ...@@ -5807,10 +5807,16 @@ class SortOp(theano.Op):
z[0] = numpy.sort(a, axis, self.kind, self.order) z[0] = numpy.sort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes): def infer_shape(self, node, inputs_shapes):
if inputs_shapes[1] is None: if (isinstance(node.inputs[1], Constant) and
# That probably means axis = None, node.inputs[1].data is None):
# so the array is flattened before being sorted # That means axis = None,
# So the array is flattened before being sorted
return [(mul(*inputs_shapes[0]),)] return [(mul(*inputs_shapes[0]),)]
# axis should not be None
# So there should be the same number of dimensions
# in the input and output
assert node.inputs[0].ndim == node.outputs[0].ndim
assert inputs_shapes[1] is ()
return [inputs_shapes[0]] return [inputs_shapes[0]]
#**** It need the argsort, so we can't do it now. #**** It need the argsort, so we can't do it now.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论