提交 a7fb0554 authored 作者: naitonium's avatar naitonium

fix min function to work with uint* and bool

上级 b6b5f4d1
......@@ -1715,6 +1715,11 @@ def min(x, axis=None, keepdims=False):
str_x_type = str(x.dtype)
if str_x_type.startswith('float') or str_x_type in int_dtypes:
return -max(-x, axis=axis, keepdims=keepdims)
elif str_x_type in uint_dtypes:
itype = np.iinfo(x.dtype)
return itype.max - max(itype.max - x, axis=axis, keepdims=keepdims)
elif str_x_type == 'bool':
return ~max(~x, axis=axis, keepdims=keepdims)
else:
# Be careful about unsigned integers, complex
raise NotImplementedError()
......@@ -1740,6 +1745,11 @@ def argmin(x, axis=None, keepdims=False):
str_x_type = str(x.dtype)
if str_x_type.startswith('float') or str_x_type in int_dtypes:
return argmax(-x, axis=axis, keepdims=keepdims)
elif str_x_type in uint_dtypes:
itype = np.iinfo(x.dtype)
return argmax(itype.max - x, axis=axis, keepdims=keepdims)
elif str_x_type == 'bool':
return argmax(~x, axis=axis, keepdims=keepdims)
else:
# Be careful about unsigned integers, complex
raise NotImplementedError()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论