提交 49c0ca0d authored 作者: Frederic's avatar Frederic

Make min(tuple(x, y)) and argmin(...) work as max and argmax.

上级 46b19c24
...@@ -1650,7 +1650,7 @@ def min(x, axis=None, keepdims=False): ...@@ -1650,7 +1650,7 @@ def min(x, axis=None, keepdims=False):
the result as dimensions with size one. With this option, the result the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor. will broadcast correctly against the original tensor.
""" """
x = as_tensor_variable(x)
str_x_type = str(x.dtype) str_x_type = str(x.dtype)
if str_x_type.startswith('float') or str_x_type in int_dtypes: if str_x_type.startswith('float') or str_x_type in int_dtypes:
return -max(-x, axis=axis, keepdims=keepdims) return -max(-x, axis=axis, keepdims=keepdims)
...@@ -1671,7 +1671,7 @@ def argmin(x, axis=None, keepdims=False): ...@@ -1671,7 +1671,7 @@ def argmin(x, axis=None, keepdims=False):
the result as dimensions with size one. With this option, the result the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor. will broadcast correctly against the original tensor.
""" """
x = as_tensor_variable(x)
str_x_type = str(x.dtype) str_x_type = str(x.dtype)
if str_x_type.startswith('float') or str_x_type in int_dtypes: if str_x_type.startswith('float') or str_x_type in int_dtypes:
return argmax(-x, axis=axis, keepdims=keepdims) return argmax(-x, axis=axis, keepdims=keepdims)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论