提交 0d47e204 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6005 from naitonium/fix-min-function

Fix min function to work with uint* and bool
...@@ -1764,6 +1764,12 @@ def min(x, axis=None, keepdims=False): ...@@ -1764,6 +1764,12 @@ def min(x, axis=None, keepdims=False):
str_x_type = str(x.dtype) str_x_type = str(x.dtype)
if str_x_type.startswith('float') or str_x_type in int_dtypes: if str_x_type.startswith('float') or str_x_type in int_dtypes:
return -max(-x, axis=axis, keepdims=keepdims) return -max(-x, axis=axis, keepdims=keepdims)
elif str_x_type in uint_dtypes:
itype = np.iinfo(x.dtype)
max_val = np.array(itype.max, dtype=itype.dtype)
return max_val - max(max_val - x, axis=axis, keepdims=keepdims)
elif str_x_type == 'bool':
return ~max(~x, axis=axis, keepdims=keepdims)
else: else:
# Be careful about unsigned integers, complex # Be careful about unsigned integers, complex
raise NotImplementedError() raise NotImplementedError()
...@@ -1789,6 +1795,11 @@ def argmin(x, axis=None, keepdims=False): ...@@ -1789,6 +1795,11 @@ def argmin(x, axis=None, keepdims=False):
str_x_type = str(x.dtype) str_x_type = str(x.dtype)
if str_x_type.startswith('float') or str_x_type in int_dtypes: if str_x_type.startswith('float') or str_x_type in int_dtypes:
return argmax(-x, axis=axis, keepdims=keepdims) return argmax(-x, axis=axis, keepdims=keepdims)
elif str_x_type in uint_dtypes:
itype = np.iinfo(x.dtype)
return argmax(itype.max - x, axis=axis, keepdims=keepdims)
elif str_x_type == 'bool':
return argmax(~x, axis=axis, keepdims=keepdims)
else: else:
# Be careful about unsigned integers, complex # Be careful about unsigned integers, complex
raise NotImplementedError() raise NotImplementedError()
......
...@@ -3454,6 +3454,24 @@ class T_argmin_argmax(unittest.TestCase): ...@@ -3454,6 +3454,24 @@ class T_argmin_argmax(unittest.TestCase):
except TypeError: except TypeError:
pass pass
def test_uint(self):
for dtype in ('uint8', 'uint16', 'uint32', 'uint64'):
itype = np.iinfo(dtype)
data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype)
n = as_tensor_variable(data)
i = eval_outputs(argmin(n))
self.assertEqual(i, 1)
i = eval_outputs(argmax(n))
self.assertEqual(i, 3)
def test_bool(self):
data = np.array([True, False], 'bool')
n = as_tensor_variable(data)
i = eval_outputs(argmin(n))
self.assertEqual(i, 1)
i = eval_outputs(argmax(n))
self.assertEqual(i, 0)
class T_min_max(unittest.TestCase): class T_min_max(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -3639,6 +3657,28 @@ class T_min_max(unittest.TestCase): ...@@ -3639,6 +3657,28 @@ class T_min_max(unittest.TestCase):
# check_grad_max(data, eval_outputs(grad(max_and_argmax(n, # check_grad_max(data, eval_outputs(grad(max_and_argmax(n,
# axis=1)[0], n)),axis=1) # axis=1)[0], n)),axis=1)
def test_uint(self):
for dtype in ('uint8', 'uint16', 'uint32', 'uint64'):
itype = np.iinfo(dtype)
data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype)
n = as_tensor_variable(data)
self.assertEqual(min(n).dtype, dtype)
i = eval_outputs(min(n))
self.assertEqual(i, itype.min)
self.assertEqual(max(n).dtype, dtype)
i = eval_outputs(max(n))
self.assertEqual(i, itype.max)
def test_bool(self):
data = np.array([True, False], 'bool')
n = as_tensor_variable(data)
self.assertEqual(min(n).dtype, 'bool')
i = eval_outputs(min(n))
self.assertEqual(i, False)
self.assertEqual(max(n).dtype, 'bool')
i = eval_outputs(max(n))
self.assertEqual(i, True)
def test_basic_allclose(): def test_basic_allclose():
# This was raised by a user in https://github.com/Theano/Theano/issues/2975 # This was raised by a user in https://github.com/Theano/Theano/issues/2975
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论