提交 f9d3e701 authored 作者: Frederic's avatar Frederic

Small test speed up

上级 3bbc2f2c
...@@ -51,16 +51,19 @@ class TestKeepDims(unittest.TestCase): ...@@ -51,16 +51,19 @@ class TestKeepDims(unittest.TestCase):
keep_param = function([x], op(x, axis=axis, keepdims=True)[0]) keep_param = function([x], op(x, axis=axis, keepdims=True)[0])
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False)[0], axis)) op(x, axis=axis, keepdims=False)[0], axis))
ans1 = keep_param(a)
assert numpy.allclose(keep_param(a), keep_synth(a)) ans2 = keep_synth(a)
assert keep_param(a).shape == keep_synth(a).shape assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
keep_param = function([x], op(x, axis=axis, keepdims=True)[1]) keep_param = function([x], op(x, axis=axis, keepdims=True)[1])
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False)[1], axis)) op(x, axis=axis, keepdims=False)[1], axis))
assert numpy.allclose(keep_param(a), keep_synth(a)) ans1 = keep_param(a)
assert keep_param(a).shape == keep_synth(a).shape ans2 = keep_synth(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
# the following ops can be specified with either a single axis or every # the following ops can be specified with either a single axis or every
# axis: # axis:
...@@ -72,15 +75,19 @@ class TestKeepDims(unittest.TestCase): ...@@ -72,15 +75,19 @@ class TestKeepDims(unittest.TestCase):
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False), axis)) op(x, axis=axis, keepdims=False), axis))
assert numpy.allclose(keep_param(a), keep_synth(a)) ans1 = keep_param(a)
assert keep_param(a).shape == keep_synth(a).shape ans2 = keep_synth(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
keep_param = function([x], op(x, axis=None, keepdims=True)) keep_param = function([x], op(x, axis=None, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=None, keepdims=False), None)) op(x, axis=None, keepdims=False), None))
assert numpy.allclose(keep_param(a), keep_synth(a)) ans1 = keep_param(a)
assert keep_param(a).shape == keep_synth(a).shape ans2 = keep_synth(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
# the following ops can be specified with a freely specified axis # the following ops can be specified with a freely specified axis
# parameter # parameter
...@@ -94,12 +101,16 @@ class TestKeepDims(unittest.TestCase): ...@@ -94,12 +101,16 @@ class TestKeepDims(unittest.TestCase):
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False), axis)) op(x, axis=axis, keepdims=False), axis))
assert numpy.allclose(keep_param(a), keep_synth(a)) ans1 = keep_param(a)
assert keep_param(a).shape == keep_synth(a).shape ans2 = keep_synth(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
keep_param = function([x], op(x, axis=None, keepdims=True)) keep_param = function([x], op(x, axis=None, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=None, keepdims=False), None)) op(x, axis=None, keepdims=False), None))
assert numpy.allclose(keep_param(a), keep_synth(a)) ans1 = keep_param(a)
assert keep_param(a).shape == keep_synth(a).shape ans2 = keep_synth(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论