提交 2720aeed authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Check if provided inputs of function are unused.

Add a mechanism to either raise an error (default), display a warning, or do nothing when unused inputs are passed.
上级 873d41b0
......@@ -1735,7 +1735,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
def __init__(self, inputs, outputs, optimizer, mode,
accept_inplace = False,
function_builder = Function,
profile=None):
profile=None,
on_unused_input='raise'):
"""
:type inputs: a list of SymbolicInput instances
......@@ -1748,6 +1749,9 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
inplace operations in the graph from the inputs to
the outputs
:param on_unused_input: What to do if a variable in the 'inputs' list is
not used in the graph. Possible values are 'raise', 'warn', and 'ignore'.
:note: this function sets TensorType.filter_checks_isfinite
when `mode.check_isfinite` is True
......@@ -1772,6 +1776,9 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
[i.update for i in inputs
if getattr(i, 'update', False)])
# Check if some input variables are unused
self._check_unused_inputs(inputs, outputs, on_unused_input)
#TODO: REMOVE THIS CRUFT - it's complicated for SymbolicInputKits
indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
expanded_inputs = reduce(list.__add__, [list(z)
......
......@@ -13,7 +13,8 @@ from numpy import any #for to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, profile=None):
rebuild_strict=True, allow_input_downcast=None, profile=None,
on_unused_input='raise'):
"""
Return a callable object that will calculate `outputs` from `inputs`.
......@@ -68,6 +69,9 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
instance. If argument is `True` then a new ProfileStats instance will be
used. This profiling object will be available via self.profile.
:param on_unused_input: What to do if a variable in the 'inputs' list is
not used in the graph. Possible values are 'raise', 'warn', and 'ignore'.
:note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
another expression is undefined. Replacements specified with givens are different from
......@@ -111,6 +115,7 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
accept_inplace=accept_inplace,name=name,
rebuild_strict=rebuild_strict,
allow_input_downcast=allow_input_downcast,
on_unused_input=on_unused_input,
profile=profile)
# We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs
......
......@@ -958,7 +958,7 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs,
mode = None, accept_inplace = False, function_builder = Function,
profile=None):
profile=None, on_unused_input='raise'):
"""
:type inputs: a list of SymbolicInput instances
......@@ -972,6 +972,12 @@ class FunctionMaker(object):
:param accept_inplace: True iff it is acceptable to have inplace operations
in the graph from the inputs to the outputs
:param on_unused_input: What to do if a variable in the 'inputs' list
is not used in the graph. Possible values are:
- 'raise' (default): raise an error
- 'warn': log a warning
- 'ignore': do not do anything
"""
mode = mode_module.get_mode(mode)
......@@ -1005,6 +1011,10 @@ class FunctionMaker(object):
_inputs = gof.graph.inputs([o.variable for o in outputs] + [i.update
for i in inputs if getattr(i, 'update', False)])
# Check if some input variables are unused
self._check_unused_inputs(inputs, outputs, on_unused_input)
#TODO: REMOVE THIS CRUFT - it's complicated for SymbolicInputKits
indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], [])
......@@ -1072,6 +1082,37 @@ class FunctionMaker(object):
(i.value != None and not isinstance(i.value, gof.Container) and i.update == None)
for i in self.inputs]
def _check_unused_inputs(self, inputs, outputs, on_unused_input):
if on_unused_input == 'ignore':
return
# There should be two categories of variables in inputs:
# - variables that have to be provided (used_inputs)
# - shared variables that will be updated
used_inputs = gof.graph.ancestors(
([o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, 'update', False)]),
blockers=[i.variable for i in inputs])
msg = ("theano.function was asked to create a function computing "
"outputs given certain inputs, but one of the provided "
"input variables is not part of the computational graph "
"needed to compute the outputs: %s. To disable this %s, "
"you can pass the parameter on_unused_input='ignore' to "
"theano.function.")
for i in inputs:
if ((i.variable not in used_inputs) and (i.update is None)):
if on_unused_input == 'warn':
warnings.warn(msg % (i.variable, 'warning'), stacklevel=5)
elif on_unused_input == 'raise':
raise ValueError(msg % (i.variable, 'error'))
else:
raise ValueError(("Invalid value for keyword "
"on_unused_input of theano.function: '%s'. "
"valid values are 'raise', 'warn', and 'ignore'."
% on_unused_input))
def create(self, input_storage=None, trustme=False):
"""
Create a function.
......@@ -1202,7 +1243,8 @@ def check_equal(x, y):
def register_checker(checker):
__checkers.insert(0, checker)
def orig_function(inputs, outputs, mode=None, accept_inplace = False, name=None, profile=None):
def orig_function(inputs, outputs, mode=None, accept_inplace = False,
name=None, profile=None, on_unused_input='raise'):
"""
Return a Function that will calculate the outputs from the inputs.
......@@ -1232,6 +1274,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False, name=None,
:param profile: None or ProfileStats instance
:param on_unused_input: What to do if a variable in the 'inputs' list is
not used in the graph. Possible values are 'raise', 'warn', and 'ignore'.
"""
#Every element of the input list will be upgraded to an `In` instance if necessary,
......@@ -1262,7 +1306,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False, name=None,
outputs,
mode[0],
accept_inplace = accept_inplace,
profile=profile).create(
profile=profile,
on_unused_input=on_unused_input).create(
defaults)
else:
if profile:
......@@ -1292,7 +1337,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False, name=None,
outputs,
mode,
accept_inplace = accept_inplace,
profile=profile).create(
profile=profile,
on_unused_input=on_unused_input).create(
defaults)
t2 = time.time()
......
......@@ -324,7 +324,7 @@ class Param(object):
def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None,
profile=None):
profile=None, on_unused_input='raise'):
"""Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances.
......@@ -372,6 +372,10 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
with that string as its `message` attribute. This profiling object will be
available via self.profile.
:type profile: str
:param profile: What to do if a variable in the 'inputs' list is not used
in the graph. Possible values are 'raise', 'warn', and 'ignore.
:rtype: theano.compile.Function
:returns: a callable object that will compute the outputs (given the inputs)
......@@ -460,7 +464,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
inputs.append(si)
return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile)
accept_inplace=accept_inplace, name=name, profile=profile,
on_unused_input=on_unused_input)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论