提交 d50d1e2c authored 作者: AlexLamb's avatar AlexLamb

Removed deadcode. Added check that forces input keys to be strings.

上级 3f9f9cfa
...@@ -193,7 +193,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -193,7 +193,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
output_keys = [] output_keys = []
outputs = [] outputs = []
for pair in output_items_sorted: for pair in output_items_sorted:
assert isinstance(pair[0], str)
output_keys.append(pair[0]) output_keys.append(pair[0])
outputs.append(pair[1]) outputs.append(pair[1])
......
...@@ -677,15 +677,10 @@ class Function(object): ...@@ -677,15 +677,10 @@ class Function(object):
return outputs[0] return outputs[0]
else: else:
if self.output_keys is not None: if self.output_keys is not None:
outputDict = {}
assert len(self.output_keys) == len(outputs) assert len(self.output_keys) == len(outputs)
#for i in range(0, len(self.output_keys)):
# outputDict[self.output_keys[i]] = outputs[i]
return dict(itertools.izip(self.output_keys, outputs)) return dict(itertools.izip(self.output_keys, outputs))
return outputs return outputs
......
...@@ -6,17 +6,6 @@ import sys ...@@ -6,17 +6,6 @@ import sys
class dictionary_output_checker(unittest.TestCase): class dictionary_output_checker(unittest.TestCase):
def test_output_list(self):
x = T.scalar()
f = theano.function([x], outputs=[x, x*2, x*3])
outputs = f(10.0)
assert outputs[0] == 10.0
assert outputs[1] == 20.0
assert outputs[2] == 30.0
def test_output_dictionary(self): def test_output_dictionary(self):
x = T.scalar() x = T.scalar()
f = theano.function([x], outputs={'a': x, 'c': x*2, f = theano.function([x], outputs={'a': x, 'c': x*2,
...@@ -29,7 +18,7 @@ class dictionary_output_checker(unittest.TestCase): ...@@ -29,7 +18,7 @@ class dictionary_output_checker(unittest.TestCase):
assert outputs['1'] == 40.0 assert outputs['1'] == 40.0
assert outputs['c'] == 20.0 assert outputs['c'] == 20.0
def test_input_dictionary(self): def test_input_named_variables(self):
x = T.scalar('x') x = T.scalar('x')
y = T.scalar('y') y = T.scalar('y')
...@@ -39,7 +28,7 @@ class dictionary_output_checker(unittest.TestCase): ...@@ -39,7 +28,7 @@ class dictionary_output_checker(unittest.TestCase):
assert f(2, y=4) == f(2, 4) assert f(2, y=4) == f(2, 4)
assert f(x=2, y=4) == f(2, 4) assert f(x=2, y=4) == f(2, 4)
def test_output_order(self): def test_output_order_sorted(self):
x = T.scalar('x') x = T.scalar('x')
y = T.scalar('y') y = T.scalar('y')
z = T.scalar('z') z = T.scalar('z')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论