提交 140e2f04 authored 作者: Frederic's avatar Frederic

[CRASH] Make MaxAndArgmax support NoneConst as axis.

上级 257a7a31
......@@ -1238,7 +1238,7 @@ class MaxAndArgmax(Op):
axis = range(x.type.ndim)
assert axis == range(x.type.ndim), (
"MaxAndArgmax does not support multiple"
" axes. the max fct supports it.")
" axes. the max fct supports it. Got %s" % axis)
axis = None
else:
axis = axis[0]
......@@ -1248,11 +1248,15 @@ class MaxAndArgmax(Op):
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = int(axis)
elif isinstance(axis, Variable):
if not isinstance(axis, TensorConstant):
raise TypeError("MaxAndArgmax needs a constant axis")
assert (axis.dtype.startswith("int")
or axis.dtype.startswith("uint"))
axis = int(axis.data)
if NoneConst.equals(axis):
axis = None
elif not isinstance(axis, TensorConstant):
raise TypeError(
"MaxAndArgmax needs a constant axis. Got %s" % axis)
else:
assert (axis.dtype.startswith("int")
or axis.dtype.startswith("uint"))
axis = int(axis.data)
# we make the axis all positive to make the infer_shape work
# with negative axis
if x.type.ndim > 0 and axis is not None:
......
......@@ -45,7 +45,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
itensor3, Tile, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle, hessian, ptp, power,
swapaxes, choose, Choose
swapaxes, choose, Choose, NoneConst,
)
from theano.tests import unittest_tools as utt
......@@ -2722,6 +2722,7 @@ class T_max_and_argmax(unittest.TestCase):
n = as_tensor_variable(data)
for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
([0, 1], None), ([1, 0], None),
(NoneConst.clone(), None),
(constant(0), 0)]:
v, i = eval_outputs(max_and_argmax(n, axis))
assert i.dtype == 'int64'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论