提交 006eaf78 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Enables specifying batch size

上级 e88dfffe
...@@ -29,11 +29,29 @@ import sys ...@@ -29,11 +29,29 @@ import sys
def main(): def main():
if '--batch' in sys.argv: # Handle --batch[=n] arguments
batch_args = [arg for arg in sys.argv if arg.startswith('--batch')]
for arg in batch_args:
sys.argv.remove(arg)
if len(batch_args):
if len(batch_args) > 1:
_logger.warn(
'Multiple --batch arguments detected, using the last one '
'and ignoring the first ones.')
batch_arg = batch_args[-1]
elems = batch_arg.split('=', 1)
if len(elems) == 2:
batch_size = int(elems[1])
else:
# Use run_tests_in_batch's default
batch_size = None
from theano.tests import run_tests_in_batch from theano.tests import run_tests_in_batch
sys.argv.remove('--batch') return run_tests_in_batch.main(batch_size=batch_size)
return run_tests_in_batch.main()
# Non-batch mode.
addplugins = [] addplugins = []
# We include KnownFailure plugin by default, unless # We include KnownFailure plugin by default, unless
# it is disabled by the "--without-knownfailure" arg. # it is disabled by the "--without-knownfailure" arg.
......
...@@ -36,7 +36,7 @@ import cPickle, os, subprocess, sys ...@@ -36,7 +36,7 @@ import cPickle, os, subprocess, sys
import theano import theano
def main(stdout=None, stderr=None, argv=None, theano_nose=None): def main(stdout=None, stderr=None, argv=None, theano_nose=None, batch_size=None):
""" """
Run tests with optional output redirection. Run tests with optional output redirection.
...@@ -48,6 +48,8 @@ def main(stdout=None, stderr=None, argv=None, theano_nose=None): ...@@ -48,6 +48,8 @@ def main(stdout=None, stderr=None, argv=None, theano_nose=None):
If theano_nose is None, then we use the theano-nose script found in If theano_nose is None, then we use the theano-nose script found in
Theano/bin to call nosetests. Otherwise we call the provided script. Theano/bin to call nosetests. Otherwise we call the provided script.
If batch_size is None, we use a default value of 100.
""" """
if stdout is None: if stdout is None:
stdout = sys.stdout stdout = sys.stdout
...@@ -57,17 +59,19 @@ def main(stdout=None, stderr=None, argv=None, theano_nose=None): ...@@ -57,17 +59,19 @@ def main(stdout=None, stderr=None, argv=None, theano_nose=None):
argv = sys.argv argv = sys.argv
if theano_nose is None: if theano_nose is None:
theano_nose = os.path.join(theano.__path__[0], '..', 'bin', 'theano-nose') theano_nose = os.path.join(theano.__path__[0], '..', 'bin', 'theano-nose')
if batch_size is None:
batch_size = 100
stdout_backup = sys.stdout stdout_backup = sys.stdout
stderr_backup = sys.stderr stderr_backup = sys.stderr
try: try:
sys.stdout = stdout sys.stdout = stdout
sys.stderr = stderr sys.stderr = stderr
run(stdout, stderr, argv, theano_nose) run(stdout, stderr, argv, theano_nose, batch_size)
finally: finally:
sys.stdout = stdout_backup sys.stdout = stdout_backup
sys.stderr = stderr_backup sys.stderr = stderr_backup
def run(stdout, stderr, argv, theano_nose): def run(stdout, stderr, argv, theano_nose, batch_size):
if len(argv) == 1: if len(argv) == 1:
tests_dir = theano.__path__[0] tests_dir = theano.__path__[0]
else: else:
...@@ -101,16 +105,15 @@ def run(stdout, stderr, argv, theano_nose): ...@@ -101,16 +105,15 @@ def run(stdout, stderr, argv, theano_nose):
n_tests = len(ids) n_tests = len(ids)
assert n_tests == max(ids) assert n_tests == max(ids)
# Run tests. # Run tests.
n_batch = 100
failed = set() failed = set()
print """\ print """\
################################### ###################################
# RUNNING TESTS IN BATCHES OF %s # # RUNNING TESTS IN BATCHES OF %s #
###################################""" % n_batch ###################################""" % batch_size
for test_id in xrange(1, n_tests + 1, n_batch): for test_id in xrange(1, n_tests + 1, batch_size):
stdout.flush() stdout.flush()
stderr.flush() stderr.flush()
test_range = range(test_id, min(test_id + n_batch, n_tests + 1)) test_range = range(test_id, min(test_id + batch_size, n_tests + 1))
# We suppress all output because we want the user to focus only on the # We suppress all output because we want the user to focus only on the
# failed tests, which are re-run (with output) below. # failed tests, which are re-run (with output) below.
dummy_out = open(os.devnull, 'w') dummy_out = open(os.devnull, 'w')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论