提交 07e7d732 authored 作者: Frederic's avatar Frederic

repair just introduced tensor.max and tensor.min to make them work with more then 1 axis.

上级 b5a05a87
...@@ -2379,11 +2379,12 @@ def max(x, axis=None, keepdims=False): ...@@ -2379,11 +2379,12 @@ def max(x, axis=None, keepdims=False):
if isinstance(axis, (list, tuple)) and len(axis) > 1: if isinstance(axis, (list, tuple)) and len(axis) > 1:
out = CAReduce(scal.maximum, axis)(x) out = CAReduce(scal.maximum, axis)(x)
try: else:
const = get_constant_value(axis) try:
out = CAReduce(scal.maximum, list(const))(x) const = get_constant_value(axis)
except Exception: out = CAReduce(scal.maximum, list(const))(x)
out = max_and_argmax(x, axis)[0] except Exception:
out = max_and_argmax(x, axis)[0]
if keepdims: if keepdims:
out = makeKeepDims(x, out, axis) out = makeKeepDims(x, out, axis)
......
...@@ -28,7 +28,7 @@ class TestKeepDims: ...@@ -28,7 +28,7 @@ class TestKeepDims:
# 'max_and_argmax' has two outputs and can be specified with either # 'max_and_argmax' has two outputs and can be specified with either
# a single or every axis: # a single or every axis:
for axis in [[0], [1], [2]]: for axis in [[0], [1], [2], None, [0, 1, 2]]:
op = tensor.max_and_argmax op = tensor.max_and_argmax
keep_param = function([x], op(x, axis=axis, keepdims=True)[0]) keep_param = function([x], op(x, axis=axis, keepdims=True)[0])
...@@ -47,9 +47,9 @@ class TestKeepDims: ...@@ -47,9 +47,9 @@ class TestKeepDims:
# the following ops can be specified with either a single axis or every # the following ops can be specified with either a single axis or every
# axis: # axis:
for op in ([tensor.argmax, tensor.max, tensor.argmin, tensor.min]): for op in ([tensor.argmax, tensor.argmin]):
for axis in [[0], [1], [2]]: for axis in [[0], [1], [2], None, [0, 1, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True)) keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
...@@ -68,7 +68,8 @@ class TestKeepDims: ...@@ -68,7 +68,8 @@ class TestKeepDims:
# the following ops can be specified with a freely specified axis # the following ops can be specified with a freely specified axis
# parameter # parameter
for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var, for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var,
tensor.std, tensor.all, tensor.any]): tensor.std, tensor.all, tensor.any,
tensor.max, tensor.min]):
for axis in [[0], [1], [2], [0, 1], [1, 2], [0, 1, 2]]: for axis in [[0], [1], [2], [0, 1], [1, 2], [0, 1, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True)) keep_param = function([x], op(x, axis=axis, keepdims=True))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论