提交 375d0c61 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make MaxAndArgmax support float16

上级 f6437fd6
......@@ -1219,6 +1219,7 @@ class MaxAndArgmax(Op):
E_axis = 'invalid axis'
params_type = Generic()
__props__ = ('axis',)
_f16_ok = True
def __init__(self, axis):
assert isinstance(axis, list)
......
......@@ -3091,6 +3091,21 @@ class T_max_and_argmax(unittest.TestCase):
v_shape = eval_outputs(max_and_argmax(n, axis)[0].shape)
assert tuple(v_shape) == np.max(data, np_axis).shape
def test2_float16(self):
data = rand(2, 3).astype("float16")
n = shared(data)
mode = get_default_mode().including("local_max_and_argmax", "uncanonicalize")
for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
([0, 1], None), ([1, 0], None),
(NoneConst.clone(), None),
(constant(0), 0)]:
v, i = eval_outputs(max_and_argmax(n, axis), (MaxAndArgmax,))
assert i.dtype == 'int64'
self.assertTrue(np.all(v == np.max(data, np_axis)))
self.assertTrue(np.all(i == np.argmax(data, np_axis)))
v_shape = eval_outputs(max_and_argmax(n, axis)[0].shape)
assert tuple(v_shape) == np.max(data, np_axis).shape
def test2_invalid(self):
n = as_tensor_variable(rand(2, 3))
# Silence expected error messages
......@@ -3315,7 +3330,7 @@ class T_argmin_argmax(unittest.TestCase):
([0, 1], None), ([1, 0], None)]:
v = eval_outputs(fct(n, axis), (Argmax,), mode=mode)
self.assertTrue(np.all(v == nfct(data, np_axis)))
v_shape = eval_outputs(fct(n, axis).shape)
v_shape = eval_outputs(fct(n, axis).shape, mode=mode)
assert tuple(v_shape) == nfct(data, np_axis).shape
def test2_invalid(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论