提交 a19013dd authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Flake8 for compile/function_module.py

上级 e710ee08
"""Driver of graph construction, optimization, and linking.
"""
from __future__ import print_function
......@@ -35,7 +34,7 @@ class UnusedInputError(Exception):
def alias_root(v):
"""Return the variable to which v is aliased by view_maps and destroy_maps"""
"Return the variable to which v is aliased by view_maps and destroy_maps"
if v.owner is None:
return v
vmap = getattr(v.owner.op, 'view_map', {})
......@@ -54,7 +53,8 @@ def alias_root(v):
def view_tree_set(v, treeset):
"""Add to `treeset` all variables that are views of v, given that v is not a view"""
"""Add to `treeset` all variables that are views of v, given that v is
not a view"""
treeset.add(v)
for cl, v_input_pos_to_cl in v.clients:
if cl == 'output':
......@@ -69,11 +69,13 @@ def view_tree_set(v, treeset):
def infer_reuse_pattern(fgraph, outputs_to_disown):
"""
Given an fgraph and a list of variables, returns the list or set of all variables which may
share the same underlying data storage as any of the specified variables. Used internally
by function, FunctionMaker.
Given an fgraph and a list of variables, returns the list or set
of all variables which may share the same underlying data storage
as any of the specified variables. Used internally by function,
FunctionMaker.
This list (or set) is also refered to as no_recycling sometimes, especially by linker code.
This list (or set) is also refered to as no_recycling sometimes,
especially by linker code.
"""
rval = set()
for o in outputs_to_disown:
......@@ -103,10 +105,10 @@ def fgraph_updated_vars(fgraph, expanded_inputs):
class Supervisor:
"""
Listener for FunctionGraph events which makes sure that no operation overwrites the
contents of protected Variables. The outputs of the FunctionGraph are protected by default.
Listener for FunctionGraph events which makes sure that no
operation overwrites the contents of protected Variables. The
outputs of the FunctionGraph are protected by default.
"""
def __init__(self, protected):
self.protected = list(protected)
......@@ -176,33 +178,38 @@ class AliasedMemoryError(Exception):
# Function
###
DUPLICATE = ['DUPLICATE'] # unique id object used as a placeholder for duplicate entries
# unique id object used as a placeholder for duplicate entries
DUPLICATE = ['DUPLICATE']
class Function(object):
"""
Type of the functions returned by theano.function or theano.FunctionMaker.create.
`Function` is the callable object that does computation. It has the storage of inputs and
outputs, performs the packing and unpacking of inputs and return values. It implements the
square-bracket indexing so that you can look up the value of a symbolic node.
Type of the functions returned by theano.function or
theano.FunctionMaker.create.
Functions are copyable via {{{fn.copy()}}} and {{{copy.copy(fn)}}}.
When a function is copied, this instance is duplicated. Contrast with self.maker
(instance of `FunctionMaker`) that is shared between copies.
The meaning of copying a function is that the containers and their current values will all be duplicated.
This requires that mutable inputs be copied, whereas immutable inputs may be shared between copies.
`Function` is the callable object that does computation. It has
the storage of inputs and outputs, performs the packing and
unpacking of inputs and return values. It implements the
square-bracket indexing so that you can look up the value of a
symbolic node.
Functions are copyable via {{{fn.copy()}}} and
{{{copy.copy(fn)}}}. When a function is copied, this instance is
duplicated. Contrast with self.maker (instance of
`FunctionMaker`) that is shared between copies. The meaning of
copying a function is that the containers and their current values
will all be duplicated. This requires that mutable inputs be
copied, whereas immutable inputs may be shared between copies.
A Function instance is hashable, on the basis of its memory address (its id).
A Function instance is hashable, on the basis of its memory
address (its id).
A Function instance is only equal to itself.
A Function instance may be serialized using the `pickle` or `cPickle` modules.
This will save all default inputs, the graph, and *** to the pickle file (WRITEME).
A Function instance may be serialized using the `pickle` or
`cPickle` modules. This will save all default inputs, the graph,
and *** to the pickle file (WRITEME).
A Function instance have a ``trust_input`` field that default to
False. When True, we don't do extra check of the input to give
......@@ -210,7 +217,6 @@ class Function(object):
the good results if you pass a python or numpy scalar instead of a
numpy tensor. C code should raise an error if you pass an object
of the wrong type.
"""
pickle_aliased_memory_strategy = 'warn'
......@@ -218,12 +224,11 @@ class Function(object):
Meaningful settings are: 'ignore', 'warn', 'raise'
If the value is 'warn', then a message will be printed to stderr if aliased storage is
dectected during pickle.dump.
If the value is 'raise', then an AliasedMemoryError will be raised if aliased storage is
detected during pickle.dump.
If the value is 'warn', then a message will be printed to stderr
if aliased storage is dectected during pickle.dump.
If the value is 'raise', then an AliasedMemoryError will be raised
if aliased storage is detected during pickle.dump.
"""
input_storage = None
......@@ -233,24 +238,28 @@ class Function(object):
"""list of Container instances"""
indices = None
"""list of (SymbolicInput|SymbolicInputKit, indices, [SymbolicInput,...]), one tuple for
each input
"""list of (SymbolicInput|SymbolicInputKit, indices,
[SymbolicInput,...]), one tuple for each input
The first tuple element is the SymbolicInput object for the corresponding function input.
The first tuple element is the SymbolicInput object for the
corresponding function input.
The second and third tuple elements are used only by Kits, which are deprecated.
The second and third tuple elements are used only by Kits, which
are deprecated.
"""
defaults = None
""" list of 3-tuples, one 3-tuple for each input.
Tuple element 0: Bool: Is this input required at each function call?
Tuple element 1: Bool: Should this inputs value be reverted after each call?
Tuple element 1: Bool: Should this inputs value be reverted after
each call?
Tuple element 2: Any: The value associated with this input.
"""
unpack_single = None
"""Bool: for outputs lists of length 1, should the 0'th element be returned directly?"""
"""Bool: for outputs lists of length 1, should the 0'th element be
returned directly?"""
return_none = None
"""Bool: whether the function should return None or not"""
......@@ -259,8 +268,8 @@ class Function(object):
"""FunctionMaker instance"""
fn = None
"""a function that evaluates the graph. Typically a linker's make_thunk method created this
function."""
"""a function that evaluates the graph. Typically a linker's
make_thunk method created this function."""
finder = None
"""Dictionary mapping several kinds of things to containers.
......@@ -273,7 +282,8 @@ class Function(object):
- the name of the input
All entries map to the container or to DUPLICATE if an ambiguity is detected
All entries map to the container or to DUPLICATE if an ambiguity
is detected
"""
inv_finder = None
......@@ -312,20 +322,22 @@ class Function(object):
input.distribute(value, indices, cs)
for c in cs:
c.provided += 1
# def assign(c, v):
#c.data = v
# Store the list of names of named inputs.
named_inputs = []
# Count the number of un-named inputs.
n_unnamed_inputs = 0
#setters = []
# Initialize the storage
# this loop works by modifying the elements (as variable c) of self.input_storage inplace.
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(zip(self.indices, defaults)):
if indices is None: # this is true iff input is not a SymbolicInputKit
c = containers[0] #containers is being used as a stack. Here we pop off the next one.
# this loop works by modifying the elements (as variable c) of
# self.input_storage inplace.
for i, ((input, indices, sinputs), (required, refeed, value)) in \
enumerate(zip(self.indices, defaults)):
# this is true iff input is not a SymbolicInputKit
if indices is None:
# containers is being used as a stack. Here we pop off
# the next one.
c = containers[0]
c.strict = getattr(input, 'strict', False)
c.allow_downcast = getattr(input, 'allow_downcast', None)
......@@ -342,7 +354,9 @@ class Function(object):
c.value = value
c.required = required
c.implicit = input.implicit
c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__)
# this is a count of how many times the input has been
# provided (reinitialized to 0 on __call__)
c.provided = 0
finder[i] = c
finder[input.variable] = c
if input.name not in finder:
......@@ -353,17 +367,14 @@ class Function(object):
n_unnamed_inputs += 1
else:
named_inputs.append(input.name)
# backport
#finder[input.name] = c if input.name not in finder else DUPLICATE
# inv_finder maps the container to the input (useful for one error message)
inv_finder[c] = input
#setters.append(partial(assign, c))
containers[:1] = []
else:
# TODO The following code may need to do something to handle
# implicit inputs.
# The input is a SymbolicInputKit, so we take as many containers as the Kit provides inputs
# The input is a SymbolicInputKit, so we take as many
# containers as the Kit provides inputs
cs = containers[:len(indices)]
# distribute does the initialization of the containers
input.distribute(value, indices, cs)
......@@ -377,12 +388,11 @@ class Function(object):
finder[input.name] = f
else:
finder[input.name] = DUPLICATE
# backport
#finder[input.name] = f if input.name not in finder else DUPLICATE
# setters.append(f)
# For each input in the kit and its corresponding container, we put an entry in finder.
# This allows the user to micro-manage elements of the kit if need be.
# All containers inherit the required field and have their own "provided" counter
# For each input in the kit and its corresponding
# container, we put an entry in finder. This allows
# the user to micro-manage elements of the kit if need
# be. All containers inherit the required field and
# have their own "provided" counter
for c, sin in zip(cs, sinputs):
finder[sin.variable] = c
finder[sin.name] = c
......@@ -390,8 +400,6 @@ class Function(object):
finder[sin.name] = c
else:
finder[sin.name] = DUPLICATE
# backport
#finder[sin.name] = c if sin.name not in finder else DUPLICATE
inv_finder[c] = input
c.required = required
c.provided = 0
......@@ -410,12 +418,14 @@ class Function(object):
except KeyError:
raise TypeError("Unknown input or state: %s" % str(item))
if s is DUPLICATE:
raise TypeError("Ambiguous name: %s - please check the names "\
"of the inputs of your function for duplicates." % str(item))
raise TypeError("Ambiguous name: %s - please check the "
"names of the inputs of your function "
"for duplicates." % str(item))
if isinstance(s, gof.Container):
return s.value
else:
raise NotImplementedError
def __setitem__(self, item, value):
try:
s = finder[item]
......@@ -425,13 +435,15 @@ class Function(object):
raise TypeError("Unknown input or state: %s. %s" %
(str(item), msg))
if s is DUPLICATE:
raise TypeError("Ambiguous name: %s - please check the names "\
"of the inputs of your function for duplicates." % str(item))
raise TypeError("Ambiguous name: %s - please check the "
"names of the inputs of your function "
"for duplicates." % str(item))
if isinstance(s, gof.Container):
s.value = value
s.provided += 1
else:
s(value)
def __contains__(self, item):
return finder.__contains__(item)
......@@ -441,6 +453,7 @@ class Function(object):
class ContainerAttribute(object):
def __getitem__(self, item):
return finder[item]
def __contains__(self, item):
return finder.__contains__(item)
# You cannot set the container
......@@ -513,20 +526,17 @@ class Function(object):
s.storage[0] = arg
else:
try:
s.storage[0] = s.type.filter(arg, strict=s.strict,
allow_downcast=s.allow_downcast)
s.storage[0] = s.type.filter(
arg, strict=s.strict,
allow_downcast=s.allow_downcast)
except Exception as e:
function_name = "theano function"
if self.name:
function_name += ' with name "' + self.name + '" '
# end if
e.args = tuple(["Bad input argument to " + function_name +
" at index %d(0-based)" % i] +
list(e.args))
e.args = ("Bad input argument to " + function_name +
" at index %d(0-based)" % i,) + e.args
raise
# end except
# end if
s.provided += 1
i += 1
......@@ -535,9 +545,8 @@ class Function(object):
for k, arg in kwargs.iteritems():
self[k] = arg
if not self.trust_input and (
not hasattr(self, '_check_for_aliased_inputs') or
self._check_for_aliased_inputs):
if (not self.trust_input and
getattr(self, '_check_for_aliased_inputs', True)):
# Collect aliased inputs among the storage space
args_share_memory = []
for i in xrange(len(self.input_storage)):
......@@ -553,8 +562,8 @@ class Function(object):
[self.input_storage[k].storage[0] for k
in args_share_memory[j]])
if numpy.any([(var.type is i_var.type and
var.type.may_share_memory(val, i_val))
for (var, val) in group_j]):
var.type.may_share_memory(val, i_val))
for (var, val) in group_j]):
is_aliased = True
args_share_memory[j].append(i)
......@@ -566,10 +575,6 @@ class Function(object):
# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# see if any of these arguments are mutable
mutable = numpy.any([(self.maker.inputs[idx].mutable or
self.maker.inputs[idx].borrow)
for idx in group])
# copy all but the first
for idx in group[1:]:
self.input_storage[i].storage[0] = copy.copy(
......@@ -696,13 +701,15 @@ class Function(object):
container = property(
lambda self: self._container,
None, # this property itself is not settable
doc="""dictionary-like access to the containers associated with Variables""")
doc=("dictionary-like access to the containers associated with "
"Variables"))
def free(self):
"""
When allow_gc = False, clear the Variables in storage_map
"""
# 1.no allow_gc return False 2.has allow_gc, if allow_gc is False, return True
# 1.no allow_gc return False
# 2.has allow_gc, if allow_gc is False, return True
if not getattr(self.fn, 'allow_gc', True):
for key in self.fn.storage_map.keys():
if not isinstance(key, theano.gof.Constant):
......@@ -719,7 +726,8 @@ def _pickle_Function(f):
ins = list(f.input_storage)
input_storage = []
for (input, indices, inputs), (required, refeed, default) in zip(f.indices, f.defaults):
for (input, indices, inputs), (required, refeed, default) in \
zip(f.indices, f.defaults):
if isinstance(input, SymbolicInputKit):
li = len(indices)
if not default:
......@@ -734,18 +742,21 @@ def _pickle_Function(f):
inputs_data = [x.data for x in f.input_storage]
# HACK to detect aliased storage.
# This is here because aliased relationships are not [currently] preserved across the pickle operation
# This is here because aliased relationships are not [currently]
# preserved across the pickle operation
if not (f.pickle_aliased_memory_strategy == 'ignore'):
all_data = input_storage + inputs_data # addition here means list append
all_data = input_storage + inputs_data
for i, d_i in enumerate(all_data):
for j, d_j in enumerate(all_data):
if (i < j) and isinstance(d_i, numpy.ndarray) and isinstance(d_j, numpy.ndarray):
if ((i < j) and isinstance(d_i, numpy.ndarray) and
isinstance(d_j, numpy.ndarray)):
if numpy.may_share_memory(d_i, d_j):
if f.pickle_aliased_memory_strategy == 'warn':
_logger.warning(('aliased relationship between'
' Function arguments %s, %s'
' will not be preserved by un-pickling'
' operation') % (str(d_i), str(d_j)))
_logger.warning('aliased relationship between '
'Function arguments %s, %s '
'will not be preserved by '
'un-pickling operation' %
(str(d_i), str(d_j)))
else:
raise AliasedMemoryError(d_i, d_j)
rval = (_constructor_Function, (f.maker, input_storage, inputs_data))
......@@ -774,20 +785,25 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
"""
Insert deepcopy in the fgraph to break aliasing of outputs
"""
# This loop was inserted to remove aliasing between outputs when they all
# evaluete to the same value. Originally it was OK for outputs to be aliased,
# but some of the outputs can be shared variables, and is not good for shared
# variables to be aliased. It might be possible to optimize this by making sure
# This loop was inserted to remove aliasing between outputs when
# they all evaluete to the same value. Originally it was OK for
# outputs to be aliased, but some of the outputs can be shared
# variables, and is not good for shared variables to be
# aliased. It might be possible to optimize this by making sure
# there is no aliasing only between shared variables.
# If some outputs are constant, we add deep copy to respect the memory contract
# If some outputs are constant, we add deep copy to respect the
# memory contract
# We don't insert deep copy when the output.borrow is True for all conserned outputs.
# We don't insert deep copy when the output.borrow is True for all
# conserned outputs.
assert len(wrapped_inputs) == len(fgraph.inputs)
assert len(wrapped_outputs) == len(fgraph.outputs)
reason = "insert_deepcopy"
updated_fgraph_inputs = [fgraph_i for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs) if getattr(i, 'update', False)]
updated_fgraph_inputs = [fgraph_i for i, fgraph_i in
zip(wrapped_inputs, fgraph.inputs)
if getattr(i, 'update', False)]
# We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs)
......@@ -802,43 +818,54 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# and not(wrapped_outputs[i].borrow and wrapped_outputs[j].borrow):
if fgraph.outputs[j] in views_of_output_i:
if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
fgraph.change_input('output', i, view_op(fgraph.outputs[i]),
reason=reason)
fgraph.change_input('output', i,
view_op(fgraph.outputs[i]),
reason=reason)
else:
fgraph.change_input('output', i, deep_copy_op(fgraph.outputs[i]),
reason=reason)
fgraph.change_input('output', i,
deep_copy_op(fgraph.outputs[i]),
reason=reason)
copied = True
break
if not copied:
for input_j in all_graph_inputs:
# do not allow outputs to be aliased to an inputs (j), unless
# a) that j'th input has been 'destroyed' by e.g. in-place computations
# b) that j'th input is a shared variable that is also being updated
# a) that j'th input has been 'destroyed' by
# e.g. in-place computations
# b) that j'th input is a shared variable that is also
# being updated
if (hasattr(fgraph, 'get_destroyers_of') and
fgraph.get_destroyers_of(input_j)):
fgraph.get_destroyers_of(input_j)):
continue
if input_j in updated_fgraph_inputs:
continue
if input_j in views_of_output_i:
# We don't put deep_copy_op if the input and the output have borrow==True
# We don't put deep_copy_op if the input and the
# output have borrow==True
if input_j in fgraph.inputs:
j = fgraph.inputs.index(input_j)
if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow:
fgraph.change_input('output', i, view_op(fgraph.outputs[i]),
reason="insert_deepcopy")
if (wrapped_outputs[i].borrow and
wrapped_inputs[j].borrow):
fgraph.change_input('output', i,
view_op(fgraph.outputs[i]),
reason="insert_deepcopy")
break
else:
fgraph.change_input('output', i, deep_copy_op(fgraph.outputs[i]),
reason="insert_deepcopy")
fgraph.change_input(
'output', i,
deep_copy_op(fgraph.outputs[i]),
reason="insert_deepcopy")
break
elif wrapped_outputs[i].borrow:
fgraph.change_input('output', i, view_op(fgraph.outputs[i]),
reason="insert_deepcopy")
fgraph.change_input('output', i,
view_op(fgraph.outputs[i]),
reason="insert_deepcopy")
break
else:
fgraph.change_input('output', i, deep_copy_op(fgraph.outputs[i]),
reason="insert_deepcopy")
fgraph.change_input('output', i,
deep_copy_op(fgraph.outputs[i]),
reason="insert_deepcopy")
break
NODEFAULT = ['NODEFAULT']
......@@ -866,17 +893,20 @@ class FunctionMaker(object):
if len(input) == 2:
return SymbolicInput(input[0], update=input[1])
else:
raise TypeError("Expected two elements in the list or tuple.", input)
raise TypeError("Expected two elements in the list or tuple.",
input)
else:
raise TypeError("Unknown input type: %s (%s), expected Variable instance", type(input), input)
raise TypeError("Unknown input type: %s (%s), expected Variable "
"instance", type(input), input)
@staticmethod
def expand_in(sinput, rinputs):
# For SymbolicInputKits, this extracts a list of SymbolicInput instances
# and corresponding indices such that these SymbolicInputs are representative
# of some of the Variable instances in inputs.
# For SymbolicInput, this returns None as the list of indices and a list with
# just the SymbolicInput.
# For SymbolicInputKits, this extracts a list of SymbolicInput
# instances and corresponding indices such that these
# SymbolicInputs are representative of some of the Variable
# instances in inputs. For SymbolicInput, this returns None
# as the list of indices and a list with just the
# SymbolicInput.
if isinstance(sinput, SymbolicInputKit):
return sinput.complete(rinputs)
elif isinstance(sinput, SymbolicInput):
......@@ -889,24 +919,25 @@ class FunctionMaker(object):
elif isinstance(output, gof.Variable):
return SymbolicOutput(output)
else:
raise TypeError("Unknown output type: %s (%s)", type(output), output)
raise TypeError("Unknown output type: %s (%s)", type(output),
output)
def optimize_graph_with_cache(self, optimizer, inputs, outputs):
# This function is not finished
from theano.gof.compilelock import get_lock, release_lock
import os.path
graph_db_file = os.path.join(theano.config.compiledir, 'optimized_graphs.pkl')
graph_db_file = os.path.join(theano.config.compiledir,
'optimized_graphs.pkl')
# the inputs, outputs, and size of the graph to be optimized
inputs_new = [inp.variable for inp in inputs]
outputs_new = [out.variable for out in outputs]
size_new = len(self.fgraph.apply_nodes)
need_optimize = False
get_lock()
key = None
# Beginning of cache optimizations.
# Could be refactored in different functions.
def load_graph_db():
if os.path.isfile(graph_db_file):
print('graph_db already exists')
......@@ -919,8 +950,9 @@ class FunctionMaker(object):
# load the graph_db dictionary
try:
f = open(graph_db_file, 'rb')
# Temporary hack to allow theano.scan_module.tests.test_scan.T_Scan
# to finish. Should be changed in definitive version.
# Temporary hack to allow
# theano.scan_module.tests.test_scan.T_Scan to
# finish. Should be changed in definitive version.
tmp = theano.config.unpickle_function
theano.config.unpickle_function = False
graph_db = cPickle.load(f)
......@@ -961,16 +993,21 @@ class FunctionMaker(object):
# two graphs are for sure different
print('need to optimize, because output size is different')
continue
elif not all(input_new.type == input_old.type for
input_new, input_old in zip(inputs_new, inputs_old)):
print('need to optimize, because inputs are of different types')
elif not all(input_new.type == input_old.type
for input_new, input_old in
zip(inputs_new, inputs_old)):
print('need to optimize, because inputs are of different '
'types')
continue
elif not all(output_new.type == output_old.type for
output_new, output_old in zip(outputs_new, outputs_old)):
print('need to optimize, because outputs are of different types')
elif not all(output_new.type == output_old.type
for output_new, output_old in
zip(outputs_new, outputs_old)):
print('need to optimize, because outputs are of different '
'types')
continue
elif not size_old == size_new:
print('need to optimize, because numbers of nodes in graph are different')
print('need to optimize, because numbers of nodes in graph'
' are different')
continue
else:
flags = []
......@@ -1007,10 +1044,10 @@ class FunctionMaker(object):
t2 = removeAllFgraph(t2)
givens = dict(zip(gof.graph.inputs([t1]),
gof.graph.inputs([t2])))
gof.graph.inputs([t2])))
temp = dict(zip(gof.graph.inputs([t1]),
gof.graph.inputs([t2])))
gof.graph.inputs([t2])))
# hack to remove inconstent entry in givens
# seems to work that but source of inconsistency
......@@ -1032,7 +1069,8 @@ class FunctionMaker(object):
return found_graph_in_db
graph_db = load_graph_db()
print('loaded graph_db from %s, size=%d' % (graph_db_file, len(graph_db)))
print('loaded graph_db from %s, size=%d' % (graph_db_file,
len(graph_db)))
found_graph = find_same_graph_in_db(graph_db)
if found_graph:
self.fgraph = found_graph
......@@ -1043,7 +1081,7 @@ class FunctionMaker(object):
self.fgraph.variables = set(gof.graph.variables(
self.fgraph.inputs, self.fgraph.outputs))
# check_integrity parameters was added to ignore
#"excess cached variables" errors. Works that way
# "excess cached variables" errors. Works that way
# but once again the error couldbe worth
# investigating.
before_opt = self.fgraph.clone(check_integrity=False)
......@@ -1057,22 +1095,24 @@ class FunctionMaker(object):
return optimizer_profile
def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None,
output_keys=None):
mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None,
output_keys=None):
"""
:type inputs: a list of SymbolicInput instances
:type outputs: a list of SymbolicOutput instances
outputs may also be a single Variable (not a list), in which
case the functions produced by FunctionMaker will return
their output value directly
:param mode: a Mode instance telling FunctionMaker how to optimize and link. None
means to use the `config.mode`.
:type outputs: a list of SymbolicOutput instances outputs may
also be a single Variable (not a list), in which case the
functions produced by FunctionMaker will return their
output value directly
:param mode: a Mode instance telling FunctionMaker how to
optimize and link. None means to use the `config.mode`.
:param accept_inplace: True iff it is acceptable to have inplace operations
in the graph from the inputs to the outputs
:param accept_inplace: True iff it is acceptable to have
inplace operations in the graph from the inputs to the
outputs
:param on_unused_input: What to do if a variable in the 'inputs' list
is not used in the graph. Possible values are:
......@@ -1089,18 +1129,19 @@ class FunctionMaker(object):
# using this somewhat awkward mechanism.
mode_profile = getattr(mode, 'profile', None)
if (profile is not None and
profile is not False and
mode_profile is not None):
profile is not False and
mode_profile is not None):
raise TypeError(
'profile passed via both "mode" and "profile" arguments')
'profile passed via both "mode" and "profile" arguments')
self.profile = profile = profile or mode_profile
if profile:
# This is very important:
# 1) We preload the cache here to don't have its timming
# included in optimization that compile function.
# 2) Do not refresh the cache here by default. It cause too much
# execution time during testing as we compile much more functions
# then the number of compile c module.
# 2) Do not refresh the cache here by default. It cause
# too much execution time during testing as we compile
# much more functions then the number of compile c
# module.
theano.gof.cc.get_module_cache().refresh()
# Handle the case where inputs and/or outputs is a single
# Variable (not in a list)
......@@ -1117,21 +1158,27 @@ class FunctionMaker(object):
inputs = [inputs]
# Wrap them in In or Out instances if needed.
inputs, outputs = map(self.wrap_in, inputs), map(self.wrap_out, outputs)
_inputs = gof.graph.inputs([o.variable for o in outputs] + [i.update
for i in inputs if getattr(i, 'update', False)])
inputs = map(self.wrap_in, inputs)
outputs = map(self.wrap_out, outputs)
_inputs = gof.graph.inputs([o.variable for o in outputs] +
[i.update for i in inputs
if getattr(i, 'update', False)])
# Check if some input variables are unused
self._check_unused_inputs(inputs, outputs, on_unused_input)
# Make a list of (SymbolicInput|SymblicInputKits, indices, [SymbolicInput,...]), one
# tuple for each input. (See Function.indices for more details)
indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
# Make a list of (SymbolicInput|SymblicInputKits, indices,
# [SymbolicInput,...]), one tuple for each input. (See
# Function.indices for more details)
indices = [[input] + self.expand_in(input, _inputs)
for input in inputs]
if fgraph is None:
need_opt = True
# make the fgraph (copies the graph, creates NEW INPUT AND OUTPUT VARIABLES)
fgraph, additional_outputs = std_fgraph(inputs, outputs, accept_inplace)
# make the fgraph (copies the graph, creates NEW INPUT AND
# OUTPUT VARIABLES)
fgraph, additional_outputs = std_fgraph(inputs, outputs,
accept_inplace)
fgraph.profile = profile
else:
# fgraph is already an optimized one
......@@ -1149,7 +1196,8 @@ class FunctionMaker(object):
# Why we add stack on node when it get done in output var?
try:
# optimize the fgraph
theano.config.compute_test_value = theano.config.compute_test_value_opt
theano.config.compute_test_value = \
theano.config.compute_test_value_opt
theano.config.traceback.limit = 0
start_optimizer = time.time()
......@@ -1165,7 +1213,8 @@ class FunctionMaker(object):
if profile:
profile.optimizer_time += opt_time
if theano.config.profile_optimizer:
profile.optimizer_profile = (optimizer, optimizer_profile)
profile.optimizer_profile = (optimizer,
optimizer_profile)
_logger.debug('Optimizing took %f seconds', opt_time)
# Add deep copy to respect the memory interface
......@@ -1176,21 +1225,26 @@ class FunctionMaker(object):
# initialize the linker
if not hasattr(linker, 'accept'):
raise ValueError("'linker' parameter of FunctionMaker should be a Linker with an accept method " \
"or one of %s" % theano.compile.mode.predefined_linkers.keys())
raise ValueError("'linker' parameter of FunctionMaker should be "
"a Linker with an accept method or one of %s" %
theano.compile.mode.predefined_linkers.keys())
# the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer.
# the 'no_borrow' outputs are the ones for which that we can't
# return the internal storage pointer.
assert len(fgraph.outputs) == len(outputs + additional_outputs)
no_borrow = [output for output, spec in zip(fgraph.outputs, outputs + additional_outputs) if not spec.borrow]
no_borrow = [output for output, spec in
zip(fgraph.outputs, outputs + additional_outputs)
if not spec.borrow]
if no_borrow:
self.linker = linker.accept(fgraph, no_recycling=infer_reuse_pattern(fgraph, no_borrow))
self.linker = linker.accept(
fgraph, no_recycling=infer_reuse_pattern(fgraph, no_borrow))
else:
self.linker = linker.accept(fgraph)
if hasattr(linker, 'accept_var_updates'):
# hacky thing so VMLinker knows about updates
self.linker.accept_var_updates(
fgraph_updated_vars(fgraph, inputs))
fgraph_updated_vars(fgraph, inputs))
self.indices = indices
self.inputs = inputs
......@@ -1206,11 +1260,10 @@ class FunctionMaker(object):
self.required = [(i.value is None) for i in self.inputs]
self.refeed = [
(i.value is not None and
not isinstance(i.value, gof.Container) and
i.update is None)
for i in self.inputs
]
(i.value is not None and
not isinstance(i.value, gof.Container) and
i.update is None)
for i in self.inputs]
def _check_unused_inputs(self, inputs, outputs, on_unused_input):
if on_unused_input is None:
......@@ -1223,57 +1276,64 @@ class FunctionMaker(object):
# - variables that have to be provided (used_inputs)
# - shared variables that will be updated
used_inputs = gof.graph.ancestors(
([o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, 'update', False)]),
blockers=[i.variable for i in inputs])
([o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, 'update', False)]),
blockers=[i.variable for i in inputs])
msg = ("theano.function was asked to create a function computing "
"outputs given certain inputs, but the provided input "
"variable at index %i is not part of the computational graph "
"needed to compute the outputs: %s.\n%s")
"outputs given certain inputs, but the provided input "
"variable at index %i is not part of the computational graph "
"needed to compute the outputs: %s.\n%s")
warn_msg = ("To make this warning into an error, you can pass the "
"parameter on_unused_input='raise' to theano.function. "
"To disable it completely, use on_unused_input='ignore'.")
"parameter on_unused_input='raise' to theano.function. "
"To disable it completely, use on_unused_input='ignore'.")
err_msg = ("To make this error into a warning, you can pass the "
"parameter on_unused_input='warn' to theano.function. "
"To disable it completely, use on_unused_input='ignore'.")
"parameter on_unused_input='warn' to theano.function. "
"To disable it completely, use on_unused_input='ignore'.")
for i in inputs:
if ((i.variable not in used_inputs) and (i.update is None)):
if on_unused_input == 'warn':
warnings.warn(msg % (inputs.index(i), i.variable, warn_msg), stacklevel=6)
warnings.warn(msg % (inputs.index(i), i.variable,
warn_msg), stacklevel=6)
elif on_unused_input == 'raise':
raise UnusedInputError(msg % (inputs.index(i), i.variable, err_msg))
raise UnusedInputError(msg % (inputs.index(i),
i.variable, err_msg))
else:
raise ValueError(("Invalid value for keyword "
"on_unused_input of theano.function: '%s'. "
"valid values are 'raise', 'warn', and 'ignore'."
% on_unused_input))
raise ValueError("Invalid value for keyword "
"on_unused_input of theano.function: "
"'%s'.\nValid values are 'raise', "
"'warn', and 'ignore'." % on_unused_input)
def create(self, input_storage=None, trustme=False):
"""
Create a function.
input_storage -> a list matching the inputs list and providing default values
if the default for an input is None, then that input is a
required input. For an input with an update, the default
acts as initialization.
input_storage -> a list matching the inputs list and providing
default values if the default for an input is
None, then that input is a required input. For an
input with an update, the default acts as
initialization.
trustme -> disables some exceptions, used internally
"""
if input_storage is None:
input_storage = [None] * len(self.inputs)
input_storage_lists = [] # list of independent one-element lists, will be passed to the linker
# list of independent one-element lists, will be passed to the linker
input_storage_lists = []
defaults = []
# The following loop is to fill in the input_storage_lists and defaults lists.
# The following loop is to fill in the input_storage_lists and
# defaults lists.
assert len(self.indices) == len(input_storage)
for i, ((input, indices, subinputs), input_storage_i) in enumerate(zip(self.indices, input_storage)):
# Replace any default value given as a variable by its container.
# Note that this makes sense only in the context of shared variables,
# but for now we avoid dealing directly with them to avoid dependency
# on the shared variables work-in-progress repository.
for i, ((input, indices, subinputs), input_storage_i) in \
enumerate(zip(self.indices, input_storage)):
# Replace any default value given as a variable by its
# container. Note that this makes sense only in the
# context of shared variables, but for now we avoid
# dealing directly with them to avoid dependency on the
# shared variables work-in-progress repository.
if isinstance(input_storage_i, gof.Variable):
input_storage_i = input_storage_i.container
......@@ -1282,7 +1342,8 @@ class FunctionMaker(object):
# share the same storage. This is done by appending
# input_storage_i.storage to input_storage_lists.
if indices is not None:
raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
raise TypeError("Cannot take a Container instance as "
"default for a SymbolicInputKit.")
input_storage_lists.append(input_storage_i.storage)
storage = input_storage[i].storage[0]
......@@ -1295,7 +1356,8 @@ class FunctionMaker(object):
required = self.required[i]
refeed = self.refeed[i]
# sanity check-- if an input is required it should not need to be refed
# sanity check-- if an input is required it should not
# need to be refed
assert not (required and refeed)
# shared variables need neither be input by the user nor refed
......@@ -1312,9 +1374,7 @@ class FunctionMaker(object):
if storage is not None:
assert refeed or not required
defaults.append((required,
refeed,
storage))
defaults.append((required, refeed, storage))
# Get a function instance
start_linker = time.time()
......@@ -1338,22 +1398,23 @@ class FunctionMaker(object):
self.profile.import_time += import_time
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single, self.return_none, self.output_keys, self)
defaults, self.unpack_single,
self.return_none, self.output_keys, self)
fn.profile = self.profile
return fn
def _pickle_FunctionMaker(self):
kwargs = dict(
inputs=self.inputs,
outputs=self.orig_outputs,
fgraph=self.fgraph,
mode=self.mode,
accept_inplace=self.accept_inplace,
function_builder=self.function_builder,
profile=self.profile,
on_unused_input=self.on_unused_input,
)
inputs=self.inputs,
outputs=self.orig_outputs,
fgraph=self.fgraph,
mode=self.mode,
accept_inplace=self.accept_inplace,
function_builder=self.function_builder,
profile=self.profile,
on_unused_input=self.on_unused_input,
)
return (_constructor_FunctionMaker, (kwargs,))
......@@ -1367,19 +1428,6 @@ def _constructor_FunctionMaker(kwargs):
copy_reg.pickle(FunctionMaker, _pickle_FunctionMaker)
try:
# Pickle of slice is implemented on python 2.6. To enabled be
# compatible with python 2.4, we implement pickling of slice
# ourself.
cPickle.dumps(slice(0, 10, 100))
except TypeError:
# This slice pickle implementation seam backward and forward compatible.
def _pickle_slice(s):
return (slice, (s.start, s.stop, s.step))
copy_reg.pickle(slice, _pickle_slice)
__checkers = []
......@@ -1390,7 +1438,6 @@ def check_equal(x, y):
except Exception:
continue
return x == y
#raise Exception('No checker for equality between %s and %s' % (x, y))
def register_checker(checker):
......@@ -1405,10 +1452,10 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
:param inputs: list of `SymbolicInput` or `In` instances
:param outputs: a SymbolicOutput or a list of `SymbolicOutput` or `Out`
instances. The return value of the returned function will match the
format of this argument (either the value itself or a list of one or more
return values)
:param outputs: a SymbolicOutput or a list of `SymbolicOutput` or
`Out` instances. The return value of the returned function
will match the format of this argument (either the value
itself or a list of one or more return values)
:param mode: a descriptive string or a Mode instance. (Default of None
means to use `config.mode` (See below for descriptive string list).
......@@ -1422,7 +1469,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
- FAST_COMPILE (minimal optimization)
- ProfileMode(deprecated): allow to print a profile mode with mode.print_summary
- ProfileMode(deprecated): allow to print a profile mode with
mode.print_summary
- DebugMode: verify many internal conditions that are normally assumed
(slow)
......@@ -1471,8 +1519,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
accept_inplace=accept_inplace,
profile=profile,
on_unused_input=on_unused_input,
output_keys = output_keys).create(
defaults)
output_keys=output_keys).create(
defaults)
t2 = time.time()
if profile:
......@@ -1552,7 +1600,7 @@ def convert_function_input(input):
raise TypeError("Unknown update type: %s, expected Variable "
"instance" % type(update), update)
if (value is not None and
isinstance(value, (gof.Variable, SymbolicInput))):
isinstance(value, (gof.Variable, SymbolicInput))):
raise TypeError("The value for input %s should not be a Variable "
"or SymbolicInput instance (got: %s)" %
(variable, value))
......@@ -1579,26 +1627,26 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs):
else:
if n_unnamed_inputs == 1:
msg = ("The function has a single input variable which has no "
"name, and thus cannot be assigned through a keyword"
" argument (use 'name=...' in a Variable's "
"constructor to give it a name).")
"name, and thus cannot be assigned through a keyword"
" argument (use 'name=...' in a Variable's "
"constructor to give it a name).")
else:
# Use plural.
msg = ("The function has %s inputs, but none of them is named,"
" and thus they cannot be assigned through keyword "
"arguments (use 'name=...' in a Variable's "
"constructor to give it a name)." % n_unnamed_inputs)
" and thus they cannot be assigned through keyword "
"arguments (use 'name=...' in a Variable's "
"constructor to give it a name)." % n_unnamed_inputs)
else:
if n_unnamed_inputs == 0:
msg = ("The function has %s named input%s (%s)." % (
n_named_inputs, get_plural(n_named_inputs),
', '.join(named_inputs)))
msg = ("The function has %s named input%s (%s)." %
(n_named_inputs, get_plural(n_named_inputs),
', '.join(named_inputs)))
else:
msg = ("The function has %s named input%s (%s), and %s unnamed "
"input%s which thus cannot be accessed through keyword "
"argument%s (use 'name=...' in a variable's constructor "
"to give it a name)." % (
n_named_inputs, get_plural(n_named_inputs),
"input%s which thus cannot be accessed through keyword "
"argument%s (use 'name=...' in a variable's constructor "
"to give it a name)." %
(n_named_inputs, get_plural(n_named_inputs),
', '.join(named_inputs), n_unnamed_inputs,
get_plural(n_unnamed_inputs),
get_plural(n_unnamed_inputs)))
......
......@@ -40,7 +40,6 @@ whitelist_flake8 = [
"tests/unittest_tools.py",
"compile/__init__.py",
"compile/profiling.py",
"compile/function_module.py",
"compile/monitormode.py",
"compile/tests/test_builders.py",
"compile/tests/test_misc.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论