提交 2720aeed authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Check if provided inputs of function are unused.

Add a mechanism to either raise an error (default), display a warning, or do nothing when unused inputs are passed.
上级 873d41b0
...@@ -1735,7 +1735,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -1735,7 +1735,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
def __init__(self, inputs, outputs, optimizer, mode, def __init__(self, inputs, outputs, optimizer, mode,
accept_inplace = False, accept_inplace = False,
function_builder = Function, function_builder = Function,
profile=None): profile=None,
on_unused_input='raise'):
""" """
:type inputs: a list of SymbolicInput instances :type inputs: a list of SymbolicInput instances
...@@ -1748,6 +1749,9 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -1748,6 +1749,9 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
inplace operations in the graph from the inputs to inplace operations in the graph from the inputs to
the outputs the outputs
:param on_unused_input: What to do if a variable in the 'inputs' list is
not used in the graph. Possible values are 'raise', 'warn', and 'ignore'.
:note: this function sets TensorType.filter_checks_isfinite :note: this function sets TensorType.filter_checks_isfinite
when `mode.check_isfinite` is True when `mode.check_isfinite` is True
...@@ -1772,6 +1776,9 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -1772,6 +1776,9 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
[i.update for i in inputs [i.update for i in inputs
if getattr(i, 'update', False)]) if getattr(i, 'update', False)])
# Check if some input variables are unused
self._check_unused_inputs(inputs, outputs, on_unused_input)
#TODO: REMOVE THIS CRUFT - it's complicated for SymbolicInputKits #TODO: REMOVE THIS CRUFT - it's complicated for SymbolicInputKits
indices = [[input] + self.expand_in(input, _inputs) for input in inputs] indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
expanded_inputs = reduce(list.__add__, [list(z) expanded_inputs = reduce(list.__add__, [list(z)
......
...@@ -13,7 +13,8 @@ from numpy import any #for to work in python 2.4 ...@@ -13,7 +13,8 @@ from numpy import any #for to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=[], givens=[], def function(inputs, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, profile=None): rebuild_strict=True, allow_input_downcast=None, profile=None,
on_unused_input='raise'):
""" """
Return a callable object that will calculate `outputs` from `inputs`. Return a callable object that will calculate `outputs` from `inputs`.
...@@ -68,6 +69,9 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], ...@@ -68,6 +69,9 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
instance. If argument is `True` then a new ProfileStats instance will be instance. If argument is `True` then a new ProfileStats instance will be
used. This profiling object will be available via self.profile. used. This profiling object will be available via self.profile.
:param on_unused_input: What to do if a variable in the 'inputs' list is
not used in the graph. Possible values are 'raise', 'warn', and 'ignore'.
:note: Regarding givens: Be careful to make sure that these substitutions are :note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
another expression is undefined. Replacements specified with givens are different from another expression is undefined. Replacements specified with givens are different from
...@@ -111,6 +115,7 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], ...@@ -111,6 +115,7 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
accept_inplace=accept_inplace,name=name, accept_inplace=accept_inplace,name=name,
rebuild_strict=rebuild_strict, rebuild_strict=rebuild_strict,
allow_input_downcast=allow_input_downcast, allow_input_downcast=allow_input_downcast,
on_unused_input=on_unused_input,
profile=profile) profile=profile)
# We need to add the flag check_aliased inputs if we have any mutable or # We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs # borrowed used defined inputs
......
...@@ -958,7 +958,7 @@ class FunctionMaker(object): ...@@ -958,7 +958,7 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs, def __init__(self, inputs, outputs,
mode = None, accept_inplace = False, function_builder = Function, mode = None, accept_inplace = False, function_builder = Function,
profile=None): profile=None, on_unused_input='raise'):
""" """
:type inputs: a list of SymbolicInput instances :type inputs: a list of SymbolicInput instances
...@@ -972,6 +972,12 @@ class FunctionMaker(object): ...@@ -972,6 +972,12 @@ class FunctionMaker(object):
:param accept_inplace: True iff it is acceptable to have inplace operations :param accept_inplace: True iff it is acceptable to have inplace operations
in the graph from the inputs to the outputs in the graph from the inputs to the outputs
:param on_unused_input: What to do if a variable in the 'inputs' list
is not used in the graph. Possible values are:
- 'raise' (default): raise an error
- 'warn': log a warning
- 'ignore': do not do anything
""" """
mode = mode_module.get_mode(mode) mode = mode_module.get_mode(mode)
...@@ -1005,6 +1011,10 @@ class FunctionMaker(object): ...@@ -1005,6 +1011,10 @@ class FunctionMaker(object):
_inputs = gof.graph.inputs([o.variable for o in outputs] + [i.update _inputs = gof.graph.inputs([o.variable for o in outputs] + [i.update
for i in inputs if getattr(i, 'update', False)]) for i in inputs if getattr(i, 'update', False)])
# Check if some input variables are unused
self._check_unused_inputs(inputs, outputs, on_unused_input)
#TODO: REMOVE THIS CRUFT - it's complicated for SymbolicInputKits #TODO: REMOVE THIS CRUFT - it's complicated for SymbolicInputKits
indices = [[input] + self.expand_in(input, _inputs) for input in inputs] indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], []) expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], [])
...@@ -1072,6 +1082,37 @@ class FunctionMaker(object): ...@@ -1072,6 +1082,37 @@ class FunctionMaker(object):
(i.value != None and not isinstance(i.value, gof.Container) and i.update == None) (i.value != None and not isinstance(i.value, gof.Container) and i.update == None)
for i in self.inputs] for i in self.inputs]
def _check_unused_inputs(self, inputs, outputs, on_unused_input):
if on_unused_input == 'ignore':
return
# There should be two categories of variables in inputs:
# - variables that have to be provided (used_inputs)
# - shared variables that will be updated
used_inputs = gof.graph.ancestors(
([o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, 'update', False)]),
blockers=[i.variable for i in inputs])
msg = ("theano.function was asked to create a function computing "
"outputs given certain inputs, but one of the provided "
"input variables is not part of the computational graph "
"needed to compute the outputs: %s. To disable this %s, "
"you can pass the parameter on_unused_input='ignore' to "
"theano.function.")
for i in inputs:
if ((i.variable not in used_inputs) and (i.update is None)):
if on_unused_input == 'warn':
warnings.warn(msg % (i.variable, 'warning'), stacklevel=5)
elif on_unused_input == 'raise':
raise ValueError(msg % (i.variable, 'error'))
else:
raise ValueError(("Invalid value for keyword "
"on_unused_input of theano.function: '%s'. "
"valid values are 'raise', 'warn', and 'ignore'."
% on_unused_input))
def create(self, input_storage=None, trustme=False): def create(self, input_storage=None, trustme=False):
""" """
Create a function. Create a function.
...@@ -1202,7 +1243,8 @@ def check_equal(x, y): ...@@ -1202,7 +1243,8 @@ def check_equal(x, y):
def register_checker(checker): def register_checker(checker):
__checkers.insert(0, checker) __checkers.insert(0, checker)
def orig_function(inputs, outputs, mode=None, accept_inplace = False, name=None, profile=None): def orig_function(inputs, outputs, mode=None, accept_inplace = False,
name=None, profile=None, on_unused_input='raise'):
""" """
Return a Function that will calculate the outputs from the inputs. Return a Function that will calculate the outputs from the inputs.
...@@ -1232,6 +1274,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False, name=None, ...@@ -1232,6 +1274,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False, name=None,
:param profile: None or ProfileStats instance :param profile: None or ProfileStats instance
:param on_unused_input: What to do if a variable in the 'inputs' list is
not used in the graph. Possible values are 'raise', 'warn', and 'ignore'.
""" """
#Every element of the input list will be upgraded to an `In` instance if necessary, #Every element of the input list will be upgraded to an `In` instance if necessary,
...@@ -1262,7 +1306,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False, name=None, ...@@ -1262,7 +1306,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False, name=None,
outputs, outputs,
mode[0], mode[0],
accept_inplace = accept_inplace, accept_inplace = accept_inplace,
profile=profile).create( profile=profile,
on_unused_input=on_unused_input).create(
defaults) defaults)
else: else:
if profile: if profile:
...@@ -1292,7 +1337,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False, name=None, ...@@ -1292,7 +1337,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False, name=None,
outputs, outputs,
mode, mode,
accept_inplace = accept_inplace, accept_inplace = accept_inplace,
profile=profile).create( profile=profile,
on_unused_input=on_unused_input).create(
defaults) defaults)
t2 = time.time() t2 = time.time()
......
...@@ -324,7 +324,7 @@ class Param(object): ...@@ -324,7 +324,7 @@ class Param(object):
def pfunc(params, outputs=None, mode=None, updates=[], givens=[], def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, rebuild_strict=True, allow_input_downcast=None,
profile=None): profile=None, on_unused_input='raise'):
"""Function-constructor for graphs with shared variables. """Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances. :type params: list of either Variable or Param instances.
...@@ -372,6 +372,10 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -372,6 +372,10 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
with that string as its `message` attribute. This profiling object will be with that string as its `message` attribute. This profiling object will be
available via self.profile. available via self.profile.
:type profile: str
:param profile: What to do if a variable in the 'inputs' list is not used
in the graph. Possible values are 'raise', 'warn', and 'ignore.
:rtype: theano.compile.Function :rtype: theano.compile.Function
:returns: a callable object that will compute the outputs (given the inputs) :returns: a callable object that will compute the outputs (given the inputs)
...@@ -460,7 +464,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -460,7 +464,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
inputs.append(si) inputs.append(si)
return orig_function(inputs, cloned_outputs, mode, return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile) accept_inplace=accept_inplace, name=name, profile=profile,
on_unused_input=on_unused_input)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论