提交 fe312cee authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Prevent replacing function inputs with `givens`

This is just confusing as it does not actually modify the function being computed, as mentioned on the mailing list.
上级 86343762
......@@ -452,6 +452,24 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
"provided for it being ignored. Please do not duplicate "
"variables in the inputs list." % (v, i, dup_v_i)))
# Check that we are not using `givens` to replace input variables, because
# this typically does nothing, contrary to what one may expect.
in_var_set = set(in_variables)
try:
givens_pairs = givens.items()
except AttributeError:
givens_pairs = givens
for x, y in givens_pairs:
if x in in_var_set:
raise RuntimeError(
'You cannot replace variable \'%s\' (found in the '
'`givens` argument), because it is an input to the '
'function. If your goal is to replace some input '
'x to your output f(x) by an expression g(y), so as to '
'compute f(g(y)), then you can achieve this by defining '
'a new variable y and compiling your function '
'as: function([y], f(x), givens={x: g(y)}).' % x)
output_vars = rebuild_collect_shared(outputs,
in_variables,
replace=givens,
......
......@@ -386,6 +386,14 @@ class T_function(unittest.TestCase):
self.assertRaises(UnusedInputError, function, [m, mt], mt*2)
f = function([m, mt], mt*2, on_unused_input='ignore')
def test_givens_input_var(self):
"""
Ensure error is raised when trying to replace an input variable.
"""
x = T.scalar('x')
y = x * 2
self.assertRaises(RuntimeError, function, [x], y, givens={x: x + 1})
class T_picklefunction(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论