提交 121a72aa authored 作者: nouiz's avatar nouiz

Merge pull request #1213 from delallea/minor

A few improvements to function()
...@@ -164,8 +164,9 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -164,8 +164,9 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
if updates is None: if updates is None:
updates = [] updates = []
if isinstance(updates, dict) and \ if (isinstance(updates, dict) and
not isinstance(updates, gof.python25.OrderedDict): not isinstance(updates, gof.python25.OrderedDict) and
len(updates) > 1):
warnings.warn( warnings.warn(
"The parameter 'updates' of theano.function()" "The parameter 'updates' of theano.function()"
" expects an OrderedDict," " expects an OrderedDict,"
...@@ -186,8 +187,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -186,8 +187,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
# compute some features of the arguments: # compute some features of the arguments:
uses_In = any([isinstance(i, In) for i in inputs]) # N.B. the square brackets are ncessary uses_In = any([isinstance(i, In) for i in inputs]) # N.B. the square brackets are ncessary
uses_tuple = any([isinstance(i, (list, tuple)) for i in inputs]) # N.B. the square brackets are ncessary uses_tuple = any([isinstance(i, (list, tuple)) for i in inputs]) # N.B. the square brackets are ncessary
uses_updates = (updates != []) uses_updates = bool(updates)
uses_givens = (givens != []) uses_givens = bool(givens)
# See if we have any mutable / borrow inputs # See if we have any mutable / borrow inputs
check_for_aliased_inputs = False check_for_aliased_inputs = False
...@@ -201,7 +202,9 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -201,7 +202,9 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
if profile: if profile:
raise NotImplementedError('profiling not supported in old-style function') raise NotImplementedError('profiling not supported in old-style function')
if uses_updates or uses_givens: if uses_updates or uses_givens:
raise NotImplementedError("In() instances and tuple inputs triggers the old semantics, which disallow using updates and givens") raise NotImplementedError(
"In() instances and tuple inputs trigger the old "
"semantics, which disallow using updates and givens")
fn = orig_function(inputs, outputs, fn = orig_function(inputs, outputs,
mode=mode, mode=mode,
accept_inplace=accept_inplace, name=name) accept_inplace=accept_inplace, name=name)
......
...@@ -233,8 +233,8 @@ def rebuild_collect_shared(outputs, ...@@ -233,8 +233,8 @@ def rebuild_collect_shared(outputs,
cloned_outputs.append(Out(cloned_v, borrow=v.borrow)) cloned_outputs.append(Out(cloned_v, borrow=v.borrow))
else: else:
raise TypeError('Outputs must be theano Variable or ' raise TypeError('Outputs must be theano Variable or '
'Out instances. Received ' + str(v)\ 'Out instances. Received ' + str(v)
+ ' of type '+str(type(v))) + ' of type ' + str(type(v)))
#computed_list.append(cloned_v) #computed_list.append(cloned_v)
else: else:
if isinstance(outputs, Variable): if isinstance(outputs, Variable):
...@@ -278,23 +278,25 @@ class Param(object): ...@@ -278,23 +278,25 @@ class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, def __init__(self, variable, default=None, name=None, mutable=False,
strict=False, allow_downcast=None, implicit=None, borrow=None): strict=False, allow_downcast=None, implicit=None, borrow=None):
""" """
:param variable: A variable in an expression graph to use as a compiled-function parameter :param variable: A variable in an expression graph to use as a
compiled-function parameter
:param default: The default value to use at call-time (can also be a Container where :param default: The default value to use at call-time (can also be a Container where
the function will find a value at call-time.) the function will find a value at call-time.)
:param name: A string to identify this parameter from function kwargs. :param name: A string to identify this parameter from function kwargs.
:param mutable: True -> function is allowed to modify this argument. :param mutable: True -> function is allowed to modify this argument.
:param borrow: Whether the function is allowed to alias some output to :param borrow: Whether the function is allowed to alias some output to
this input. Using None (default) means we re-use the same value as the this input. Using None (default) means we re-use the same value as the
`mutable` flag. `mutable` flag.
False: do not permit any output to be aliased to the input
False: do not permit any output to be aliased to the input
:param strict: False -> function arguments may be copied or cast to match the :param strict: False -> function arguments may be copied or cast to match the
type required by the parameter `variable`. True -> function arguments must exactly match the type type required by the parameter `variable`.
required by `variable`. True -> function arguments must exactly match the type
required by `variable`.
:param allow_downcast: Only applies if `strict` is False. :param allow_downcast: Only applies if `strict` is False.
True -> allow assigned value to lose precision when cast during assignment. True -> allow assigned value to lose precision when cast during assignment.
...@@ -452,6 +454,27 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -452,6 +454,27 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
"provided for it being ignored. Please do not duplicate " "provided for it being ignored. Please do not duplicate "
"variables in the inputs list." % (v, i, dup_v_i))) "variables in the inputs list." % (v, i, dup_v_i)))
# Check that we are not using `givens` to replace input variables, because
# this typically does nothing, contrary to what one may expect.
in_var_set = set(in_variables)
try:
givens_pairs = givens.items()
except AttributeError:
givens_pairs = givens
for x, y in givens_pairs:
if x in in_var_set:
raise RuntimeError(
'You are trying to replace variable \'%s\' through the '
'`givens` parameter, but this variable is an input to your '
'function. Replacing inputs is currently forbidden because it '
'has no effect. One way to modify an input `x` to a function '
'evaluating f(x) is to define a new input `y` and use '
'`theano.function([y], f(x), givens={x: g(y)})`. Another '
'solution consists in using `theano.clone`, e.g. like this: '
'`theano.function([x], '
'theano.clone(f(x), replace={x: g(x)}))`.'
% x)
output_vars = rebuild_collect_shared(outputs, output_vars = rebuild_collect_shared(outputs,
in_variables, in_variables,
replace=givens, replace=givens,
......
...@@ -386,6 +386,14 @@ class T_function(unittest.TestCase): ...@@ -386,6 +386,14 @@ class T_function(unittest.TestCase):
self.assertRaises(UnusedInputError, function, [m, mt], mt*2) self.assertRaises(UnusedInputError, function, [m, mt], mt*2)
f = function([m, mt], mt*2, on_unused_input='ignore') f = function([m, mt], mt*2, on_unused_input='ignore')
def test_givens_input_var(self):
"""
Ensure error is raised when trying to replace an input variable.
"""
x = T.scalar('x')
y = x * 2
self.assertRaises(RuntimeError, function, [x], y, givens={x: x + 1})
class T_picklefunction(unittest.TestCase): class T_picklefunction(unittest.TestCase):
...@@ -680,6 +688,18 @@ class SomethingToPickle(object): ...@@ -680,6 +688,18 @@ class SomethingToPickle(object):
self.f2 = function([x, In(a, value=1.0,name='a'), In(s, value=self.f1.container[s], update=s+a*x, mutable=True)], s+a*x) self.f2 = function([x, In(a, value=1.0,name='a'), In(s, value=self.f1.container[s], update=s+a*x, mutable=True)], s+a*x)
def test_empty_givens_updates():
"""
Regression test for bug fixed in 8625e03.
"""
# Empty givens / updates dictionaries were not properly detected before,
# triggering useless crashes at compile time.
x = T.scalar()
y = x * 2
function([theano.In(x)], y, givens={})
function([theano.In(x)], y, updates={})
if __name__ == '__main__': if __name__ == '__main__':
if 1: if 1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论