提交 b5263298 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Flake8 for compile/pfunc.py

上级 4d98657a
...@@ -21,7 +21,7 @@ def rebuild_collect_shared(outputs, ...@@ -21,7 +21,7 @@ def rebuild_collect_shared(outputs,
rebuild_strict=True, rebuild_strict=True,
copy_inputs_over=True, copy_inputs_over=True,
no_default_updates=False, no_default_updates=False,
): ):
""" """
Function that allows replacing subgraphs of a computational Function that allows replacing subgraphs of a computational
graph. graph.
...@@ -152,12 +152,12 @@ def rebuild_collect_shared(outputs, ...@@ -152,12 +152,12 @@ def rebuild_collect_shared(outputs,
if v_orig in clone_d: if v_orig in clone_d:
raise AssertionError( raise AssertionError(
"When using 'givens' or 'replace' with several " "When using 'givens' or 'replace' with several "
"(old_v, new_v) replacement pairs, you can not have a " "(old_v, new_v) replacement pairs, you can not have a "
"new_v variable depend on an old_v one. For instance, " "new_v variable depend on an old_v one. For instance, "
"givens = {a:b, b:(a+1)} is not allowed. Here, the old_v " "givens = {a:b, b:(a+1)} is not allowed. Here, the old_v "
"%s is used to compute other new_v's, but it is scheduled " "%s is used to compute other new_v's, but it is scheduled "
"to be replaced by %s." % (v_orig, v_repl)) "to be replaced by %s." % (v_orig, v_repl))
clone_d[v_orig] = clone_v_get_shared_updates(v_repl, clone_d[v_orig] = clone_v_get_shared_updates(v_repl,
copy_inputs_over) copy_inputs_over)
...@@ -199,7 +199,7 @@ def rebuild_collect_shared(outputs, ...@@ -199,7 +199,7 @@ def rebuild_collect_shared(outputs,
# filter_variable ensure smooth conversion of cpu/gpu Types # filter_variable ensure smooth conversion of cpu/gpu Types
try: try:
update_val = store_into.type.filter_variable(update_val) update_val = store_into.type.filter_variable(update_val)
except TypeError as e: except TypeError:
err_msg = ('An update must have the same type as the' err_msg = ('An update must have the same type as the'
' original shared variable (shared_var=%s,' ' original shared variable (shared_var=%s,'
' shared_var.type=%s,' ' shared_var.type=%s,'
...@@ -275,35 +275,38 @@ def rebuild_collect_shared(outputs, ...@@ -275,35 +275,38 @@ def rebuild_collect_shared(outputs,
class Param(object): class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, def __init__(self, variable, default=None, name=None, mutable=False,
strict=False, allow_downcast=None, implicit=None, borrow=None): strict=False, allow_downcast=None, implicit=None,
borrow=None):
""" """
:param variable: A variable in an expression graph to use as a :param variable: A variable in an expression graph to use as a
compiled-function parameter compiled-function parameter
:param default: The default value to use at call-time (can also be a Container where :param default: The default value to use at call-time (can
the function will find a value at call-time.) also be a Container where the function will find a value
at call-time.)
:param name: A string to identify this parameter from function kwargs. :param name: A string to identify this parameter from function kwargs.
:param mutable: True -> function is allowed to modify this argument. :param mutable: True -> function is allowed to modify this argument.
:param borrow: Whether the function is allowed to alias some output to :param borrow: Whether the function is allowed to alias some
this input. Using None (default) means we re-use the same value as the output to this input. Using None (default) means we re-use
`mutable` flag. the same value as the `mutable` flag.
False: do not permit any output to be aliased to the input False: do not permit any output to be aliased to the input
:param strict: False -> function arguments may be copied or cast to match the :param strict: False -> function arguments may be copied or
type required by the parameter `variable`. cast to match the type required by the parameter
`variable`.
True -> function arguments must exactly match the type True -> function arguments must exactly match the type
required by `variable`. required by `variable`.
:param allow_downcast: Only applies if `strict` is False. :param allow_downcast: Only applies if `strict` is False.
True -> allow assigned value to lose precision when cast during assignment. True -> allow assigned value to lose precision when cast
during assignment.
False -> never allow precision loss. False -> never allow precision loss.
None -> only allow downcasting of a Python float to a scalar floatX. None -> only allow downcasting of a Python float to a scalar floatX.
:param implicit: see help(theano.io.In) :param implicit: see help(theano.io.In)
""" """
self.variable = variable self.variable = variable
self.default = default self.default = default
...@@ -320,12 +323,12 @@ class Param(object): ...@@ -320,12 +323,12 @@ class Param(object):
# aliased to the input. Thus mutable=True should require borrow=True. # aliased to the input. Thus mutable=True should require borrow=True.
if self.mutable and not self.borrow: if self.mutable and not self.borrow:
raise AssertionError( raise AssertionError(
"Symbolic input for variable %s (name=%s) has " "Symbolic input for variable %s (name=%s) has "
"flags mutable=True, borrow=False. This combination is " "flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the " "incompatible since mutable=True implies that the "
"input variable may be both aliased (borrow=True) and " "input variable may be both aliased (borrow=True) and "
"overwritten.", "overwritten.",
variable, name) variable, name)
self.strict = strict self.strict = strict
self.allow_downcast = allow_downcast self.allow_downcast = allow_downcast
...@@ -333,9 +336,9 @@ class Param(object): ...@@ -333,9 +336,9 @@ class Param(object):
def pfunc(params, outputs=None, mode=None, updates=None, givens=None, def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None,output_keys=None): profile=None, on_unused_input=None, output_keys=None):
"""Function-constructor for graphs with shared variables. """Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances. :type params: list of either Variable or Param instances.
...@@ -348,30 +351,35 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -348,30 +351,35 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
:type mode: string or `theano.compile.Mode` instance. :type mode: string or `theano.compile.Mode` instance.
:param mode: compilation mode :param mode: compilation mode
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or dict. :type updates: iterable over pairs (shared_variable,
:param updates: update the values for SharedVariable inputs according to these expressions new_expression). List, tuple or dict.
:param updates: update the values for SharedVariable inputs
according to these expressions
:type givens: iterable over pairs (Var1, Var2) of Variables. List, tuple or dict. The Var1 :type givens: iterable over pairs (Var1, Var2) of Variables. List,
and Var2 in each pair must have the same Type. tuple or dict. The Var1 and Var2 in each pair must have the
same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces :param givens: specific substitutions to make in the computation
Var1). graph (Var2 replaces Var1).
:type no_default_updates: either bool or list of Variables :type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables. :param no_default_updates: if True, do not perform any automatic
If False (default), perform them all. Else, perform automatic updates on all Variables update on Variables. If False (default), perform them
that are neither in "updates" nor in "no_default_updates". all. Else, perform automatic updates on all Variables that are
neither in "updates" nor in "no_default_updates".
:type name: None or string :type name: None or string
:param name: attaches a name to the profiling result of this function. :param name: attaches a name to the profiling result of this function.
:type allow_input_downcast: Boolean :type allow_input_downcast: Boolean
:param allow_input_downcast: True means that the values passed as :param allow_input_downcast: True means that the values passed as
inputs when calling the function can be silently downcasted to fit inputs when calling the function can be silently downcasted to
the dtype of the corresponding Variable, which may lose precision. fit the dtype of the corresponding Variable, which may lose
False means that it will only be cast to a more general, or precision. False means that it will only be cast to a more
precise, type. None (default) is almost like False, but allows general, or precise, type. None (default) is almost like
downcasting of Python float scalars to floatX. False, but allows downcasting of Python float scalars to
floatX.
:type profile: None, True, str, or ProfileStats instance :type profile: None, True, str, or ProfileStats instance
:param profile: accumulate profiling information into a given ProfileStats :param profile: accumulate profiling information into a given ProfileStats
...@@ -389,30 +397,32 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -389,30 +397,32 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
:rtype: theano.compile.Function :rtype: theano.compile.Function
:returns: a callable object that will compute the outputs (given the inputs) :returns: a callable object that will compute the outputs (given
and update the implicit function arguments according to the `updates`. the inputs) and update the implicit function arguments
according to the `updates`.
:note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
another expression is undefined. Replacements specified with givens are different from
optimizations in that Var2 is not expected to be equivalent to Var1.
:note: Regarding givens: Be careful to make sure that these
substitutions are independent--behaviour when Var1 of one pair
appears in the graph leading to Var2 in another expression is
undefined. Replacements specified with givens are different
from optimizations in that Var2 is not expected to be
equivalent to Var1.
""" """
# #
# This function works by cloning the graph (except for the inputs), and then shipping it # This function works by cloning the graph (except for the
# off to compile.function # inputs), and then shipping it off to compile.function (There it
# (There it will be cloned again, unnecessarily, because it doesn't know that we already # will be cloned again, unnecessarily, because it doesn't know
# cloned it.) # that we already cloned it.)
# #
# First, it clones the replacements named in the givens argument, and points each Var1 to # First, it clones the replacements named in the givens argument,
# the clone of Var2. # and points each Var1 to the clone of Var2. Then it sets the
# Then it sets the inputs in the clone dictionary. # inputs in the clone dictionary. After these steps, we are
# After these steps, we are assuming that the clone dictionary contains all the inputs to # assuming that the clone dictionary contains all the inputs to
# the computation graph. # the computation graph.
# #
# Then it clones the outputs and the update expressions. This rebuilds a computation graph # Then it clones the outputs and the update expressions. This
# from the inputs and the givens. # rebuilds a computation graph from the inputs and the givens.
# #
if updates is None: if updates is None:
updates = [] updates = []
...@@ -431,11 +441,13 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -431,11 +441,13 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
# useful. # useful.
if not isinstance(params, (list, tuple)): if not isinstance(params, (list, tuple)):
raise Exception("in pfunc() the first argument must be a list or a tuple") raise Exception("in pfunc() the first argument must be a list or "
"a tuple")
if not isinstance(no_default_updates, bool)\ if not isinstance(no_default_updates, bool)\
and not isinstance(no_default_updates, list): and not isinstance(no_default_updates, list):
raise TypeError("no_default_update should be either a boolean or a list") raise TypeError("no_default_update should be either a boolean or "
"a list")
if len(updates) > 0 and any(isinstance(v, Variable) if len(updates) > 0 and any(isinstance(v, Variable)
for v in iter_over_pairs(updates)): for v in iter_over_pairs(updates)):
...@@ -453,10 +465,10 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -453,10 +465,10 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
if v in in_variables[(i + 1):]: if v in in_variables[(i + 1):]:
dup_v_i = in_variables.index(v, (i + 1)) dup_v_i = in_variables.index(v, (i + 1))
raise UnusedInputError( raise UnusedInputError(
("Variable %s is used twice in inputs to theano.function, " ("Variable %s is used twice in inputs to theano.function, "
"at indices %i and %i. This would result in values " "at indices %i and %i. This would result in values "
"provided for it being ignored. Please do not duplicate " "provided for it being ignored. Please do not duplicate "
"variables in the inputs list." % (v, i, dup_v_i))) "variables in the inputs list." % (v, i, dup_v_i)))
# Check that we are not using `givens` to replace input variables, because # Check that we are not using `givens` to replace input variables, because
# this typically does nothing, contrary to what one may expect. # this typically does nothing, contrary to what one may expect.
...@@ -494,9 +506,10 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -494,9 +506,10 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
i.variable = iv i.variable = iv
for sv in shared_inputs: for sv in shared_inputs:
# pass value of None here # pass value of None
# value will be stored in the resulting functions' defaults list # value will be stored in the resulting functions' defaults
# but since the value of shared variables never needs to be refed, it is not needed # list but since the value of shared variables never needs to
# be refed, it is not needed
if sv in update_d: if sv in update_d:
si = In(variable=sv, value=sv.container, mutable=True, si = In(variable=sv, value=sv.container, mutable=True,
borrow=True, update=update_d[sv], shared=True) borrow=True, update=update_d[sv], shared=True)
...@@ -506,8 +519,9 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -506,8 +519,9 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
inputs.append(si) inputs.append(si)
return orig_function(inputs, cloned_outputs, mode, return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile, accept_inplace=accept_inplace, name=name,
on_unused_input=on_unused_input, output_keys=output_keys) profile=profile, on_unused_input=on_unused_input,
output_keys=output_keys)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
...@@ -517,14 +531,14 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None): ...@@ -517,14 +531,14 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
return In(variable=param, strict=strict, allow_downcast=allow_downcast) return In(variable=param, strict=strict, allow_downcast=allow_downcast)
elif isinstance(param, Param): elif isinstance(param, Param):
return In( return In(
variable=param.variable, variable=param.variable,
name=param.name, name=param.name,
value=param.default, value=param.default,
mutable=param.mutable, mutable=param.mutable,
strict=param.strict, strict=param.strict,
borrow=param.borrow, borrow=param.borrow,
allow_downcast=param.allow_downcast, allow_downcast=param.allow_downcast,
implicit=param.implicit) implicit=param.implicit)
raise TypeError('Unknown parameter type: %s' % type(param)) raise TypeError('Unknown parameter type: %s' % type(param))
......
...@@ -38,7 +38,6 @@ whitelist_flake8 = [ ...@@ -38,7 +38,6 @@ whitelist_flake8 = [
"tests/test_tutorial.py", "tests/test_tutorial.py",
"tests/disturb_mem.py", "tests/disturb_mem.py",
"tests/unittest_tools.py", "tests/unittest_tools.py",
"compile/pfunc.py",
"compile/mode.py", "compile/mode.py",
"compile/profilemode.py", "compile/profilemode.py",
"compile/builders.py", "compile/builders.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论