提交 610a449d authored 作者: Frederic's avatar Frederic

Make test faster by compiling less theano fct.

上级 8af088f7
......@@ -48,20 +48,19 @@ class TestKeepDims(unittest.TestCase):
[-2, -3, 2]]:
op = tensor.max_and_argmax
keep_param = function([x], op(x, axis=axis, keepdims=True)[0])
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False)[0], axis))
ans1 = keep_param(a)
ans2 = keep_synth(a)
f = function([x], [op(x, axis=axis, keepdims=True)[0],
self.makeKeepDims_local(
x, op(x, axis=axis, keepdims=False)[0],
axis)])
ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
keep_param = function([x], op(x, axis=axis, keepdims=True)[1])
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False)[1], axis))
ans1 = keep_param(a)
ans2 = keep_synth(a)
f = function([x], [op(x, axis=axis, keepdims=True)[1],
self.makeKeepDims_local(
x, op(x, axis=axis, keepdims=False)[1],
axis)])
ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
......@@ -71,12 +70,11 @@ class TestKeepDims(unittest.TestCase):
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2],
[-1], [-2], [-3], [-1, -2, -3], [0, -2, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False), axis))
ans1 = keep_param(a)
ans2 = keep_synth(a)
f = function([x], [op(x, axis=axis, keepdims=True),
self.makeKeepDims_local(
x, op(x, axis=axis, keepdims=False),
axis)])
ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
......@@ -89,11 +87,11 @@ class TestKeepDims(unittest.TestCase):
[0, 1], [1, 2], [0, 1, 2],
[-1], [-2], [-3], [-1, -2], [-1, -2, -3], [0, -2, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False), axis))
f = function([x], [op(x, axis=axis, keepdims=True),
self.makeKeepDims_local(
x, op(x, axis=axis, keepdims=False),
axis)])
ans1 = keep_param(a)
ans2 = keep_synth(a)
ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论