提交 93159e52 authored 作者: abergeron's avatar abergeron

Merge pull request #4144 from ChihebTrabelsi/ccw

Test files have been modified in order to respect the flake8 style.
...@@ -47,18 +47,18 @@ class T_scipy(unittest.TestCase): ...@@ -47,18 +47,18 @@ class T_scipy(unittest.TestCase):
b = shared(numpy.zeros(())) b = shared(numpy.zeros(()))
# Construct Theano expression graph # Construct Theano expression graph
p_1 = 1 / (1 + T.exp(-T.dot(x, w)-b)) p_1 = 1 / (1 + T.exp(-T.dot(x, w) - b))
xent = -y*T.log(p_1) - (1-y)*T.log(1-p_1) xent = -y * T.log(p_1) - (1 - y) * T.log(1 - p_1)
prediction = p_1 > 0.5 prediction = p_1 > 0.5
cost = xent.mean() + 0.01*(w**2).sum() cost = xent.mean() + 0.01 * (w ** 2).sum()
gw, gb = T.grad(cost, [w, b]) gw, gb = T.grad(cost, [w, b])
# Compile expressions to functions # Compile expressions to functions
train = function( train = function(
inputs=[x, y], inputs=[x, y],
outputs=[prediction, xent], outputs=[prediction, xent],
updates=[(w, w-0.1*gw), (b, b-0.1*gb)]) updates=[(w, w - 0.1 * gw), (b, b - 0.1 * gb)])
predict = function(inputs=[x], outputs=prediction) function(inputs=[x], outputs=prediction)
N = 4 N = 4
feats = 100 feats = 100
......
from __future__ import print_function from __future__ import print_function
from theano.compile import Mode
import theano
from theano.printing import hex_digest
__authors__ = "Ian Goodfellow" __authors__ = "Ian Goodfellow"
__credits__ = ["Ian Goodfellow"] __credits__ = ["Ian Goodfellow"]
__license__ = "3-clause BSD" __license__ = "3-clause BSD"
__maintainer__ = "Ian Goodfellow" __maintainer__ = "Ian Goodfellow"
__email__ = "goodfeli@iro" __email__ = "goodfeli@iro"
from theano.compile import Mode
import theano
from theano.printing import hex_digest
class MismatchError(Exception): class MismatchError(Exception):
""" """
...@@ -96,12 +96,12 @@ class Record(object): ...@@ -96,12 +96,12 @@ class Record(object):
msg = 'Replay detected mismatch.\n' msg = 'Replay detected mismatch.\n'
msg += ' I wanted to write:\n' msg += ' I wanted to write:\n'
if len(line) > 100: if len(line) > 100:
msg += line[0:100]+'...' msg += line[0:100] + '...'
else: else:
msg += line msg += line
msg += '\nwhen previous job wrote:\n' msg += '\nwhen previous job wrote:\n'
if len(old_line) > 100: if len(old_line) > 100:
msg += old_line[0:100]+'...' msg += old_line[0:100] + '...'
else: else:
msg += old_line msg += old_line
raise MismatchError(msg) raise MismatchError(msg)
...@@ -198,7 +198,7 @@ class RecordMode(Mode): ...@@ -198,7 +198,7 @@ class RecordMode(Mode):
except MismatchError as e: except MismatchError as e:
print('Got this MismatchError:') print('Got this MismatchError:')
print(e) print(e)
print('while processing node i='+str(i)+':') print('while processing node i=' + str(i) + ':')
print('str(node):', str(node)) print('str(node):', str(node))
print('Symbolic inputs: ') print('Symbolic inputs: ')
for elem in node.inputs: for elem in node.inputs:
...@@ -208,7 +208,7 @@ class RecordMode(Mode): ...@@ -208,7 +208,7 @@ class RecordMode(Mode):
assert isinstance(elem, list) assert isinstance(elem, list)
elem, = elem elem, = elem
print(str(elem)) print(str(elem))
print('function name: '+node.fgraph.name) print('function name: ' + node.fgraph.name)
raise MismatchError("Non-determinism detected by WrapLinker") raise MismatchError("Non-determinism detected by WrapLinker")
def callback(i, node, fn): def callback(i, node, fn):
...@@ -228,16 +228,15 @@ class RecordMode(Mode): ...@@ -228,16 +228,15 @@ class RecordMode(Mode):
for elem in self.known_fgraphs]) for elem in self.known_fgraphs])
self.known_fgraphs.add(fgraph) self.known_fgraphs.add(fgraph)
num_app = len(fgraph.apply_nodes) num_app = len(fgraph.apply_nodes)
line = 'Function '+fgraph.name+' has ' + str(num_app) \ line = 'Function ' + fgraph.name + ' has ' + str(num_app) \
+ ' apply nodes.\n' + ' apply nodes.\n'
handle_line(line, i, node, fn) handle_line(line, i, node, fn)
line = 'Function name: '+fgraph.name + '\n' line = 'Function name: ' + fgraph.name + '\n'
handle_line(line, i, node, fn) handle_line(line, i, node, fn)
line = 'Node '+str(i)+':'+str(node)+'\n' line = 'Node ' + str(i) + ':' + str(node) + '\n'
handle_line(line, i, node, fn) handle_line(line, i, node, fn)
assert all([isinstance(x, list) assert all([isinstance(x, list) and len(x) == 1 for x in fn.inputs])
and len(x) == 1 for x in fn.inputs])
def digest(x): def digest(x):
x = x[0] x = x[0]
......
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import print_function from __future__ import print_function
import datetime
import os
import subprocess
import sys
import time
from six.moves import xrange
import six.moves.cPickle as pickle
import theano
from theano.misc.windows import output_subprocess_Popen
__authors__ = "Olivier Delalleau, Eric Larsen" __authors__ = "Olivier Delalleau, Eric Larsen"
__contact__ = "delallea@iro" __contact__ = "delallea@iro"
...@@ -55,18 +66,6 @@ nosetests. ...@@ -55,18 +66,6 @@ nosetests.
""" """
import datetime
import os
import subprocess
import sys
import time
from six.moves import xrange
import six.moves.cPickle as pickle
import theano
from theano.misc.windows import output_subprocess_Popen
def main(stdout=None, stderr=None, argv=None, theano_nose=None, def main(stdout=None, stderr=None, argv=None, theano_nose=None,
batch_size=None, time_profile=False, display_batch_output=False): batch_size=None, time_profile=False, display_batch_output=False):
""" """
...@@ -94,8 +93,8 @@ def main(stdout=None, stderr=None, argv=None, theano_nose=None, ...@@ -94,8 +93,8 @@ def main(stdout=None, stderr=None, argv=None, theano_nose=None,
if argv is None: if argv is None:
argv = sys.argv argv = sys.argv
if theano_nose is None: if theano_nose is None:
# If Theano is installed with pip/easy_install, it can be in the # If Theano is installed with pip/easy_install, it can be in the
#*/lib/python2.7/site-packages/theano, but theano-nose in */bin # */lib/python2.7/site-packages/theano, but theano-nose in */bin
for i in range(1, 5): for i in range(1, 5):
path = theano.__path__[0] path = theano.__path__[0]
for _ in range(i): for _ in range(i):
...@@ -145,8 +144,7 @@ def run(stdout, stderr, argv, theano_nose, batch_size, time_profile, ...@@ -145,8 +144,7 @@ def run(stdout, stderr, argv, theano_nose, batch_size, time_profile,
# Using sys.executable, so that the same Python version is used. # Using sys.executable, so that the same Python version is used.
python = sys.executable python = sys.executable
rval = subprocess.call( rval = subprocess.call(
([python, theano_nose, '--collect-only', '--with-id'] ([python, theano_nose, '--collect-only', '--with-id'] + argv),
+ argv),
stdin=dummy_in.fileno(), stdin=dummy_in.fileno(),
stdout=stdout.fileno(), stdout=stdout.fileno(),
stderr=stderr.fileno()) stderr=stderr.fileno())
...@@ -215,9 +213,7 @@ def run(stdout, stderr, argv, theano_nose, batch_size, time_profile, ...@@ -215,9 +213,7 @@ def run(stdout, stderr, argv, theano_nose, batch_size, time_profile,
stdout.flush() stdout.flush()
stderr.flush() stderr.flush()
subprocess.call( subprocess.call(
([python, theano_nose, '-v', '--with-id'] ([python, theano_nose, '-v', '--with-id'] + failed + argv),
+ failed
+ argv),
stdin=dummy_in.fileno(), stdin=dummy_in.fileno(),
stdout=stdout.fileno(), stdout=stdout.fileno(),
stderr=stderr.fileno()) stderr=stderr.fileno())
...@@ -252,7 +248,6 @@ def run(stdout, stderr, argv, theano_nose, batch_size, time_profile, ...@@ -252,7 +248,6 @@ def run(stdout, stderr, argv, theano_nose, batch_size, time_profile,
# iterating through tests # iterating through tests
# initializing master profiling list and raw log # initializing master profiling list and raw log
prof_master_nosort = [] prof_master_nosort = []
prof_rawlog = []
dummy_out = open(os.devnull, 'w') dummy_out = open(os.devnull, 'w')
path_rawlog = os.path.join(sav_dir, 'timeprof_rawlog') path_rawlog = os.path.join(sav_dir, 'timeprof_rawlog')
stamp = str(datetime.datetime.now()) + '\n\n' stamp = str(datetime.datetime.now()) + '\n\n'
...@@ -273,21 +268,21 @@ def run(stdout, stderr, argv, theano_nose, batch_size, time_profile, ...@@ -273,21 +268,21 @@ def run(stdout, stderr, argv, theano_nose, batch_size, time_profile,
f_nosort.flush() f_nosort.flush()
for test_floor in xrange(1, n_tests + 1, batch_size): for test_floor in xrange(1, n_tests + 1, batch_size):
for test_id in xrange(test_floor, min(test_floor + batch_size, for test_id in xrange(test_floor, min(test_floor + batch_size,
n_tests + 1)): n_tests + 1)):
# Print the test we will start in the raw log to help # Print the test we will start in the raw log to help
# debug tests that are too long. # debug tests that are too long.
f_rawlog.write("\n%s Will run test #%d %s\n" % ( f_rawlog.write("\n%s Will run test #%d %s\n" % (
time.ctime(), test_id, data["ids"][test_id])) time.ctime(), test_id, data["ids"][test_id]))
f_rawlog.flush() f_rawlog.flush()
p_out = output_subprocess_Popen( p_out = output_subprocess_Popen(([python, theano_nose, '-v', '--with-id'] +
([python, theano_nose, '-v', '--with-id'] [str(test_id)] +
+ [str(test_id)] + argv + argv +
['--disabdocstring'])) ['--disabdocstring']))
# the previous option calls a custom Nosetests plugin # the previous option calls a custom Nosetests plugin
# precluding automatic sustitution of doc. string for # precluding automatic sustitution of doc. string for
# test name in display # test name in display
# (see class 'DisabDocString' in file theano-nose) # (see class 'DisabDocString' in file theano-nose)
# recovering and processing data from pipe # recovering and processing data from pipe
err = p_out[1] err = p_out[1]
...@@ -334,9 +329,9 @@ def run(stdout, stderr, argv, theano_nose, batch_size, time_profile, ...@@ -334,9 +329,9 @@ def run(stdout, stderr, argv, theano_nose, batch_size, time_profile,
# write the no sort file # write the no sort file
s_nosort = ((str(prof_tuple[0]) + 's').ljust(10) + s_nosort = ((str(prof_tuple[0]) + 's').ljust(10) +
" " + prof_tuple[1].ljust(7) + " " + " " + prof_tuple[1].ljust(7) + " " +
prof_tuple[2] + prof_tuple[3] + prof_tuple[2] + prof_tuple[3] +
"\n") "\n")
f_nosort.write(s_nosort) f_nosort.write(s_nosort)
f_nosort.flush() f_nosort.flush()
...@@ -354,12 +349,13 @@ def run(stdout, stderr, argv, theano_nose, batch_size, time_profile, ...@@ -354,12 +349,13 @@ def run(stdout, stderr, argv, theano_nose, batch_size, time_profile,
' (sorted by computation time)\n\n' + stamp + fields) ' (sorted by computation time)\n\n' + stamp + fields)
for i in xrange(len(prof_master_nosort)): for i in xrange(len(prof_master_nosort)):
s_sort = ((str(prof_master_sort[i][0]) + 's').ljust(10) + s_sort = ((str(prof_master_sort[i][0]) + 's').ljust(10) +
" " + prof_master_sort[i][1].ljust(7) + " " + " " + prof_master_sort[i][1].ljust(7) + " " +
prof_master_sort[i][2] + prof_master_sort[i][3] + prof_master_sort[i][2] + prof_master_sort[i][3] +
"\n") "\n")
f_sort.write(s_sort) f_sort.write(s_sort)
# end of saving nosort # end of saving nosort
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(main()) sys.exit(main())
""" """
Test for jacobian/hessian functions in Theano Test for jacobian/hessian functions in Theano
""" """
import unittest
from six.moves import xrange from six.moves import xrange
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano import function
import theano import theano
from theano import tensor from theano import tensor
import numpy import numpy
...@@ -151,17 +149,14 @@ def test_jacobian_disconnected_inputs(): ...@@ -151,17 +149,14 @@ def test_jacobian_disconnected_inputs():
""" """
v1 = tensor.vector() v1 = tensor.vector()
v2 = tensor.vector() v2 = tensor.vector()
jacobian_v = theano.gradient.jacobian(1 + v1, v2, jacobian_v = theano.gradient.jacobian(1 + v1, v2, disconnected_inputs='ignore')
disconnected_inputs='ignore')
func_v = theano.function([v1, v2], jacobian_v) func_v = theano.function([v1, v2], jacobian_v)
val = numpy.arange(4.0).astype(theano.config.floatX) val = numpy.arange(4.0).astype(theano.config.floatX)
assert numpy.allclose(func_v(val, val), numpy.zeros((4, 4))) assert numpy.allclose(func_v(val, val), numpy.zeros((4, 4)))
s1 = tensor.scalar() s1 = tensor.scalar()
s2 = tensor.scalar() s2 = tensor.scalar()
jacobian_s = theano.gradient.jacobian(1 + s1, s2, jacobian_s = theano.gradient.jacobian(1 + s1, s2, disconnected_inputs='ignore')
disconnected_inputs='ignore')
func_s = theano.function([s2], jacobian_s) func_s = theano.function([s2], jacobian_s)
val = numpy.array(1.0).astype(theano.config.floatX) val = numpy.array(1.0).astype(theano.config.floatX)
assert numpy.allclose(func_s(val), numpy.zeros(1)) assert numpy.allclose(func_s(val), numpy.zeros(1))
...@@ -33,11 +33,10 @@ class T_config(unittest.TestCase): ...@@ -33,11 +33,10 @@ class T_config(unittest.TestCase):
THEANO_FLAGS_DICT['T_config.test_invalid_default_b'] = 'ok' THEANO_FLAGS_DICT['T_config.test_invalid_default_b'] = 'ok'
# This should succeed since we defined a proper value, even # This should succeed since we defined a proper value, even
# though the default was invalid. # though the default was invalid.
AddConfigVar( AddConfigVar('T_config.test_invalid_default_b',
'T_config.test_invalid_default_b', doc='unittest',
doc='unittest', configparam=ConfigParam('invalid', filter=filter),
configparam=ConfigParam('invalid', filter=filter), in_c_key=False)
in_c_key=False)
# Check that the flag has been removed # Check that the flag has been removed
assert 'T_config.test_invalid_default_b' not in THEANO_FLAGS_DICT assert 'T_config.test_invalid_default_b' not in THEANO_FLAGS_DICT
......
...@@ -4,7 +4,6 @@ from theano.tests import disturb_mem ...@@ -4,7 +4,6 @@ from theano.tests import disturb_mem
import numpy as np import numpy as np
import theano import theano
from theano.printing import var_descriptor from theano.printing import var_descriptor
from nose.plugins.skip import SkipTest
from theano import config, shared from theano import config, shared
from six import StringIO from six import StringIO
...@@ -24,10 +23,6 @@ def test_determinism_1(): ...@@ -24,10 +23,6 @@ def test_determinism_1():
# change. # change.
# This specific script is capable of catching a bug where # This specific script is capable of catching a bug where
# FunctionGraph.toposort was non-deterministic. # FunctionGraph.toposort was non-deterministic.
try:
import hashlib
except ImportError:
raise SkipTest('python version too old to do this test')
def run(replay, log=None): def run(replay, log=None):
......
...@@ -27,19 +27,7 @@ ignore = ('E501', 'E123', 'E133') ...@@ -27,19 +27,7 @@ ignore = ('E501', 'E123', 'E133')
whitelist_flake8 = [ whitelist_flake8 = [
"compat/six.py", # This is bundled code that will be deleted, don't fix it "compat/six.py", # This is bundled code that will be deleted, don't fix it
"__init__.py", "__init__.py",
"tests/test_gradient.py",
"tests/test_config.py",
"tests/diverse_tests.py",
"tests/test_rop.py",
"tests/test_2nd_order_grads.py",
"tests/run_tests_in_batch.py",
"tests/test_record.py",
"tests/__init__.py", "tests/__init__.py",
"tests/test_updates.py",
"tests/test_pickle_unpickle_theano_fn.py",
"tests/test_determinism.py",
"tests/record.py",
"tests/unittest_tools.py",
"compile/__init__.py", "compile/__init__.py",
"compile/profiling.py", "compile/profiling.py",
"compile/tests/test_builders.py", "compile/tests/test_builders.py",
......
...@@ -12,15 +12,15 @@ The config option is in configdefaults.py ...@@ -12,15 +12,15 @@ The config option is in configdefaults.py
This note is written by Li Yao. This note is written by Li Yao.
""" """
import unittest
import numpy import numpy
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
from theano.compat import DictMixin, OrderedDict from theano.compat import OrderedDict
import theano import theano
import theano.tensor as T import theano.tensor as T
floatX = 'float32' floatX = 'float32'
def test_pickle_unpickle_with_reoptimization(): def test_pickle_unpickle_with_reoptimization():
mode = theano.config.mode mode = theano.config.mode
if mode in ["DEBUG_MODE", "DebugMode"]: if mode in ["DEBUG_MODE", "DebugMode"]:
......
from theano.tests.record import * from theano.tests.record import Record, MismatchError, RecordMode
from theano import function from theano import function
from six.moves import xrange, StringIO from six.moves import xrange, StringIO
from theano.tensor import iscalar from theano.tensor import iscalar
...@@ -20,21 +20,21 @@ def test_record_good(): ...@@ -20,21 +20,21 @@ def test_record_good():
num_lines = 10 num_lines = 10
for i in xrange(num_lines): for i in xrange(num_lines):
recorder.handle_line(str(i)+'\n') recorder.handle_line(str(i) + '\n')
# Make sure they were recorded correctly # Make sure they were recorded correctly
output_value = output.getvalue() output_value = output.getvalue()
assert output_value == ''.join(str(i)+'\n' for i in xrange(num_lines)) assert output_value == ''.join(str(i) + '\n' for i in xrange(num_lines))
# Make sure that the playback functionality doesn't raise any errors # Make sure that the playback functionality doesn't raise any errors
# when we repeat them # when we repeat them
output = StringIO(output_value) output = StringIO(output_value)
playback_checker = Record(file_object=output, replay=True) playback_checker = Record(file_object=output, replay=True)
for i in xrange(num_lines): for i in xrange(num_lines):
playback_checker.handle_line(str(i)+'\n') playback_checker.handle_line(str(i) + '\n')
def test_record_bad(): def test_record_bad():
...@@ -51,17 +51,17 @@ def test_record_bad(): ...@@ -51,17 +51,17 @@ def test_record_bad():
num_lines = 10 num_lines = 10
for i in xrange(num_lines): for i in xrange(num_lines):
recorder.handle_line(str(i)+'\n') recorder.handle_line(str(i) + '\n')
# Make sure that the playback functionality doesn't raise any errors # Make sure that the playback functionality doesn't raise any errors
# when we repeat some of them # when we repeat some of them
output_value = output.getvalue() output_value = output.getvalue()
output = StringIO(output_value) output = StringIO(output_value)
playback_checker = Record(file_object=output, replay=True) playback_checker = Record(file_object=output, replay=True)
for i in xrange(num_lines // 2): for i in xrange(num_lines // 2):
playback_checker.handle_line(str(i)+'\n') playback_checker.handle_line(str(i) + '\n')
# Make sure it raises an error when we deviate from the recorded sequence # Make sure it raises an error when we deviate from the recorded sequence
try: try:
...@@ -92,7 +92,7 @@ def test_record_mode_good(): ...@@ -92,7 +92,7 @@ def test_record_mode_good():
num_lines = 10 num_lines = 10
for i in xrange(num_lines): for i in xrange(num_lines):
recorder.handle_line(str(i)+'\n') recorder.handle_line(str(i) + '\n')
f(i) f(i)
# Make sure that the playback functionality doesn't raise any errors # Make sure that the playback functionality doesn't raise any errors
...@@ -100,7 +100,7 @@ def test_record_mode_good(): ...@@ -100,7 +100,7 @@ def test_record_mode_good():
output_value = output.getvalue() output_value = output.getvalue()
output = StringIO(output_value) output = StringIO(output_value)
playback_checker = Record(file_object=output, replay=True) playback_checker = Record(file_object=output, replay=True)
playback_mode = RecordMode(playback_checker) playback_mode = RecordMode(playback_checker)
...@@ -108,7 +108,7 @@ def test_record_mode_good(): ...@@ -108,7 +108,7 @@ def test_record_mode_good():
f = function([i], i, mode=playback_mode, name='f') f = function([i], i, mode=playback_mode, name='f')
for i in xrange(num_lines): for i in xrange(num_lines):
playback_checker.handle_line(str(i)+'\n') playback_checker.handle_line(str(i) + '\n')
f(i) f(i)
...@@ -132,7 +132,7 @@ def test_record_mode_bad(): ...@@ -132,7 +132,7 @@ def test_record_mode_bad():
num_lines = 10 num_lines = 10
for i in xrange(num_lines): for i in xrange(num_lines):
recorder.handle_line(str(i)+'\n') recorder.handle_line(str(i) + '\n')
f(i) f(i)
# Make sure that the playback functionality doesn't raise any errors # Make sure that the playback functionality doesn't raise any errors
...@@ -140,7 +140,7 @@ def test_record_mode_bad(): ...@@ -140,7 +140,7 @@ def test_record_mode_bad():
output_value = output.getvalue() output_value = output.getvalue()
output = StringIO(output_value) output = StringIO(output_value)
playback_checker = Record(file_object=output, replay=True) playback_checker = Record(file_object=output, replay=True)
playback_mode = RecordMode(playback_checker) playback_mode = RecordMode(playback_checker)
...@@ -148,7 +148,7 @@ def test_record_mode_bad(): ...@@ -148,7 +148,7 @@ def test_record_mode_bad():
f = function([i], i, mode=playback_mode, name='f') f = function([i], i, mode=playback_mode, name='f')
for i in xrange(num_lines // 2): for i in xrange(num_lines // 2):
playback_checker.handle_line(str(i)+'\n') playback_checker.handle_line(str(i) + '\n')
f(i) f(i)
# Make sure a wrong event causes a MismatchError # Make sure a wrong event causes a MismatchError
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
""" """
import unittest import unittest
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano import function from theano import function
import theano import theano
from theano import tensor from theano import tensor
...@@ -21,7 +21,6 @@ import numpy ...@@ -21,7 +21,6 @@ import numpy
from theano.gof import Op, Apply from theano.gof import Op, Apply
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.tests.unittest_tools import SkipTest from theano.tests.unittest_tools import SkipTest
from theano.tensor.signal.pool import Pool
from theano.tensor.nnet import conv, conv2d from theano.tensor.nnet import conv, conv2d
''' '''
...@@ -75,7 +74,7 @@ class RopLop_checker(unittest.TestCase): ...@@ -75,7 +74,7 @@ class RopLop_checker(unittest.TestCase):
test that an error is raised.""" test that an error is raised."""
raised = False raised = False
try: try:
tmp = tensor.Rop(y, self.x, self.v) tensor.Rop(y, self.x, self.v)
except ValueError: except ValueError:
raised = True raised = True
if not raised: if not raised:
...@@ -108,10 +107,10 @@ class RopLop_checker(unittest.TestCase): ...@@ -108,10 +107,10 @@ class RopLop_checker(unittest.TestCase):
theano.config.floatX) theano.config.floatX)
yv = tensor.Rop(y, self.mx, self.mv) yv = tensor.Rop(y, self.mx, self.mv)
rop_f = function([self.mx, self.mv], yv, on_unused_input='ignore') rop_f = function([self.mx, self.mv], yv, on_unused_input='ignore')
sy, _ = theano.scan(lambda i, y, x, v: \ sy, _ = theano.scan(lambda i, y, x, v:
(tensor.grad(y[i], x) * v).sum(), (tensor.grad(y[i], x) * v).sum(),
sequences=tensor.arange(y.shape[0]), sequences=tensor.arange(y.shape[0]),
non_sequences=[y, self.mx, self.mv]) non_sequences=[y, self.mx, self.mv])
scan_f = function([self.mx, self.mv], sy, on_unused_input='ignore') scan_f = function([self.mx, self.mv], sy, on_unused_input='ignore')
v1 = rop_f(vx, vv) v1 = rop_f(vx, vv)
...@@ -119,11 +118,9 @@ class RopLop_checker(unittest.TestCase): ...@@ -119,11 +118,9 @@ class RopLop_checker(unittest.TestCase):
assert numpy.allclose(v1, v2), ('ROP mismatch: %s %s' % (v1, v2)) assert numpy.allclose(v1, v2), ('ROP mismatch: %s %s' % (v1, v2))
self.check_nondiff_rop(theano.clone(y, self.check_nondiff_rop(theano.clone(y, replace={self.mx: break_op(self.mx)}))
replace={self.mx: break_op(self.mx)}))
vv = numpy.asarray(self.rng.uniform(size=out_shape), vv = numpy.asarray(self.rng.uniform(size=out_shape), theano.config.floatX)
theano.config.floatX)
yv = tensor.Lop(y, self.mx, self.v) yv = tensor.Lop(y, self.mx, self.v)
lop_f = function([self.mx, self.v], yv) lop_f = function([self.mx, self.v], yv)
...@@ -160,8 +157,7 @@ class RopLop_checker(unittest.TestCase): ...@@ -160,8 +157,7 @@ class RopLop_checker(unittest.TestCase):
assert numpy.allclose(v1, v2), ('ROP mismatch: %s %s' % (v1, v2)) assert numpy.allclose(v1, v2), ('ROP mismatch: %s %s' % (v1, v2))
known_fail = False known_fail = False
try: try:
self.check_nondiff_rop(theano.clone(y, self.check_nondiff_rop(theano.clone(y, replace={self.x: break_op(self.x)}))
replace={self.x: break_op(self.x)}))
except AssertionError: except AssertionError:
known_fail = True known_fail = True
...@@ -266,13 +262,13 @@ class test_RopLop(RopLop_checker): ...@@ -266,13 +262,13 @@ class test_RopLop(RopLop_checker):
filter_shape = (2, 2, 2, 3) filter_shape = (2, 2, 2, 3)
image_dim = len(image_shape) image_dim = len(image_shape)
filter_dim = len(filter_shape) filter_dim = len(filter_shape)
input = tensor.TensorType( input = tensor.TensorType(
theano.config.floatX, theano.config.floatX,
[False] * image_dim)(name='input') [False] * image_dim)(name='input')
filters = tensor.TensorType( filters = tensor.TensorType(
theano.config.floatX, theano.config.floatX,
[False] * filter_dim)(name='filter') [False] * filter_dim)(name='filter')
ev_input = tensor.TensorType( ev_input = tensor.TensorType(
theano.config.floatX, theano.config.floatX,
[False] * image_dim)(name='ev_input') [False] * image_dim)(name='ev_input')
ev_filters = tensor.TensorType( ev_filters = tensor.TensorType(
...@@ -284,28 +280,23 @@ class test_RopLop(RopLop_checker): ...@@ -284,28 +280,23 @@ class test_RopLop(RopLop_checker):
output = sym_conv2d(input, filters).flatten() output = sym_conv2d(input, filters).flatten()
yv = tensor.Rop(output, [input, filters], [ev_input, ev_filters]) yv = tensor.Rop(output, [input, filters], [ev_input, ev_filters])
rop_f = function([input, filters, ev_input, ev_filters], rop_f = function([input, filters, ev_input, ev_filters],
yv, on_unused_input='ignore') yv, on_unused_input='ignore')
sy, _ = theano.scan( sy, _ = theano.scan(lambda i, y, x1, x2, v1, v2:
lambda i, y, x1, x2, v1, v2: (tensor.grad(y[i], x1) * v1).sum() +
(tensor.grad(y[i], x1) * v1).sum() + \ (tensor.grad(y[i], x2) * v2).sum(),
(tensor.grad(y[i], x2) * v2).sum(),
sequences=tensor.arange(output.shape[0]), sequences=tensor.arange(output.shape[0]),
non_sequences=[output, input, filters, non_sequences=[output, input, filters,
ev_input, ev_filters]) ev_input, ev_filters])
scan_f = function([input, filters, ev_input, ev_filters], sy, scan_f = function([input, filters, ev_input, ev_filters], sy,
on_unused_input='ignore') on_unused_input='ignore')
dtype = theano.config.floatX dtype = theano.config.floatX
image_data = numpy.random.random(image_shape).astype(dtype) image_data = numpy.random.random(image_shape).astype(dtype)
filter_data = numpy.random.random(filter_shape).astype(dtype) filter_data = numpy.random.random(filter_shape).astype(dtype)
ev_image_data = numpy.random.random(image_shape).astype(dtype) ev_image_data = numpy.random.random(image_shape).astype(dtype)
ev_filter_data = numpy.random.random(filter_shape).astype(dtype) ev_filter_data = numpy.random.random(filter_shape).astype(dtype)
v1 = rop_f(image_data, filter_data, ev_image_data, v1 = rop_f(image_data, filter_data, ev_image_data, ev_filter_data)
ev_filter_data) v2 = scan_f(image_data, filter_data, ev_image_data, ev_filter_data)
v2 = scan_f(image_data, filter_data, ev_image_data, assert numpy.allclose(v1, v2), ("Rop mismatch: %s %s" % (v1, v2))
ev_filter_data)
assert numpy.allclose(v1, v2), ("Rop mismatch: %s %s" %
(v1, v2))
def test_join(self): def test_join(self):
tv = numpy.asarray(self.rng.uniform(size=(10,)), tv = numpy.asarray(self.rng.uniform(size=(10,)),
...@@ -353,10 +344,8 @@ class test_RopLop(RopLop_checker): ...@@ -353,10 +344,8 @@ class test_RopLop(RopLop_checker):
self.check_rop_lop(out1d, self.in_shape[0]) self.check_rop_lop(out1d, self.in_shape[0])
# Alloc of x into a 3-D tensor, flattened # Alloc of x into a 3-D tensor, flattened
out3d = tensor.alloc(self.x, out3d = tensor.alloc(self.x, self.mat_in_shape[0], self.mat_in_shape[1], self.in_shape[0])
self.mat_in_shape[0], self.mat_in_shape[1], self.in_shape[0]) self.check_rop_lop(out3d.flatten(), self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0])
self.check_rop_lop(out3d.flatten(),
self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0])
def test_invalid_input(self): def test_invalid_input(self):
success = False success = False
......
...@@ -14,10 +14,8 @@ class test_ifelse(unittest.TestCase): ...@@ -14,10 +14,8 @@ class test_ifelse(unittest.TestCase):
OrderedUpdates({sv: 3}) OrderedUpdates({sv: 3})
def test_updates_setitem(self): def test_updates_setitem(self):
ok = True
up = OrderedUpdates() up = OrderedUpdates()
sv = theano.shared('asdf')
# keys have to be SharedVariables # keys have to be SharedVariables
self.assertRaises(TypeError, up.__setitem__, 5, 7) self.assertRaises(TypeError, up.__setitem__, 5, 7)
......
...@@ -51,8 +51,7 @@ def fetch_seed(pseed=None): ...@@ -51,8 +51,7 @@ def fetch_seed(pseed=None):
else: else:
seed = None seed = None
except ValueError: except ValueError:
print(('Error: config.unittests.rseed contains ' print(('Error: config.unittests.rseed contains ' 'invalid seed, using None instead'), file=sys.stderr)
'invalid seed, using None instead'), file=sys.stderr)
seed = None seed = None
return seed return seed
...@@ -66,8 +65,7 @@ def seed_rng(pseed=None): ...@@ -66,8 +65,7 @@ def seed_rng(pseed=None):
seed = fetch_seed(pseed) seed = fetch_seed(pseed)
if pseed and pseed != seed: if pseed and pseed != seed:
print('Warning: using seed given by config.unittests.rseed=%i'\ print('Warning: using seed given by config.unittests.rseed=%i' 'instead of seed %i given as parameter' % (seed, pseed), file=sys.stderr)
'instead of seed %i given as parameter' % (seed, pseed), file=sys.stderr)
numpy.random.seed(seed) numpy.random.seed(seed)
return seed return seed
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论