提交 f19bb9fa authored 作者: AlexLamb's avatar AlexLamb

Added comments. Added test that rejects non-string keys

上级 9fc1f281
import unittest
import theano
import theano.tensor as T
import sys
class dictionary_output_checker(unittest.TestCase):
def test_output_dictionary(self):
'''
Tests that theano.function works when outputs is a dictionary
'''
x = T.scalar()
f = theano.function([x], outputs={'a': x, 'c': x*2,
'b': x*3, '1': x*4})
......@@ -19,6 +23,11 @@ class dictionary_output_checker(unittest.TestCase):
assert outputs['c'] == 20.0
def test_input_named_variables(self):
'''
Tests that named variables work when outputs is a dictionary
'''
x = T.scalar('x')
y = T.scalar('y')
......@@ -29,6 +38,11 @@ class dictionary_output_checker(unittest.TestCase):
assert f(x=2, y=4) == f(2, 4)
def test_output_order_sorted(self):
'''
Tests that the output keys are sorted correctly.
'''
x = T.scalar('x')
y = T.scalar('y')
z = T.scalar('z')
......@@ -45,6 +59,12 @@ class dictionary_output_checker(unittest.TestCase):
assert 'z' in str(f.outputs[4])
def test_composing_function(self):
'''
Tests that one can compose two theano functions when the outputs are
provided in a dictionary.
'''
x = T.scalar('x')
y = T.scalar('y')
......@@ -68,6 +88,10 @@ class dictionary_output_checker(unittest.TestCase):
def test_output_list_still_works(self):
'''
Test that theano.function works if outputs is a list.
'''
x = T.scalar('x')
f = theano.function([x], outputs=[x * 3, x * 2, x * 4, x])
......@@ -80,9 +104,12 @@ class dictionary_output_checker(unittest.TestCase):
assert result[3] == 5.0
def test_debug_mode_dict(self):
x = T.scalar('x')
print sys.stderr, "Running debug mode with dict"
'''
Tests that debug mode works where outputs is a dictionary.
'''
x = T.scalar('x')
f = theano.function([x], outputs={'1': x, '2': 2 * x,
'3': 3 * x}, mode="DEBUG_MODE")
......@@ -95,12 +122,42 @@ class dictionary_output_checker(unittest.TestCase):
def test_debug_mode_list(self):
'''
Tests that debug mode works where the outputs argument is a list.
'''
x = T.scalar('x')
f = theano.function([x], outputs=[x, 2 * x, 3 * x])
f = theano.function([x], outputs=[x, 2 * x, 3 * x], mode="DEBUG_MODE")
result = f(5.0)
assert result[0] == 5.0
assert result[1] == 10.0
assert result[2] == 15.0
def test_key_string_requirement(self):
'''
Tests that an exception is thrown if a non-string key is used in
the outputs dictionary.
'''
x = T.scalar('x')
try:
theano.function([x], outputs={1.0: x})
raise Exception("Did not throw exception with 1.0 as only key")
except AssertionError:
pass
try:
theano.function([x], outputs={1.0: x, "a": x**2})
raise Exception("Did not throw exception with 1.0 as one key")
except AssertionError:
pass
try:
theano.function([x], outputs={("a", "b"): x, "a": x**2})
raise Exception("Did not throw exception with tuple as one key")
except AssertionError:
pass
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论