提交 006eaf78 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Enables specifying batch size

上级 e88dfffe
......@@ -29,11 +29,29 @@ import sys
def main():
if '--batch' in sys.argv:
# Handle --batch[=n] arguments
batch_args = [arg for arg in sys.argv if arg.startswith('--batch')]
for arg in batch_args:
sys.argv.remove(arg)
if len(batch_args):
if len(batch_args) > 1:
_logger.warn(
'Multiple --batch arguments detected, using the last one '
'and ignoring the first ones.')
batch_arg = batch_args[-1]
elems = batch_arg.split('=', 1)
if len(elems) == 2:
batch_size = int(elems[1])
else:
# Use run_tests_in_batch's default
batch_size = None
from theano.tests import run_tests_in_batch
sys.argv.remove('--batch')
return run_tests_in_batch.main()
return run_tests_in_batch.main(batch_size=batch_size)
# Non-batch mode.
addplugins = []
# We include KnownFailure plugin by default, unless
# it is disabled by the "--without-knownfailure" arg.
......
......@@ -36,7 +36,7 @@ import cPickle, os, subprocess, sys
import theano
def main(stdout=None, stderr=None, argv=None, theano_nose=None):
def main(stdout=None, stderr=None, argv=None, theano_nose=None, batch_size=None):
"""
Run tests with optional output redirection.
......@@ -48,6 +48,8 @@ def main(stdout=None, stderr=None, argv=None, theano_nose=None):
If theano_nose is None, then we use the theano-nose script found in
Theano/bin to call nosetests. Otherwise we call the provided script.
If batch_size is None, we use a default value of 100.
"""
if stdout is None:
stdout = sys.stdout
......@@ -57,17 +59,19 @@ def main(stdout=None, stderr=None, argv=None, theano_nose=None):
argv = sys.argv
if theano_nose is None:
theano_nose = os.path.join(theano.__path__[0], '..', 'bin', 'theano-nose')
if batch_size is None:
batch_size = 100
stdout_backup = sys.stdout
stderr_backup = sys.stderr
try:
sys.stdout = stdout
sys.stderr = stderr
run(stdout, stderr, argv, theano_nose)
run(stdout, stderr, argv, theano_nose, batch_size)
finally:
sys.stdout = stdout_backup
sys.stderr = stderr_backup
def run(stdout, stderr, argv, theano_nose):
def run(stdout, stderr, argv, theano_nose, batch_size):
if len(argv) == 1:
tests_dir = theano.__path__[0]
else:
......@@ -101,16 +105,15 @@ def run(stdout, stderr, argv, theano_nose):
n_tests = len(ids)
assert n_tests == max(ids)
# Run tests.
n_batch = 100
failed = set()
print """\
###################################
# RUNNING TESTS IN BATCHES OF %s #
###################################""" % n_batch
for test_id in xrange(1, n_tests + 1, n_batch):
###################################""" % batch_size
for test_id in xrange(1, n_tests + 1, batch_size):
stdout.flush()
stderr.flush()
test_range = range(test_id, min(test_id + n_batch, n_tests + 1))
test_range = range(test_id, min(test_id + batch_size, n_tests + 1))
# We suppress all output because we want the user to focus only on the
# failed tests, which are re-run (with output) below.
dummy_out = open(os.devnull, 'w')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论