提交 2c69de91 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5232 from ferrine/max_and_argmax_fail

noticed that max_and_argmax fails, test + fix
......@@ -1602,6 +1602,7 @@ def max_and_argmax(a, axis=None, keepdims=False):
"""
# Check axis and convert it to a Python list of integers.
# Axis will be used as an op param of MaxAndArgmax.
a = as_tensor_variable(a)
if axis is None:
axis = list(range(a.type.ndim))
elif (isinstance(axis, (integer_types, numpy.integer)) or
......
......@@ -3091,6 +3091,12 @@ class T_max_and_argmax(unittest.TestCase):
assert mv.shape == (0,)
assert iv.shape == (0,)
def test_numpy_input(self):
ar = numpy.array([1, 2, 3])
max, argmax = max_and_argmax(ar, axis=None)
self.assertEqual(max.eval(), 3)
self.assertEqual(argmax.eval(), 2)
class T_argmin_argmax(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论