提交 b045f9c4 authored 作者: AlexLamb's avatar AlexLamb

Added support for debugmode. Changed unit tests so that keys are always…

Added support for debugmode. Changed unit tests so that keys are always comparable types in python3.3
上级 5fb0c235
......@@ -2179,7 +2179,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
accept_inplace = False,
function_builder = Function,
profile=None,
on_unused_input=None):
on_unused_input=None,
output_keys=None):
"""
:type inputs: a list of SymbolicInput instances
......@@ -2314,6 +2315,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
self.accept_inplace = accept_inplace
self.function_builder = function_builder
self.mode = mode
self.output_keys = output_keys
def create(self, defaults=None, trustme=False):
"""
......@@ -2422,7 +2424,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
_fn, _i, _o = self.linker.make_thunk(input_storage=input_storage)
fn = self.function_builder(_fn, _i, _o, self.indices,
self.outputs, defaults, self.unpack_single,
self.return_none, self)
self.return_none, self.output_keys, self)
return fn
......
......@@ -4,6 +4,7 @@ from theano.tests import unittest_tools as utt
import theano
import theano.tensor as T
import sys
class dictionary_output_checker(unittest.TestCase):
......@@ -24,13 +25,13 @@ class dictionary_output_checker(unittest.TestCase):
x = T.scalar()
f = theano.function([x], outputs = {'a' : x, 'c' : x*2, 'b' : x*3, 1 : x*4})
f = theano.function([x], outputs = {'a' : x, 'c' : x*2, 'b' : x*3, '1' : x*4})
outputs = f(10.0)
assert outputs['a'] == 10.0
assert outputs['b'] == 30.0
assert outputs[1] == 40.0
assert outputs['1'] == 40.0
assert outputs['c'] == 20.0
def test_input_dictionary(self):
......@@ -50,7 +51,7 @@ class dictionary_output_checker(unittest.TestCase):
e1 = T.scalar('1')
e2 = T.scalar('2')
f = theano.function([x,y,z,e1,e2], outputs = {'x':x,'y':y,'z':z,1:e1,2:e2})
f = theano.function([x,y,z,e1,e2], outputs = {'x':x,'y':y,'z':z,'1':e1,'2':e2})
assert '1' in str(f.outputs[0])
assert '2' in str(f.outputs[1])
......@@ -93,3 +94,29 @@ class dictionary_output_checker(unittest.TestCase):
assert result[2] == 20.0
assert result[3] == 5.0
def test_debug_mode_dict(self):
x = T.scalar('x')
print sys.stderr, "Running debug mode with dict"
f = theano.function([x], outputs = {'1' : x, '2' : 2 * x, '3' : 3 * x}, mode = "DEBUG_MODE")
result = f(3.0)
assert result['1'] == 3.0
assert result['2'] == 6.0
assert result['3'] == 9.0
def test_debug_mode_list(self):
x = T.scalar('x')
f = theano.function([x], outputs = [x, 2 * x, 3 * x])
result = f(5.0)
assert result[0] == 5.0
assert result[1] == 10.0
assert result[2] == 15.0
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论