提交 d8bf2677 authored 作者: Frederic's avatar Frederic

fix crash introduced yesterday in MaxAndArgmax.

When we put a TensorVariable in a list, when we call as_tensor_variable on it, this create a 2d tensor. We do not want that.
上级 8317a572
......@@ -2183,10 +2183,6 @@ class MaxAndArgmax(Op):
def make_node(self, x, axis=None):
x = _as_tensor_variable(x)
if isinstance(axis, Variable):
if not isinstance(axis, Constant):
raise TypeError("MaxAndArgmax needs a constant axis")
axis = [axis.data]
if isinstance(axis, int):
axis = [axis]
......@@ -2197,7 +2193,12 @@ class MaxAndArgmax(Op):
assert axis == range(x.type.ndim), (
"MaxAndArgmax does not support multiple"
" axes. the max fct supports it.")
elif isinstance(axis, Variable):
if not isinstance(axis, TensorConstant):
raise TypeError("MaxAndArgmax needs a constant axis")
axis = axis.data
if axis.ndim == 0:
axis = [axis]
# we make the axis all positive to make the infer_shape work
# with negative axis
if x.type.ndim > 0 and axis is not None:
......@@ -2218,8 +2219,8 @@ class MaxAndArgmax(Op):
raise ValueError(
'Invalid axis: %s (the number of dimensions of the '
'input is: %s)' % (axis, x.type.ndim))
all_axes.add(ax)
all_axes.add(ax.item())
assert axis.ndim == 1
inputs = [x, axis]
# We keep the original broadcastable flags for dimensions on which
# we do not perform the max / argmax.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论