提交 244eea74 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Raise an Exception for duplicate inputs in pfunc

上级 5bc7aa8a
...@@ -422,7 +422,17 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -422,7 +422,17 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
inputs = [_pfunc_param_to_in(p, allow_downcast=allow_input_downcast) inputs = [_pfunc_param_to_in(p, allow_downcast=allow_input_downcast)
for p in params] for p in params]
# Check if some variable is present more than once in inputs
in_variables = [input.variable for input in inputs] in_variables = [input.variable for input in inputs]
for i, v in enumerate(in_variables):
if v in in_variables[(i + 1):]:
dup_v_i = in_variables.index(v, (i + 1))
raise ValueError(
("Variable %s is used twice in inputs to theano.function, "
"at indices %i and %i. This would result in values "
"provided for it being ignored. Please do not duplicate "
"variables in the inputs list." % (v, i, dup_v_i)))
output_vars = rebuild_collect_shared(outputs, output_vars = rebuild_collect_shared(outputs,
in_variables, in_variables,
replace=givens, replace=givens,
......
...@@ -616,7 +616,9 @@ class Test_pfunc(unittest.TestCase): ...@@ -616,7 +616,9 @@ class Test_pfunc(unittest.TestCase):
assert f() == 21 assert f() == 21
assert f() == 34 assert f() == 34
def test_duplicate_inputs(self):
x = theano.tensor.lscalar('x')
self.assertRaises(ValueError, theano.function, [x, x, x], x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论