提交 d79dd27b authored 作者: lamblin's avatar lamblin

Merge pull request #517 from delallea/minor

Minor fixes
...@@ -8,6 +8,7 @@ import copy_reg ...@@ -8,6 +8,7 @@ import copy_reg
import cPickle import cPickle
import itertools import itertools
import time import time
import warnings
import numpy import numpy
...@@ -1114,7 +1115,7 @@ class FunctionMaker(object): ...@@ -1114,7 +1115,7 @@ class FunctionMaker(object):
for i in inputs: for i in inputs:
if ((i.variable not in used_inputs) and (i.update is None)): if ((i.variable not in used_inputs) and (i.update is None)):
if on_unused_input == 'warn': if on_unused_input == 'warn':
warnings.warn(msg % (i.variable, warn_msg), stacklevel=5) warnings.warn(msg % (i.variable, warn_msg), stacklevel=6)
elif on_unused_input == 'raise': elif on_unused_input == 'raise':
raise UnusedInputError(msg % (i.variable, err_msg)) raise UnusedInputError(msg % (i.variable, err_msg))
else: else:
...@@ -1253,21 +1254,24 @@ def check_equal(x, y): ...@@ -1253,21 +1254,24 @@ def check_equal(x, y):
def register_checker(checker): def register_checker(checker):
__checkers.insert(0, checker) __checkers.insert(0, checker)
def orig_function(inputs, outputs, mode=None, accept_inplace = False,
def orig_function(inputs, outputs, mode=None, accept_inplace=False,
name=None, profile=None, on_unused_input='raise'): name=None, profile=None, on_unused_input='raise'):
""" """
Return a Function that will calculate the outputs from the inputs. Return a Function that will calculate the outputs from the inputs.
:param inputs: list of `SymbolicInput` or `In` instances :param inputs: list of `SymbolicInput` or `In` instances
:param outputs: a SymbolicOutput or a list of `SymbolicOutput` or `Out` instances. The return :param outputs: a SymbolicOutput or a list of `SymbolicOutput` or `Out`
value of the returned function will match the format of this argument (either the value instances. The return value of the returned function will match the
itself or a list of one or more return values) format of this argument (either the value itself or a list of one or more
return values)
:param mode: a descriptive string or a Mode instance. (Default of None means to use :param mode: a descriptive string or a Mode instance. (Default of None
`config.mode` (See below for descriptive string list). means to use `config.mode` (See below for descriptive string list).
:param name: an optional name for this fct. If used, the profile mode will print the time spent in this fct. :param name: an optional name for this fct. If used, the profile mode will
print the time spent in this fct.
Currently, the library provides the following mode strings: Currently, the library provides the following mode strings:
...@@ -1275,12 +1279,13 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False, ...@@ -1275,12 +1279,13 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False,
- FAST_COMPILE (minimal optimization) - FAST_COMPILE (minimal optimization)
- PROFILE_MODE : allow to print a profile mode with mode.print_summary - PROFILE_MODE: allow to print a profile mode with mode.print_summary
- DEBUG_MODE : verify many internal conditions that are normally assumed (SLOW) - DEBUG_MODE: verify many internal conditions that are normally assumed
(slow)
:param accept_inplace: True iff the graph can contain inplace operations prior to the :param accept_inplace: True iff the graph can contain inplace operations
optimization phase (default is False) prior to the optimization phase (default is False)
:param profile: None or ProfileStats instance :param profile: None or ProfileStats instance
...@@ -1288,11 +1293,12 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False, ...@@ -1288,11 +1293,12 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False,
not used in the graph. Possible values are 'raise', 'warn', and 'ignore'. not used in the graph. Possible values are 'raise', 'warn', and 'ignore'.
""" """
#Every element of the input list will be upgraded to an `In` instance if necessary, # Every element of the input list will be upgraded to an `In` instance if
#using the rules implemented by the `convert_function_input` function. # necessary, using the rules implemented by the `convert_function_input`
# function.
#Similarly, every element of the output list will be upgraded to an # Similarly, every element of the output list will be upgraded to an `Out`
#`Out` instance if necessary: # instance if necessary:
t1 = time.time() t1 = time.time()
mode = mode_module.get_mode(mode) mode = mode_module.get_mode(mode)
...@@ -1315,51 +1321,58 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False, ...@@ -1315,51 +1321,58 @@ def orig_function(inputs, outputs, mode=None, accept_inplace = False,
inputs, inputs,
outputs, outputs,
mode[0], mode[0],
accept_inplace = accept_inplace, accept_inplace=accept_inplace,
profile=profile, profile=profile,
on_unused_input=on_unused_input).create( on_unused_input=on_unused_input).create(defaults)
defaults)
else: else:
if profile: if profile:
raise NotImplementedError('profiling not implemented in this kind of mode') raise NotImplementedError('profiling not implemented in this '
'kind of mode')
#return a different kind of function #return a different kind of function
def dup_defaults(): def dup_defaults():
# TODO This may need to be changed to use containers as defaults. # TODO This may need to be changed to use containers as
# defaults.
retval = [] retval = []
for default in defaults: for default in defaults:
if isinstance(default, gof.Container): if isinstance(default, gof.Container):
retval +=[copy.copy(default.value)] retval += [copy.copy(default.value)]
else: else:
retval +=[copy.copy(default)] retval += [copy.copy(default)]
return retval return retval
#backport #backport
#return [copy.copy(default.value) if isinstance(default, gof.Container) else #return [copy.copy(default.value)
# if isinstance(default, gof.Container) else
# copy.copy(default) # copy.copy(default)
# for default in defaults] # for default in defaults]
makers = [FunctionMaker(inputs, outputs, m, accept_inplace = accept_inplace) for m in mode[1:]] makers = [FunctionMaker(inputs, outputs, m,
fns = [maker.create(dup_defaults(), trustme = True) for maker in makers] accept_inplace=accept_inplace)
for m in mode[1:]]
fns = [maker.create(dup_defaults(), trustme=True)
for maker in makers]
builder = partial(SanityCheckFunction, fns, check_equal) builder = partial(SanityCheckFunction, fns, check_equal)
maker1 = FunctionMaker(inputs, outputs, mode[0], accept_inplace = accept_inplace, function_builder = builder) maker1 = FunctionMaker(inputs, outputs, mode[0],
accept_inplace=accept_inplace,
function_builder=builder)
fn = maker1.create(defaults) fn = maker1.create(defaults)
else: else:
Maker = getattr(mode, 'function_maker', FunctionMaker) Maker = getattr(mode, 'function_maker', FunctionMaker)
fn = Maker(inputs, fn = Maker(inputs,
outputs, outputs,
mode, mode,
accept_inplace = accept_inplace, accept_inplace=accept_inplace,
profile=profile, profile=profile,
on_unused_input=on_unused_input).create( on_unused_input=on_unused_input).create(
defaults) defaults)
t2 = time.time() t2 = time.time()
if profile: if profile:
profile.compile_time+=t2-t1 profile.compile_time += t2 - t1
fn.name = name fn.name = name
return fn return fn
def convert_function_input(input): def convert_function_input(input):
""" """
Upgrade a input shortcut to an In instance. Upgrade a input shortcut to an In instance.
...@@ -1372,18 +1385,19 @@ def convert_function_input(input): ...@@ -1372,18 +1385,19 @@ def convert_function_input(input):
- a tuple (r, val) will be `In`(r, value=value, autoname=True) - a tuple (r, val) will be `In`(r, value=value, autoname=True)
- a tuple ((r,up), val) will be `In`(r, value=value, update=up, autoname=True) - a tuple ((r,up), val) will be
`In`(r, value=value, update=up, autoname=True)
- a tuple (name, r, val) will be `In`(r, name=name, value=value) - a tuple (name, r, val) will be `In`(r, name=name, value=value)
- a tuple (name, (r,up), val) will be `In`(r, name=name, value=val, update=up, autoname=True) - a tuple (name, (r,up), val) will be
`In`(r, name=name, value=val, update=up, autoname=True)
""" """
if isinstance(input, (SymbolicInput, SymbolicInputKit)): if isinstance(input, (SymbolicInput, SymbolicInputKit)):
return input return input
elif isinstance(input, gof.Constant): elif isinstance(input, gof.Constant):
raise TypeError('A Constant instance is not a legal function input', input) raise TypeError('A Constant instance is not a legal function input',
input)
elif isinstance(input, gof.Variable): elif isinstance(input, gof.Variable):
return In(input) return In(input)
elif isinstance(input, (list, tuple)): elif isinstance(input, (list, tuple)):
...@@ -1397,7 +1411,8 @@ def convert_function_input(input): ...@@ -1397,7 +1411,8 @@ def convert_function_input(input):
name = None name = None
if isinstance(input[0], (list, tuple)): if isinstance(input[0], (list, tuple)):
if len(input[0]) != 2 or len(input) != 2: if len(input[0]) != 2 or len(input) != 2:
raise TypeError("Invalid input syntax: %s (check documentation or use an In instance)" % orig) raise TypeError("Invalid input syntax: %s (check "
"documentation or use an In instance)" % orig)
(variable, update), value = input (variable, update), value = input
elif isinstance(input[0], gof.Variable): elif isinstance(input[0], gof.Variable):
if len(input) == 1: if len(input) == 1:
...@@ -1405,38 +1420,48 @@ def convert_function_input(input): ...@@ -1405,38 +1420,48 @@ def convert_function_input(input):
elif len(input) == 2: elif len(input) == 2:
(variable, value), update = input, None (variable, value), update = input, None
else: else:
raise TypeError("Invalid input syntax: %s (check documentation or use an In instance)" % orig) raise TypeError("Invalid input syntax: %s (check "
"documentation or use an In instance)" % orig)
elif isinstance(input[0], (SymbolicInput, SymbolicInputKit)): elif isinstance(input[0], (SymbolicInput, SymbolicInputKit)):
if len(input) == 1: if len(input) == 1:
return input[0] return input[0]
elif len(input) == 2: elif len(input) == 2:
input, value = input input, value = input
if name is not None: input.name = name if name is not None:
input.name = name
input.value = value input.value = value
return input return input
else: else:
raise TypeError("The input specification is not valid: %s" % input) raise TypeError("The input specification is not valid: %s" % input)
if not isinstance(variable, gof.Variable): if not isinstance(variable, gof.Variable):
raise TypeError("Unknown input type: %s, expected Variable instance" % type(variable), variable) raise TypeError("Unknown input type: %s, expected Variable "
"instance" % type(variable), variable)
if update is not None and not isinstance(update, gof.Variable): if update is not None and not isinstance(update, gof.Variable):
raise TypeError("Unknown update type: %s, expected Variable instance" % type(update), update) raise TypeError("Unknown update type: %s, expected Variable "
if value is not None and isinstance(value, (gof.Variable, SymbolicInput)): "instance" % type(update), update)
raise TypeError("The value for input %s should not be a Variable or SymbolicInput instance (got: %s)" % (variable, value)) if (value is not None and
isinstance(value, (gof.Variable, SymbolicInput))):
raise TypeError("The value for input %s should not be a Variable "
"or SymbolicInput instance (got: %s)" %
(variable, value))
return In(variable, name=name, value=value, update=update) return In(variable, name=name, value=value, update=update)
else: else:
raise TypeError("Unknown input type: %s, expected Variable instance" % type(input), input) raise TypeError("Unknown input type: %s, expected Variable instance" %
type(input), input)
def get_info_on_inputs(named_inputs, n_unnamed_inputs): def get_info_on_inputs(named_inputs, n_unnamed_inputs):
"""Return a human-readable description of named and un-named inputs.""" """Return a human-readable description of named and un-named inputs."""
n_named_inputs = len(named_inputs) n_named_inputs = len(named_inputs)
def get_plural(n): def get_plural(n):
if n > 1: if n > 1:
return 's' return 's'
else: else:
return '' return ''
if n_named_inputs == 0: if n_named_inputs == 0:
if n_unnamed_inputs == 0: if n_unnamed_inputs == 0:
msg = 'The function is supposed to have no input.' msg = 'The function is supposed to have no input.'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论