提交 140e2f04 authored 作者: Frederic's avatar Frederic

[CRASH] Make MaxAndArgmax support NoneConst as axis.

上级 257a7a31
...@@ -1238,7 +1238,7 @@ class MaxAndArgmax(Op): ...@@ -1238,7 +1238,7 @@ class MaxAndArgmax(Op):
axis = range(x.type.ndim) axis = range(x.type.ndim)
assert axis == range(x.type.ndim), ( assert axis == range(x.type.ndim), (
"MaxAndArgmax does not support multiple" "MaxAndArgmax does not support multiple"
" axes. the max fct supports it.") " axes. the max fct supports it. Got %s" % axis)
axis = None axis = None
else: else:
axis = axis[0] axis = axis[0]
...@@ -1248,8 +1248,12 @@ class MaxAndArgmax(Op): ...@@ -1248,8 +1248,12 @@ class MaxAndArgmax(Op):
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0: elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = int(axis) axis = int(axis)
elif isinstance(axis, Variable): elif isinstance(axis, Variable):
if not isinstance(axis, TensorConstant): if NoneConst.equals(axis):
raise TypeError("MaxAndArgmax needs a constant axis") axis = None
elif not isinstance(axis, TensorConstant):
raise TypeError(
"MaxAndArgmax needs a constant axis. Got %s" % axis)
else:
assert (axis.dtype.startswith("int") assert (axis.dtype.startswith("int")
or axis.dtype.startswith("uint")) or axis.dtype.startswith("uint"))
axis = int(axis.data) axis = int(axis.data)
......
...@@ -45,7 +45,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -45,7 +45,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
itensor3, Tile, switch, Diagonal, Diag, itensor3, Tile, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values, nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle, hessian, ptp, power, stacklists, DimShuffle, hessian, ptp, power,
swapaxes, choose, Choose swapaxes, choose, Choose, NoneConst,
) )
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -2722,6 +2722,7 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -2722,6 +2722,7 @@ class T_max_and_argmax(unittest.TestCase):
n = as_tensor_variable(data) n = as_tensor_variable(data)
for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None), for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
([0, 1], None), ([1, 0], None), ([0, 1], None), ([1, 0], None),
(NoneConst.clone(), None),
(constant(0), 0)]: (constant(0), 0)]:
v, i = eval_outputs(max_and_argmax(n, axis)) v, i = eval_outputs(max_and_argmax(n, axis))
assert i.dtype == 'int64' assert i.dtype == 'int64'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论