提交 346969b3 authored 作者: Frederic's avatar Frederic

Make the MaxAndArgmax op work with theano constant.

上级 92e2e0a0
......@@ -2183,6 +2183,11 @@ class MaxAndArgmax(Op):
def make_node(self, x, axis=None):
x = _as_tensor_variable(x)
if isinstance(axis, Variable):
if not isinstance(axis, Constant):
raise TypeError("MaxAndArgmax need a constant axis")
axis = [axis.data]
if isinstance(axis, int):
axis = [axis]
elif isinstance(axis, (tuple, list)):
......@@ -2192,6 +2197,7 @@ class MaxAndArgmax(Op):
assert axis == range(x.type.ndim), (
"MaxAndArgmax does not support multiple"
" axes. the max fct supports it.")
# we make the axis all positive to make the infer_shape work
# with negative axis
if x.type.ndim > 0 and axis is not None:
......
......@@ -1790,7 +1790,8 @@ class T_max_and_argmax(unittest.TestCase):
data = rand(2, 3)
n = as_tensor_variable(data)
for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
([0, 1], None), ([1, 0], None)]:
([0, 1], None), ([1, 0], None),
(constant(0), 0)]:
v, i = eval_outputs(max_and_argmax(n, axis))
assert i.dtype == 'int64'
self.assertTrue(numpy.all(v == numpy.max(data, np_axis)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论