提交 08d45e3f authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed tensor.basic.max and documented its reasoning

上级 18150a14
......@@ -2674,16 +2674,22 @@ def max(x, axis=None, keepdims=False):
:note: we return an error as numpy when we reduce a dim with a shape of 0
"""
if isinstance(axis, (list, tuple)) and len(axis) > 1:
# We have a choice of implementing this call with the
# CAReduce op or the MaxAndArgmax op.
# MaxAndArgmax supports grad and Rop, so we prefer to use that.
# CAReduce is faster, but optimizations will replace MaxAndArgmax[0]
# with CAReduce at compile time, so at this stage the important
# thing is supporting all user interface features, not speed.
# Some cases can be implemented only with CAReduce.
# We thus prefer to use MaxAndArgmax, if possible. It does not
# support all axis arguments, so we may need to fall back to CAReduce.
try:
out = max_and_argmax(x, axis)[0]
except:
out = CAReduce(scal.maximum, axis)(x)
else:
try:
const = get_scalar_constant_value(axis)
assert const.dtype in discrete_dtypes
const = int(const)
out = CAReduce(scal.maximum, [const])(x)
except NotScalarConstantError:
out = max_and_argmax(x, axis)[0]
if keepdims:
out = makeKeepDims(x, out, axis)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论