提交 610a449d authored 作者: Frederic's avatar Frederic

Make test faster by compiling less theano fct.

上级 8af088f7
...@@ -48,20 +48,19 @@ class TestKeepDims(unittest.TestCase): ...@@ -48,20 +48,19 @@ class TestKeepDims(unittest.TestCase):
[-2, -3, 2]]: [-2, -3, 2]]:
op = tensor.max_and_argmax op = tensor.max_and_argmax
keep_param = function([x], op(x, axis=axis, keepdims=True)[0]) f = function([x], [op(x, axis=axis, keepdims=True)[0],
keep_synth = function([x], self.makeKeepDims_local(x, self.makeKeepDims_local(
op(x, axis=axis, keepdims=False)[0], axis)) x, op(x, axis=axis, keepdims=False)[0],
ans1 = keep_param(a) axis)])
ans2 = keep_synth(a) ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2) assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape assert ans1.shape == ans2.shape
keep_param = function([x], op(x, axis=axis, keepdims=True)[1]) f = function([x], [op(x, axis=axis, keepdims=True)[1],
keep_synth = function([x], self.makeKeepDims_local(x, self.makeKeepDims_local(
op(x, axis=axis, keepdims=False)[1], axis)) x, op(x, axis=axis, keepdims=False)[1],
axis)])
ans1 = keep_param(a) ans1, ans2 = f(a)
ans2 = keep_synth(a)
assert numpy.allclose(ans1, ans2) assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape assert ans1.shape == ans2.shape
...@@ -71,12 +70,11 @@ class TestKeepDims(unittest.TestCase): ...@@ -71,12 +70,11 @@ class TestKeepDims(unittest.TestCase):
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2], for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2],
[-1], [-2], [-3], [-1, -2, -3], [0, -2, 2]]: [-1], [-2], [-3], [-1, -2, -3], [0, -2, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True)) f = function([x], [op(x, axis=axis, keepdims=True),
keep_synth = function([x], self.makeKeepDims_local(x, self.makeKeepDims_local(
op(x, axis=axis, keepdims=False), axis)) x, op(x, axis=axis, keepdims=False),
axis)])
ans1 = keep_param(a) ans1, ans2 = f(a)
ans2 = keep_synth(a)
assert numpy.allclose(ans1, ans2) assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape assert ans1.shape == ans2.shape
...@@ -89,11 +87,11 @@ class TestKeepDims(unittest.TestCase): ...@@ -89,11 +87,11 @@ class TestKeepDims(unittest.TestCase):
[0, 1], [1, 2], [0, 1, 2], [0, 1], [1, 2], [0, 1, 2],
[-1], [-2], [-3], [-1, -2], [-1, -2, -3], [0, -2, 2]]: [-1], [-2], [-3], [-1, -2], [-1, -2, -3], [0, -2, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True)) f = function([x], [op(x, axis=axis, keepdims=True),
keep_synth = function([x], self.makeKeepDims_local(x, self.makeKeepDims_local(
op(x, axis=axis, keepdims=False), axis)) x, op(x, axis=axis, keepdims=False),
axis)])
ans1 = keep_param(a) ans1, ans2 = f(a)
ans2 = keep_synth(a)
assert numpy.allclose(ans1, ans2) assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape assert ans1.shape == ans2.shape
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论