提交 847e8ceb authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2635 from carriepl/fix_stochastic_tests

Apply decorator for stochastic tests on test_run_nnet
...@@ -162,7 +162,7 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp): ...@@ -162,7 +162,7 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
utt.assert_allclose(np.cumsum(a, axis=axis), f(a)) utt.assert_allclose(np.cumsum(a, axis=axis), f(a))
# Use multiple GPU gridblocks # Use multiple GPU gridblocks
a_shape = [5, 5] a_shape = [4, 4]
a_shape[1-shape_axis] = self.max_grid_size1+1 a_shape[1-shape_axis] = self.max_grid_size1+1
a = np.random.random(a_shape).astype("float32") a = np.random.random(a_shape).astype("float32")
utt.assert_allclose(np.cumsum(a, axis=axis), f(a), rtol=5e-5) utt.assert_allclose(np.cumsum(a, axis=axis), f(a), rtol=5e-5)
......
...@@ -125,6 +125,7 @@ def run_nnet(use_gpu, n_batch=60, n_in=1024, n_hid=2048, n_out=10, ...@@ -125,6 +125,7 @@ def run_nnet(use_gpu, n_batch=60, n_in=1024, n_hid=2048, n_out=10,
return numpy.asarray(rval), dt return numpy.asarray(rval), dt
@utt.AttemptManyTimes(n_attempts=3, n_req_successes=1)
def test_run_nnet(): def test_run_nnet():
for n_in in 1024, 2048, 4096: for n_in in 1024, 2048, 4096:
for n_hid in 1024, 2048, 4096: for n_hid in 1024, 2048, 4096:
......
from copy import copy, deepcopy from copy import copy, deepcopy
from functools import wraps
import logging import logging
from StringIO import StringIO from StringIO import StringIO
import sys import sys
...@@ -370,6 +371,7 @@ class AttemptManyTimes: ...@@ -370,6 +371,7 @@ class AttemptManyTimes:
# Wrap fct in a function that will attempt to run it multiple # Wrap fct in a function that will attempt to run it multiple
# times and return the result if the test passes enough times # times and return the result if the test passes enough times
# of propagate the raised exception if it doesn't. # of propagate the raised exception if it doesn't.
@wraps(fct)
def attempt_multiple_times(*args, **kwargs): def attempt_multiple_times(*args, **kwargs):
# Keep a copy of the current seed for unittests so that we can use # Keep a copy of the current seed for unittests so that we can use
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论