提交 0f86ecd9 authored 作者: lamblin's avatar lamblin

Merge pull request #424 from delallea/minor

Minor stuff (PEP8 and typos mostly)
...@@ -5,7 +5,8 @@ from theano import gof ...@@ -5,7 +5,8 @@ from theano import gof
from sharedvalue import SharedVariable from sharedvalue import SharedVariable
import logging import logging
_logger=logging.getLogger("theano.compile.io") _logger = logging.getLogger("theano.compile.io")
class SymbolicInput(object): class SymbolicInput(object):
""" """
...@@ -194,8 +195,7 @@ class In(SymbolicInput): ...@@ -194,8 +195,7 @@ class In(SymbolicInput):
# try to keep it synchronized. # try to keep it synchronized.
def __init__(self, variable, name=None, value=None, update=None, def __init__(self, variable, name=None, value=None, update=None,
mutable=None, strict=False, allow_downcast=None, autoname=True, mutable=None, strict=False, allow_downcast=None, autoname=True,
implicit=None, borrow=None, shared = False): implicit=None, borrow=None, shared=False):
#if shared, an input's value comes from its persistent storage, not from a default stored #if shared, an input's value comes from its persistent storage, not from a default stored
#in the function or from the caller #in the function or from the caller
...@@ -206,7 +206,7 @@ class In(SymbolicInput): ...@@ -206,7 +206,7 @@ class In(SymbolicInput):
# mutable=True should require borrow=True. Raise warning when borrow is explicitely set # mutable=True should require borrow=True. Raise warning when borrow is explicitely set
# to False with mutable=True. # to False with mutable=True.
if mutable: if mutable:
if borrow==False: if borrow == False:
_logger.warning("Symbolic input for variable %s (name=%s) has " _logger.warning("Symbolic input for variable %s (name=%s) has "
"flags mutable=True, borrow=False. This combination is " "flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the " "incompatible since mutable=True implies that the "
......
"""Provide a simple user friendly API """ """Provide a simple user friendly API """
__docformat__ = 'restructuredtext en' __docformat__ = 'restructuredtext en'
import numpy # for backport to 2.4, to get any().
from profiling import ProfileStats from profiling import ProfileStats
from theano.gof import Container, Variable, generic, graph, Constant, Value
from theano import config
from theano.compile import orig_function, In, Out from theano.compile import orig_function, In, Out
from theano.compile.sharedvalue import SharedVariable, shared from theano.compile.sharedvalue import SharedVariable, shared
from theano import config from theano.gof import Container, Variable, generic, graph, Constant, Value
from theano.gof.python25 import any
import logging import logging
_logger=logging.getLogger("theano.compile.pfunc") _logger = logging.getLogger("theano.compile.pfunc")
def rebuild_collect_shared( outputs
, inputs = None def rebuild_collect_shared(outputs,
, replace = None inputs=None,
, updates = None replace=None,
, rebuild_strict = True updates=None,
, copy_inputs_over = True rebuild_strict=True,
, no_default_updates = False copy_inputs_over=True,
no_default_updates=False,
): ):
""" """
Function that allows replacing subgraphs of a computational Function that allows replacing subgraphs of a computational
...@@ -60,7 +63,7 @@ def rebuild_collect_shared( outputs ...@@ -60,7 +63,7 @@ def rebuild_collect_shared( outputs
""" """
if isinstance(outputs,tuple): if isinstance(outputs, tuple):
outputs = list(outputs) outputs = list(outputs)
## This function implements similar functionality as graph.clone ## This function implements similar functionality as graph.clone
...@@ -71,7 +74,6 @@ def rebuild_collect_shared( outputs ...@@ -71,7 +74,6 @@ def rebuild_collect_shared( outputs
# list of shared inputs that are used as inputs of the graph # list of shared inputs that are used as inputs of the graph
shared_inputs = [] shared_inputs = []
def clone_v_get_shared_updates(v, copy_inputs_over): def clone_v_get_shared_updates(v, copy_inputs_over):
''' '''
Clones a variable and its inputs recursively until all are in Clones a variable and its inputs recursively until all are in
...@@ -88,36 +90,34 @@ def rebuild_collect_shared( outputs ...@@ -88,36 +90,34 @@ def rebuild_collect_shared( outputs
return clone_d[v] return clone_d[v]
if v.owner: if v.owner:
clone_a(v.owner, copy_inputs_over) clone_a(v.owner, copy_inputs_over)
return clone_d.setdefault(v,v) return clone_d.setdefault(v, v)
elif isinstance(v, SharedVariable): elif isinstance(v, SharedVariable):
if v not in shared_inputs: if v not in shared_inputs:
shared_inputs.append(v) shared_inputs.append(v)
if hasattr(v, 'default_update'): if hasattr(v, 'default_update'):
# Check that v should not be excluded from the default # Check that v should not be excluded from the default
# updates list # updates list
if ( no_default_updates is False or if (no_default_updates is False or
( isinstance(no_default_updates, list) and (isinstance(no_default_updates, list) and
v not in no_default_updates v not in no_default_updates)):
)
):
# Do not use default_update if a "real" update was # Do not use default_update if a "real" update was
# provided # provided
if v not in update_d: if v not in update_d:
v_update = v.type.filter_variable(v.default_update) v_update = v.type.filter_variable(v.default_update)
if v_update.type != v.type: if v_update.type != v.type:
raise TypeError( raise TypeError(
( 'an update must have the same type as ' 'an update must have the same type as '
'the original shared variable' ) 'the original shared variable',
, (v, v.type, v_update, v_update.type)) (v, v.type, v_update, v_update.type))
update_d[v] = v_update update_d[v] = v_update
update_expr.append((v, v_update)) update_expr.append((v, v_update))
if not copy_inputs_over or (isinstance(v, Constant) and if not copy_inputs_over or (isinstance(v, Constant) and
hasattr(v,'env')): hasattr(v, 'env')):
### Cloning shared variables implies copying their underlying ### Cloning shared variables implies copying their underlying
### memory buffer ?? No. ### memory buffer ?? No.
return clone_d.setdefault(v,v.clone()) return clone_d.setdefault(v, v.clone())
else: else:
return clone_d.setdefault(v,v) return clone_d.setdefault(v, v)
def clone_a(a, copy_inputs_over): def clone_a(a, copy_inputs_over):
''' '''
...@@ -132,12 +132,11 @@ def rebuild_collect_shared( outputs ...@@ -132,12 +132,11 @@ def rebuild_collect_shared( outputs
clone_d[a] = a.clone_with_new_inputs([clone_d[i] for i in clone_d[a] = a.clone_with_new_inputs([clone_d[i] for i in
a.inputs], a.inputs],
strict = rebuild_strict) strict=rebuild_strict)
for old_o, new_o in zip(a.outputs, clone_d[a].outputs): for old_o, new_o in zip(a.outputs, clone_d[a].outputs):
clone_d.setdefault(old_o,new_o) clone_d.setdefault(old_o, new_o)
return clone_d[a] return clone_d[a]
# intialize the clone_d mapping with the replace dictionary # intialize the clone_d mapping with the replace dictionary
if replace is None: if replace is None:
replace = [] replace = []
...@@ -147,9 +146,9 @@ def rebuild_collect_shared( outputs ...@@ -147,9 +146,9 @@ def rebuild_collect_shared( outputs
replace_pairs = replace replace_pairs = replace
for v_orig, v_repl in replace_pairs: for v_orig, v_repl in replace_pairs:
if not isinstance(v_orig,Variable): if not isinstance(v_orig, Variable):
raise TypeError('given keys must be Variable', v_orig) raise TypeError('given keys must be Variable', v_orig)
if not isinstance(v_repl,Variable): if not isinstance(v_repl, Variable):
v_repl = shared(v_repl) v_repl = shared(v_repl)
assert v_orig not in clone_d assert v_orig not in clone_d
clone_d[v_orig] = clone_v_get_shared_updates(v_repl, clone_d[v_orig] = clone_v_get_shared_updates(v_repl,
...@@ -160,9 +159,9 @@ def rebuild_collect_shared( outputs ...@@ -160,9 +159,9 @@ def rebuild_collect_shared( outputs
def clone_inputs(i): def clone_inputs(i):
if not copy_inputs_over: if not copy_inputs_over:
return clone_d.setdefault(i,i.clone()) return clone_d.setdefault(i, i.clone())
else: else:
return clone_d.setdefault(i,i) return clone_d.setdefault(i, i)
input_variables = [clone_inputs(i) for i in inputs] input_variables = [clone_inputs(i) for i in inputs]
...@@ -171,7 +170,7 @@ def rebuild_collect_shared( outputs ...@@ -171,7 +170,7 @@ def rebuild_collect_shared( outputs
# it is also not clear when/how to use the value of that shared # it is also not clear when/how to use the value of that shared
# variable (is it a default? ignored?, if the shared variable changes, # variable (is it a default? ignored?, if the shared variable changes,
# does that function default also change?). # does that function default also change?).
if numpy.any([isinstance(v, SharedVariable) for v in input_variables]): if any([isinstance(v, SharedVariable) for v in input_variables]):
raise TypeError(('Cannot use a shared variable (%s) as explicit ' raise TypeError(('Cannot use a shared variable (%s) as explicit '
'input. Consider substituting a non-shared' 'input. Consider substituting a non-shared'
' variable via the `givens` parameter') % v) ' variable via the `givens` parameter') % v)
...@@ -181,25 +180,25 @@ def rebuild_collect_shared( outputs ...@@ -181,25 +180,25 @@ def rebuild_collect_shared( outputs
updates = [] updates = []
for (store_into, update_val) in iter_over_pairs(updates): for (store_into, update_val) in iter_over_pairs(updates):
if not isinstance(store_into, SharedVariable): if not isinstance(store_into, SharedVariable):
raise TypeError('update target must be a SharedVariable' raise TypeError('update target must be a SharedVariable',
, store_into) store_into)
if store_into in update_d: if store_into in update_d:
raise ValueError(('this shared variable already has an update ' raise ValueError('this shared variable already has an update '
'expression'), 'expression',
(store_into, update_d[store_into])) (store_into, update_d[store_into]))
# filter_variable ensure smooth conversion of cpu/gpu Types # filter_variable ensure smooth conversion of cpu/gpu Types
update_val = store_into.type.filter_variable(update_val) update_val = store_into.type.filter_variable(update_val)
if update_val.type != store_into.type: if update_val.type != store_into.type:
err_msg = ( 'an update must have the same type as the ' err_msg = ('an update must have the same type as the '
'original shared variable(dest, dest.type, ' 'original shared variable(dest, dest.type, '
'update_val, update_val.type)') 'update_val, update_val.type)')
err_arg = ( store_into err_arg = (store_into,
, store_into.type store_into.type,
, update_val update_val,
, update_val.type) update_val.type)
raise TypeError(err_msg, err_arg ) raise TypeError(err_msg, err_arg)
update_d[store_into] = update_val update_d[store_into] = update_val
update_expr.append((store_into, update_val)) update_expr.append((store_into, update_val))
...@@ -215,8 +214,8 @@ def rebuild_collect_shared( outputs ...@@ -215,8 +214,8 @@ def rebuild_collect_shared( outputs
copy_inputs_over) copy_inputs_over)
cloned_outputs.append(Out(cloned_v, borrow=v.borrow)) cloned_outputs.append(Out(cloned_v, borrow=v.borrow))
else: else:
raise TypeError( ( 'outputs must be theano Variable or ' raise TypeError('outputs must be theano Variable or '
'Out instances'), v) 'Out instances', v)
#computed_list.append(cloned_v) #computed_list.append(cloned_v)
else: else:
if isinstance(outputs, Variable): if isinstance(outputs, Variable):
...@@ -231,10 +230,9 @@ def rebuild_collect_shared( outputs ...@@ -231,10 +230,9 @@ def rebuild_collect_shared( outputs
elif outputs is None: elif outputs is None:
cloned_outputs = [] # TODO: get Function.__call__ to return None cloned_outputs = [] # TODO: get Function.__call__ to return None
else: else:
raise TypeError( ('output must be a theano Variable or Out ' raise TypeError('output must be a theano Variable or Out '
'instance (or list of them)') 'instance (or list of them)',
, outputs) outputs)
# Iterate over update_expr, cloning its elements, and updating # Iterate over update_expr, cloning its elements, and updating
# shared_inputs, update_d and update_expr from the SharedVariables # shared_inputs, update_d and update_expr from the SharedVariables
...@@ -244,7 +242,7 @@ def rebuild_collect_shared( outputs ...@@ -244,7 +242,7 @@ def rebuild_collect_shared( outputs
# Note: we extend update_expr while iterating over it. # Note: we extend update_expr while iterating over it.
i = 0 i = 0
while i<len(update_expr): while i < len(update_expr):
v, v_update = update_expr[i] v, v_update = update_expr[i]
cloned_v_update = clone_v_get_shared_updates(v_update, cloned_v_update = clone_v_get_shared_updates(v_update,
copy_inputs_over) copy_inputs_over)
...@@ -253,12 +251,13 @@ def rebuild_collect_shared( outputs ...@@ -253,12 +251,13 @@ def rebuild_collect_shared( outputs
shared_inputs.append(v) shared_inputs.append(v)
i += 1 i += 1
return ( input_variables, cloned_outputs return (input_variables, cloned_outputs,
, [clone_d, update_d, update_expr, shared_inputs] ) [clone_d, update_d, update_expr, shared_inputs])
class Param(object): class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, def __init__(self, variable, default=None, name=None, mutable=False,
strict=False, allow_downcast=None, implicit=None, borrow = None): strict=False, allow_downcast=None, implicit=None, borrow=None):
""" """
:param variable: A variable in an expression graph to use as a compiled-function parameter :param variable: A variable in an expression graph to use as a compiled-function parameter
...@@ -295,7 +294,7 @@ class Param(object): ...@@ -295,7 +294,7 @@ class Param(object):
# mutable=True should require borrow=True. Raise warning when borrow is explicitely set # mutable=True should require borrow=True. Raise warning when borrow is explicitely set
# to False with mutable=True. # to False with mutable=True.
if mutable: if mutable:
if borrow==False: if not borrow:
_logger.warning("Symbolic input for variable %s (name=%s) has " _logger.warning("Symbolic input for variable %s (name=%s) has "
"flags mutable=True, borrow=False. This combination is " "flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the " "incompatible since mutable=True implies that the "
...@@ -308,6 +307,7 @@ class Param(object): ...@@ -308,6 +307,7 @@ class Param(object):
self.implicit = implicit self.implicit = implicit
self.borrow = borrow self.borrow = borrow
def pfunc(params, outputs=None, mode=None, updates=[], givens=[], def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, rebuild_strict=True, allow_input_downcast=None,
...@@ -398,27 +398,25 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -398,27 +398,25 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
# No need to block other objects being passed through though. It might be # No need to block other objects being passed through though. It might be
# useful. # useful.
if not isinstance(params,(list,tuple)): if not isinstance(params, (list, tuple)):
raise Exception("in pfunc() the first argument must be a list or a tuple") raise Exception("in pfunc() the first argument must be a list or a tuple")
if not isinstance(no_default_updates, bool)\ if not isinstance(no_default_updates, bool)\
and not isinstance(no_default_updates, list): and not isinstance(no_default_updates, list):
raise TypeError("no_default_update should be either a boolean or a list") raise TypeError("no_default_update should be either a boolean or a list")
# transform params into theano.compile.In objects. # transform params into theano.compile.In objects.
inputs = [_pfunc_param_to_in(p, allow_downcast=allow_input_downcast) inputs = [_pfunc_param_to_in(p, allow_downcast=allow_input_downcast)
for p in params] for p in params]
in_variables = [ input.variable for input in inputs ] in_variables = [input.variable for input in inputs]
output_vars = rebuild_collect_shared( output_vars = rebuild_collect_shared(outputs,
outputs in_variables,
, in_variables replace=givens,
, replace = givens updates=updates,
, updates = updates rebuild_strict=True,
, rebuild_strict = True copy_inputs_over=True,
, copy_inputs_over = True no_default_updates=no_default_updates)
, no_default_updates = no_default_updates )
# extracting the arguments # extracting the arguments
input_variables, cloned_outputs, other_stuff = output_vars input_variables, cloned_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff clone_d, update_d, update_expr, shared_inputs = other_stuff
...@@ -431,14 +429,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -431,14 +429,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
#value will be stored in the resulting functions' defaults list #value will be stored in the resulting functions' defaults list
#but since the value of shared variables never needs to be refed, it is not needed #but since the value of shared variables never needs to be refed, it is not needed
if sv in update_d: if sv in update_d:
si = In(variable=sv, value = sv.container, mutable=True, si = In(variable=sv, value=sv.container, mutable=True,
borrow=True, update=update_d[sv], shared = True) borrow=True, update=update_d[sv], shared=True)
else: else:
si = In(variable=sv, value = sv.container, si = In(variable=sv, value=sv.container,
mutable=False, borrow=True, shared = True) mutable=False, borrow=True, shared=True)
inputs.append(si) inputs.append(si)
return orig_function(inputs, cloned_outputs, mode, return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile) accept_inplace=accept_inplace, name=name, profile=profile)
...@@ -449,7 +446,7 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None): ...@@ -449,7 +446,7 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
#if isinstance(param, Value): #if isinstance(param, Value):
#return In(variable=param) #return In(variable=param)
#raise NotImplementedError() #raise NotImplementedError()
if isinstance(param, Variable): #N.B. includes Value and SharedVariable if isinstance(param, Variable): # N.B. includes Value and SharedVariable
return In(variable=param, strict=strict, allow_downcast=allow_downcast) return In(variable=param, strict=strict, allow_downcast=allow_downcast)
elif isinstance(param, Param): elif isinstance(param, Param):
return In( return In(
...@@ -458,9 +455,9 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None): ...@@ -458,9 +455,9 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
value=param.default, value=param.default,
mutable=param.mutable, mutable=param.mutable,
strict=param.strict, strict=param.strict,
borrow = param.borrow, borrow=param.borrow,
allow_downcast=param.allow_downcast, allow_downcast=param.allow_downcast,
implicit = param.implicit) implicit=param.implicit)
raise TypeError('Unknown parameter type: %s' % type(param)) raise TypeError('Unknown parameter type: %s' % type(param))
......
...@@ -15,7 +15,8 @@ compiledir_format_dict = {"platform": platform.platform(), ...@@ -15,7 +15,8 @@ compiledir_format_dict = {"platform": platform.platform(),
"theano_version": theano.__version__, "theano_version": theano.__version__,
} }
compiledir_format_keys = ", ".join(compiledir_format_dict.keys()) compiledir_format_keys = ", ".join(compiledir_format_dict.keys())
default_compiledir_format = "compiledir_%(platform)s-%(processor)s-%(python_version)s" default_compiledir_format =\
"compiledir_%(platform)s-%(processor)s-%(python_version)s"
AddConfigVar("compiledir_format", AddConfigVar("compiledir_format",
textwrap.fill(textwrap.dedent("""\ textwrap.fill(textwrap.dedent("""\
...@@ -53,7 +54,7 @@ def filter_compiledir(path): ...@@ -53,7 +54,7 @@ def filter_compiledir(path):
# the same directory at the same time. # the same directory at the same time.
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise ValueError( raise ValueError(
"Unable to create to create the compiledir directory" "Unable to create the compiledir directory"
" '%s'. Check the permissions." % path) " '%s'. Check the permissions." % path)
# PROBLEM: sometimes the initial approach based on # PROBLEM: sometimes the initial approach based on
...@@ -118,7 +119,7 @@ def print_compiledir_content(): ...@@ -118,7 +119,7 @@ def print_compiledir_content():
compiledir = theano.config.compiledir compiledir = theano.config.compiledir
table = [] table = []
more_then_one_ops = 0 more_than_one_ops = 0
zeros_op = 0 zeros_op = 0
for dir in os.listdir(compiledir): for dir in os.listdir(compiledir):
file = None file = None
...@@ -131,7 +132,7 @@ def print_compiledir_content(): ...@@ -131,7 +132,7 @@ def print_compiledir_content():
if len(ops) == 0: if len(ops) == 0:
zeros_op += 1 zeros_op += 1
elif len(ops) > 1: elif len(ops) > 1:
more_then_one_ops += 1 more_than_one_ops += 1
else: else:
types = list(set([x for x in flatten(keydata.keys) types = list(set([x for x in flatten(keydata.keys)
if isinstance(x, theano.gof.Type)])) if isinstance(x, theano.gof.Type)]))
...@@ -142,7 +143,7 @@ def print_compiledir_content(): ...@@ -142,7 +143,7 @@ def print_compiledir_content():
if file is not None: if file is not None:
file.close() file.close()
print "List %d compiled individual op in this theano cache %s:" % ( print "List of %d compiled individual ops in this theano cache %s:" % (
len(table), compiledir) len(table), compiledir)
print "sub directory/Op/a set of the different associated Theano type" print "sub directory/Op/a set of the different associated Theano type"
table = sorted(table, key=lambda t: str(t[1])) table = sorted(table, key=lambda t: str(t[1]))
...@@ -153,13 +154,12 @@ def print_compiledir_content(): ...@@ -153,13 +154,12 @@ def print_compiledir_content():
table_op_class[op.__class__] += 1 table_op_class[op.__class__] += 1
print print
print "List %d of individual compiled Op class and" % (len(table_op_class)), print ("List of %d individual compiled Op classes and "
print " the number of time it got compiled" "the number of times they got compiled" % len(table_op_class))
table_op_class = sorted(table_op_class.iteritems(), key=lambda t: t[1]) table_op_class = sorted(table_op_class.iteritems(), key=lambda t: t[1])
for op_class, nb in table_op_class: for op_class, nb in table_op_class:
print op_class, nb print op_class, nb
print ("Skipped %d files that contained more then" print ("Skipped %d files that contained more than"
" 1 op (was compiled with the c linker)" % (more_then_one_ops)) " 1 op (was compiled with the C linker)" % more_than_one_ops)
print ("Skipped %d files that contained 0 op" print ("Skipped %d files that contained 0 op "
"(Are they always theano.scalar ops?)" % ( "(are they always theano.scalar ops?)" % zeros_op)
more_then_one_ops))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论