提交 5c7a72fe authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

expand unit tests to include integer axes

上级 1d8f0780
...@@ -10,6 +10,8 @@ class TestKeepDims: ...@@ -10,6 +10,8 @@ class TestKeepDims:
if axis is None: if axis is None:
axis = numpy.arange(x.ndim) axis = numpy.arange(x.ndim)
elif isinstance(axis, int):
axis = [axis]
i = 0 i = 0
new_dims = [] new_dims = []
for j, _ in enumerate(x.shape): for j, _ in enumerate(x.shape):
...@@ -28,7 +30,7 @@ class TestKeepDims: ...@@ -28,7 +30,7 @@ class TestKeepDims:
# 'max_and_argmax' has two outputs and can be specified with either # 'max_and_argmax' has two outputs and can be specified with either
# a single or every axis: # a single or every axis:
for axis in [[0], [1], [2], None, [0, 1, 2]]: for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2]]:
op = tensor.max_and_argmax op = tensor.max_and_argmax
keep_param = function([x], op(x, axis=axis, keepdims=True)[0]) keep_param = function([x], op(x, axis=axis, keepdims=True)[0])
...@@ -49,7 +51,7 @@ class TestKeepDims: ...@@ -49,7 +51,7 @@ class TestKeepDims:
# axis: # axis:
for op in ([tensor.argmax, tensor.argmin]): for op in ([tensor.argmax, tensor.argmin]):
for axis in [[0], [1], [2], None, [0, 1, 2]]: for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True)) keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
...@@ -70,7 +72,7 @@ class TestKeepDims: ...@@ -70,7 +72,7 @@ class TestKeepDims:
for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var, for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var,
tensor.std, tensor.all, tensor.any, tensor.std, tensor.all, tensor.any,
tensor.max, tensor.min]): tensor.max, tensor.min]):
for axis in [[0], [1], [2], [0, 1], [1, 2], [0, 1, 2]]: for axis in [0, 1, 2, [0], [1], [2], [0, 1], [1, 2], [0, 1, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True)) keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论