提交 0f86ecd9 authored 作者: lamblin's avatar lamblin

Merge pull request #424 from delallea/minor

Minor stuff (PEP8 and typos mostly)
......@@ -5,7 +5,8 @@ from theano import gof
from sharedvalue import SharedVariable
import logging
_logger=logging.getLogger("theano.compile.io")
_logger = logging.getLogger("theano.compile.io")
class SymbolicInput(object):
"""
......@@ -49,7 +50,7 @@ class SymbolicInput(object):
def __init__(self, variable, name=None, update=None, mutable=None,
strict=False, allow_downcast=None, autoname=True,
implicit=False):
assert implicit is not None # Safety check.
assert implicit is not None # Safety check.
self.variable = variable
if (autoname and name is None):
self.name = variable.name
......@@ -194,8 +195,7 @@ class In(SymbolicInput):
# try to keep it synchronized.
def __init__(self, variable, name=None, value=None, update=None,
mutable=None, strict=False, allow_downcast=None, autoname=True,
implicit=None, borrow=None, shared = False):
implicit=None, borrow=None, shared=False):
#if shared, an input's value comes from its persistent storage, not from a default stored
#in the function or from the caller
......@@ -206,7 +206,7 @@ class In(SymbolicInput):
# mutable=True should require borrow=True. Raise warning when borrow is explicitely set
# to False with mutable=True.
if mutable:
if borrow==False:
if borrow == False:
_logger.warning("Symbolic input for variable %s (name=%s) has "
"flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the "
......
"""Provide a simple user friendly API """
__docformat__ = 'restructuredtext en'
import numpy # for backport to 2.4, to get any().
from profiling import ProfileStats
from theano.gof import Container, Variable, generic, graph, Constant, Value
from theano import config
from theano.compile import orig_function, In, Out
from theano.compile.sharedvalue import SharedVariable, shared
from theano import config
from theano.gof import Container, Variable, generic, graph, Constant, Value
from theano.gof.python25 import any
import logging
_logger=logging.getLogger("theano.compile.pfunc")
def rebuild_collect_shared( outputs
, inputs = None
, replace = None
, updates = None
, rebuild_strict = True
, copy_inputs_over = True
, no_default_updates = False
_logger = logging.getLogger("theano.compile.pfunc")
def rebuild_collect_shared(outputs,
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
):
"""
Function that allows replacing subgraphs of a computational
......@@ -60,7 +63,7 @@ def rebuild_collect_shared( outputs
"""
if isinstance(outputs,tuple):
if isinstance(outputs, tuple):
outputs = list(outputs)
## This function implements similar functionality as graph.clone
......@@ -71,7 +74,6 @@ def rebuild_collect_shared( outputs
# list of shared inputs that are used as inputs of the graph
shared_inputs = []
def clone_v_get_shared_updates(v, copy_inputs_over):
'''
Clones a variable and its inputs recursively until all are in
......@@ -88,36 +90,34 @@ def rebuild_collect_shared( outputs
return clone_d[v]
if v.owner:
clone_a(v.owner, copy_inputs_over)
return clone_d.setdefault(v,v)
return clone_d.setdefault(v, v)
elif isinstance(v, SharedVariable):
if v not in shared_inputs:
shared_inputs.append(v)
if hasattr(v, 'default_update'):
# Check that v should not be excluded from the default
# updates list
if ( no_default_updates is False or
( isinstance(no_default_updates, list) and
v not in no_default_updates
)
):
if (no_default_updates is False or
(isinstance(no_default_updates, list) and
v not in no_default_updates)):
# Do not use default_update if a "real" update was
# provided
if v not in update_d:
v_update = v.type.filter_variable(v.default_update)
if v_update.type != v.type:
raise TypeError(
( 'an update must have the same type as '
'the original shared variable' )
, (v, v.type, v_update, v_update.type))
'an update must have the same type as '
'the original shared variable',
(v, v.type, v_update, v_update.type))
update_d[v] = v_update
update_expr.append((v, v_update))
if not copy_inputs_over or (isinstance(v, Constant) and
hasattr(v,'env')):
hasattr(v, 'env')):
### Cloning shared variables implies copying their underlying
### memory buffer ?? No.
return clone_d.setdefault(v,v.clone())
return clone_d.setdefault(v, v.clone())
else:
return clone_d.setdefault(v,v)
return clone_d.setdefault(v, v)
def clone_a(a, copy_inputs_over):
'''
......@@ -132,12 +132,11 @@ def rebuild_collect_shared( outputs
clone_d[a] = a.clone_with_new_inputs([clone_d[i] for i in
a.inputs],
strict = rebuild_strict)
strict=rebuild_strict)
for old_o, new_o in zip(a.outputs, clone_d[a].outputs):
clone_d.setdefault(old_o,new_o)
clone_d.setdefault(old_o, new_o)
return clone_d[a]
# intialize the clone_d mapping with the replace dictionary
if replace is None:
replace = []
......@@ -147,9 +146,9 @@ def rebuild_collect_shared( outputs
replace_pairs = replace
for v_orig, v_repl in replace_pairs:
if not isinstance(v_orig,Variable):
if not isinstance(v_orig, Variable):
raise TypeError('given keys must be Variable', v_orig)
if not isinstance(v_repl,Variable):
if not isinstance(v_repl, Variable):
v_repl = shared(v_repl)
assert v_orig not in clone_d
clone_d[v_orig] = clone_v_get_shared_updates(v_repl,
......@@ -160,9 +159,9 @@ def rebuild_collect_shared( outputs
def clone_inputs(i):
if not copy_inputs_over:
return clone_d.setdefault(i,i.clone())
return clone_d.setdefault(i, i.clone())
else:
return clone_d.setdefault(i,i)
return clone_d.setdefault(i, i)
input_variables = [clone_inputs(i) for i in inputs]
......@@ -171,7 +170,7 @@ def rebuild_collect_shared( outputs
# it is also not clear when/how to use the value of that shared
# variable (is it a default? ignored?, if the shared variable changes,
# does that function default also change?).
if numpy.any([isinstance(v, SharedVariable) for v in input_variables]):
if any([isinstance(v, SharedVariable) for v in input_variables]):
raise TypeError(('Cannot use a shared variable (%s) as explicit '
'input. Consider substituting a non-shared'
' variable via the `givens` parameter') % v)
......@@ -181,25 +180,25 @@ def rebuild_collect_shared( outputs
updates = []
for (store_into, update_val) in iter_over_pairs(updates):
if not isinstance(store_into, SharedVariable):
raise TypeError('update target must be a SharedVariable'
, store_into)
raise TypeError('update target must be a SharedVariable',
store_into)
if store_into in update_d:
raise ValueError(('this shared variable already has an update '
'expression'),
(store_into, update_d[store_into]))
raise ValueError('this shared variable already has an update '
'expression',
(store_into, update_d[store_into]))
# filter_variable ensure smooth conversion of cpu/gpu Types
update_val = store_into.type.filter_variable(update_val)
if update_val.type != store_into.type:
err_msg = ( 'an update must have the same type as the '
'original shared variable(dest, dest.type, '
'update_val, update_val.type)')
err_arg = ( store_into
, store_into.type
, update_val
, update_val.type)
raise TypeError(err_msg, err_arg )
err_msg = ('an update must have the same type as the '
'original shared variable(dest, dest.type, '
'update_val, update_val.type)')
err_arg = (store_into,
store_into.type,
update_val,
update_val.type)
raise TypeError(err_msg, err_arg)
update_d[store_into] = update_val
update_expr.append((store_into, update_val))
......@@ -215,8 +214,8 @@ def rebuild_collect_shared( outputs
copy_inputs_over)
cloned_outputs.append(Out(cloned_v, borrow=v.borrow))
else:
raise TypeError( ( 'outputs must be theano Variable or '
'Out instances'), v)
raise TypeError('outputs must be theano Variable or '
'Out instances', v)
#computed_list.append(cloned_v)
else:
if isinstance(outputs, Variable):
......@@ -229,12 +228,11 @@ def rebuild_collect_shared( outputs
cloned_outputs = Out(cloned_v, borrow=outputs.borrow)
#computed_list.append(cloned_v)
elif outputs is None:
cloned_outputs = [] # TODO: get Function.__call__ to return None
cloned_outputs = [] # TODO: get Function.__call__ to return None
else:
raise TypeError( ('output must be a theano Variable or Out '
'instance (or list of them)')
, outputs)
raise TypeError('output must be a theano Variable or Out '
'instance (or list of them)',
outputs)
# Iterate over update_expr, cloning its elements, and updating
# shared_inputs, update_d and update_expr from the SharedVariables
......@@ -244,7 +242,7 @@ def rebuild_collect_shared( outputs
# Note: we extend update_expr while iterating over it.
i = 0
while i<len(update_expr):
while i < len(update_expr):
v, v_update = update_expr[i]
cloned_v_update = clone_v_get_shared_updates(v_update,
copy_inputs_over)
......@@ -253,12 +251,13 @@ def rebuild_collect_shared( outputs
shared_inputs.append(v)
i += 1
return ( input_variables, cloned_outputs
, [clone_d, update_d, update_expr, shared_inputs] )
return (input_variables, cloned_outputs,
[clone_d, update_d, update_expr, shared_inputs])
class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False,
strict=False, allow_downcast=None, implicit=None, borrow = None):
strict=False, allow_downcast=None, implicit=None, borrow=None):
"""
:param variable: A variable in an expression graph to use as a compiled-function parameter
......@@ -295,7 +294,7 @@ class Param(object):
# mutable=True should require borrow=True. Raise warning when borrow is explicitely set
# to False with mutable=True.
if mutable:
if borrow==False:
if not borrow:
_logger.warning("Symbolic input for variable %s (name=%s) has "
"flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the "
......@@ -308,6 +307,7 @@ class Param(object):
self.implicit = implicit
self.borrow = borrow
def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None,
......@@ -398,27 +398,25 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
# No need to block other objects being passed through though. It might be
# useful.
if not isinstance(params,(list,tuple)):
if not isinstance(params, (list, tuple)):
raise Exception("in pfunc() the first argument must be a list or a tuple")
if not isinstance(no_default_updates, bool)\
and not isinstance(no_default_updates, list):
raise TypeError("no_default_update should be either a boolean or a list")
# transform params into theano.compile.In objects.
inputs = [_pfunc_param_to_in(p, allow_downcast=allow_input_downcast)
for p in params]
in_variables = [ input.variable for input in inputs ]
output_vars = rebuild_collect_shared(
outputs
, in_variables
, replace = givens
, updates = updates
, rebuild_strict = True
, copy_inputs_over = True
, no_default_updates = no_default_updates )
in_variables = [input.variable for input in inputs]
output_vars = rebuild_collect_shared(outputs,
in_variables,
replace=givens,
updates=updates,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=no_default_updates)
# extracting the arguments
input_variables, cloned_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff
......@@ -431,14 +429,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
#value will be stored in the resulting functions' defaults list
#but since the value of shared variables never needs to be refed, it is not needed
if sv in update_d:
si = In(variable=sv, value = sv.container, mutable=True,
borrow=True, update=update_d[sv], shared = True)
si = In(variable=sv, value=sv.container, mutable=True,
borrow=True, update=update_d[sv], shared=True)
else:
si = In(variable=sv, value = sv.container,
mutable=False, borrow=True, shared = True)
si = In(variable=sv, value=sv.container,
mutable=False, borrow=True, shared=True)
inputs.append(si)
return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile)
......@@ -449,7 +446,7 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
#if isinstance(param, Value):
#return In(variable=param)
#raise NotImplementedError()
if isinstance(param, Variable): #N.B. includes Value and SharedVariable
if isinstance(param, Variable): # N.B. includes Value and SharedVariable
return In(variable=param, strict=strict, allow_downcast=allow_downcast)
elif isinstance(param, Param):
return In(
......@@ -458,9 +455,9 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
value=param.default,
mutable=param.mutable,
strict=param.strict,
borrow = param.borrow,
borrow=param.borrow,
allow_downcast=param.allow_downcast,
implicit = param.implicit)
implicit=param.implicit)
raise TypeError('Unknown parameter type: %s' % type(param))
......
......@@ -15,7 +15,8 @@ compiledir_format_dict = {"platform": platform.platform(),
"theano_version": theano.__version__,
}
compiledir_format_keys = ", ".join(compiledir_format_dict.keys())
default_compiledir_format = "compiledir_%(platform)s-%(processor)s-%(python_version)s"
default_compiledir_format =\
"compiledir_%(platform)s-%(processor)s-%(python_version)s"
AddConfigVar("compiledir_format",
textwrap.fill(textwrap.dedent("""\
......@@ -53,7 +54,7 @@ def filter_compiledir(path):
# the same directory at the same time.
if e.errno != errno.EEXIST:
raise ValueError(
"Unable to create to create the compiledir directory"
"Unable to create the compiledir directory"
" '%s'. Check the permissions." % path)
# PROBLEM: sometimes the initial approach based on
......@@ -118,7 +119,7 @@ def print_compiledir_content():
compiledir = theano.config.compiledir
table = []
more_then_one_ops = 0
more_than_one_ops = 0
zeros_op = 0
for dir in os.listdir(compiledir):
file = None
......@@ -131,7 +132,7 @@ def print_compiledir_content():
if len(ops) == 0:
zeros_op += 1
elif len(ops) > 1:
more_then_one_ops += 1
more_than_one_ops += 1
else:
types = list(set([x for x in flatten(keydata.keys)
if isinstance(x, theano.gof.Type)]))
......@@ -142,7 +143,7 @@ def print_compiledir_content():
if file is not None:
file.close()
print "List %d compiled individual op in this theano cache %s:" % (
print "List of %d compiled individual ops in this theano cache %s:" % (
len(table), compiledir)
print "sub directory/Op/a set of the different associated Theano type"
table = sorted(table, key=lambda t: str(t[1]))
......@@ -153,13 +154,12 @@ def print_compiledir_content():
table_op_class[op.__class__] += 1
print
print "List %d of individual compiled Op class and" % (len(table_op_class)),
print " the number of time it got compiled"
print ("List of %d individual compiled Op classes and "
"the number of times they got compiled" % len(table_op_class))
table_op_class = sorted(table_op_class.iteritems(), key=lambda t: t[1])
for op_class, nb in table_op_class:
print op_class, nb
print ("Skipped %d files that contained more then"
" 1 op (was compiled with the c linker)" % (more_then_one_ops))
print ("Skipped %d files that contained 0 op"
"(Are they always theano.scalar ops?)" % (
more_then_one_ops))
print ("Skipped %d files that contained more than"
" 1 op (was compiled with the C linker)" % more_than_one_ops)
print ("Skipped %d files that contained 0 op "
"(are they always theano.scalar ops?)" % zeros_op)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论