提交 b5263298 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Flake8 for compile/pfunc.py

上级 4d98657a
......@@ -21,7 +21,7 @@ def rebuild_collect_shared(outputs,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
):
):
"""
Function that allows replacing subgraphs of a computational
graph.
......@@ -152,12 +152,12 @@ def rebuild_collect_shared(outputs,
if v_orig in clone_d:
raise AssertionError(
"When using 'givens' or 'replace' with several "
"(old_v, new_v) replacement pairs, you can not have a "
"new_v variable depend on an old_v one. For instance, "
"givens = {a:b, b:(a+1)} is not allowed. Here, the old_v "
"%s is used to compute other new_v's, but it is scheduled "
"to be replaced by %s." % (v_orig, v_repl))
"When using 'givens' or 'replace' with several "
"(old_v, new_v) replacement pairs, you can not have a "
"new_v variable depend on an old_v one. For instance, "
"givens = {a:b, b:(a+1)} is not allowed. Here, the old_v "
"%s is used to compute other new_v's, but it is scheduled "
"to be replaced by %s." % (v_orig, v_repl))
clone_d[v_orig] = clone_v_get_shared_updates(v_repl,
copy_inputs_over)
......@@ -199,7 +199,7 @@ def rebuild_collect_shared(outputs,
# filter_variable ensure smooth conversion of cpu/gpu Types
try:
update_val = store_into.type.filter_variable(update_val)
except TypeError as e:
except TypeError:
err_msg = ('An update must have the same type as the'
' original shared variable (shared_var=%s,'
' shared_var.type=%s,'
......@@ -275,35 +275,38 @@ def rebuild_collect_shared(outputs,
class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False,
strict=False, allow_downcast=None, implicit=None, borrow=None):
strict=False, allow_downcast=None, implicit=None,
borrow=None):
"""
:param variable: A variable in an expression graph to use as a
compiled-function parameter
:param default: The default value to use at call-time (can also be a Container where
the function will find a value at call-time.)
:param default: The default value to use at call-time (can
also be a Container where the function will find a value
at call-time.)
:param name: A string to identify this parameter from function kwargs.
:param mutable: True -> function is allowed to modify this argument.
:param borrow: Whether the function is allowed to alias some output to
this input. Using None (default) means we re-use the same value as the
`mutable` flag.
:param borrow: Whether the function is allowed to alias some
output to this input. Using None (default) means we re-use
the same value as the `mutable` flag.
False: do not permit any output to be aliased to the input
:param strict: False -> function arguments may be copied or cast to match the
type required by the parameter `variable`.
:param strict: False -> function arguments may be copied or
cast to match the type required by the parameter
`variable`.
True -> function arguments must exactly match the type
required by `variable`.
:param allow_downcast: Only applies if `strict` is False.
True -> allow assigned value to lose precision when cast during assignment.
True -> allow assigned value to lose precision when cast
during assignment.
False -> never allow precision loss.
None -> only allow downcasting of a Python float to a scalar floatX.
:param implicit: see help(theano.io.In)
"""
self.variable = variable
self.default = default
......@@ -320,12 +323,12 @@ class Param(object):
# aliased to the input. Thus mutable=True should require borrow=True.
if self.mutable and not self.borrow:
raise AssertionError(
"Symbolic input for variable %s (name=%s) has "
"flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the "
"input variable may be both aliased (borrow=True) and "
"overwritten.",
variable, name)
"Symbolic input for variable %s (name=%s) has "
"flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the "
"input variable may be both aliased (borrow=True) and "
"overwritten.",
variable, name)
self.strict = strict
self.allow_downcast = allow_downcast
......@@ -333,9 +336,9 @@ class Param(object):
def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None,output_keys=None):
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None, output_keys=None):
"""Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances.
......@@ -348,30 +351,35 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
:type mode: string or `theano.compile.Mode` instance.
:param mode: compilation mode
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or dict.
:param updates: update the values for SharedVariable inputs according to these expressions
:type updates: iterable over pairs (shared_variable,
new_expression). List, tuple or dict.
:param updates: update the values for SharedVariable inputs
according to these expressions
:type givens: iterable over pairs (Var1, Var2) of Variables. List, tuple or dict. The Var1
and Var2 in each pair must have the same Type.
:type givens: iterable over pairs (Var1, Var2) of Variables. List,
tuple or dict. The Var1 and Var2 in each pair must have the
same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces
Var1).
:param givens: specific substitutions to make in the computation
graph (Var2 replaces Var1).
:type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables.
If False (default), perform them all. Else, perform automatic updates on all Variables
that are neither in "updates" nor in "no_default_updates".
:param no_default_updates: if True, do not perform any automatic
update on Variables. If False (default), perform them
all. Else, perform automatic updates on all Variables that are
neither in "updates" nor in "no_default_updates".
:type name: None or string
:param name: attaches a name to the profiling result of this function.
:type allow_input_downcast: Boolean
:param allow_input_downcast: True means that the values passed as
inputs when calling the function can be silently downcasted to fit
the dtype of the corresponding Variable, which may lose precision.
False means that it will only be cast to a more general, or
precise, type. None (default) is almost like False, but allows
downcasting of Python float scalars to floatX.
inputs when calling the function can be silently downcasted to
fit the dtype of the corresponding Variable, which may lose
precision. False means that it will only be cast to a more
general, or precise, type. None (default) is almost like
False, but allows downcasting of Python float scalars to
floatX.
:type profile: None, True, str, or ProfileStats instance
:param profile: accumulate profiling information into a given ProfileStats
......@@ -389,30 +397,32 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
:rtype: theano.compile.Function
:returns: a callable object that will compute the outputs (given the inputs)
and update the implicit function arguments according to the `updates`.
:returns: a callable object that will compute the outputs (given
the inputs) and update the implicit function arguments
according to the `updates`.
:note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
another expression is undefined. Replacements specified with givens are different from
optimizations in that Var2 is not expected to be equivalent to Var1.
:note: Regarding givens: Be careful to make sure that these
substitutions are independent--behaviour when Var1 of one pair
appears in the graph leading to Var2 in another expression is
undefined. Replacements specified with givens are different
from optimizations in that Var2 is not expected to be
equivalent to Var1.
"""
#
# This function works by cloning the graph (except for the inputs), and then shipping it
# off to compile.function
# (There it will be cloned again, unnecessarily, because it doesn't know that we already
# cloned it.)
# This function works by cloning the graph (except for the
# inputs), and then shipping it off to compile.function (There it
# will be cloned again, unnecessarily, because it doesn't know
# that we already cloned it.)
#
# First, it clones the replacements named in the givens argument, and points each Var1 to
# the clone of Var2.
# Then it sets the inputs in the clone dictionary.
# After these steps, we are assuming that the clone dictionary contains all the inputs to
# First, it clones the replacements named in the givens argument,
# and points each Var1 to the clone of Var2. Then it sets the
# inputs in the clone dictionary. After these steps, we are
# assuming that the clone dictionary contains all the inputs to
# the computation graph.
#
# Then it clones the outputs and the update expressions. This rebuilds a computation graph
# from the inputs and the givens.
# Then it clones the outputs and the update expressions. This
# rebuilds a computation graph from the inputs and the givens.
#
if updates is None:
updates = []
......@@ -431,11 +441,13 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
# useful.
if not isinstance(params, (list, tuple)):
raise Exception("in pfunc() the first argument must be a list or a tuple")
raise Exception("in pfunc() the first argument must be a list or "
"a tuple")
if not isinstance(no_default_updates, bool)\
and not isinstance(no_default_updates, list):
raise TypeError("no_default_update should be either a boolean or a list")
raise TypeError("no_default_update should be either a boolean or "
"a list")
if len(updates) > 0 and any(isinstance(v, Variable)
for v in iter_over_pairs(updates)):
......@@ -453,10 +465,10 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
if v in in_variables[(i + 1):]:
dup_v_i = in_variables.index(v, (i + 1))
raise UnusedInputError(
("Variable %s is used twice in inputs to theano.function, "
"at indices %i and %i. This would result in values "
"provided for it being ignored. Please do not duplicate "
"variables in the inputs list." % (v, i, dup_v_i)))
("Variable %s is used twice in inputs to theano.function, "
"at indices %i and %i. This would result in values "
"provided for it being ignored. Please do not duplicate "
"variables in the inputs list." % (v, i, dup_v_i)))
# Check that we are not using `givens` to replace input variables, because
# this typically does nothing, contrary to what one may expect.
......@@ -494,9 +506,10 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
i.variable = iv
for sv in shared_inputs:
# pass value of None here
# value will be stored in the resulting functions' defaults list
# but since the value of shared variables never needs to be refed, it is not needed
# pass value of None
# value will be stored in the resulting functions' defaults
# list but since the value of shared variables never needs to
# be refed, it is not needed
if sv in update_d:
si = In(variable=sv, value=sv.container, mutable=True,
borrow=True, update=update_d[sv], shared=True)
......@@ -506,8 +519,9 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
inputs.append(si)
return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile,
on_unused_input=on_unused_input, output_keys=output_keys)
accept_inplace=accept_inplace, name=name,
profile=profile, on_unused_input=on_unused_input,
output_keys=output_keys)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......@@ -517,14 +531,14 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
return In(variable=param, strict=strict, allow_downcast=allow_downcast)
elif isinstance(param, Param):
return In(
variable=param.variable,
name=param.name,
value=param.default,
mutable=param.mutable,
strict=param.strict,
borrow=param.borrow,
allow_downcast=param.allow_downcast,
implicit=param.implicit)
variable=param.variable,
name=param.name,
value=param.default,
mutable=param.mutable,
strict=param.strict,
borrow=param.borrow,
allow_downcast=param.allow_downcast,
implicit=param.implicit)
raise TypeError('Unknown parameter type: %s' % type(param))
......
......@@ -38,7 +38,6 @@ whitelist_flake8 = [
"tests/test_tutorial.py",
"tests/disturb_mem.py",
"tests/unittest_tools.py",
"compile/pfunc.py",
"compile/mode.py",
"compile/profilemode.py",
"compile/builders.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论