提交 7f862fdc authored 作者: Olivier Delalleau's avatar Olivier Delalleau

PEP8 fixes

上级 5c576b65
...@@ -5,7 +5,8 @@ from theano import gof ...@@ -5,7 +5,8 @@ from theano import gof
from sharedvalue import SharedVariable from sharedvalue import SharedVariable
import logging import logging
_logger=logging.getLogger("theano.compile.io") _logger = logging.getLogger("theano.compile.io")
class SymbolicInput(object): class SymbolicInput(object):
""" """
...@@ -49,7 +50,7 @@ class SymbolicInput(object): ...@@ -49,7 +50,7 @@ class SymbolicInput(object):
def __init__(self, variable, name=None, update=None, mutable=None, def __init__(self, variable, name=None, update=None, mutable=None,
strict=False, allow_downcast=None, autoname=True, strict=False, allow_downcast=None, autoname=True,
implicit=False): implicit=False):
assert implicit is not None # Safety check. assert implicit is not None # Safety check.
self.variable = variable self.variable = variable
if (autoname and name is None): if (autoname and name is None):
self.name = variable.name self.name = variable.name
...@@ -194,8 +195,7 @@ class In(SymbolicInput): ...@@ -194,8 +195,7 @@ class In(SymbolicInput):
# try to keep it synchronized. # try to keep it synchronized.
def __init__(self, variable, name=None, value=None, update=None, def __init__(self, variable, name=None, value=None, update=None,
mutable=None, strict=False, allow_downcast=None, autoname=True, mutable=None, strict=False, allow_downcast=None, autoname=True,
implicit=None, borrow=None, shared = False): implicit=None, borrow=None, shared=False):
#if shared, an input's value comes from its persistent storage, not from a default stored #if shared, an input's value comes from its persistent storage, not from a default stored
#in the function or from the caller #in the function or from the caller
...@@ -206,7 +206,7 @@ class In(SymbolicInput): ...@@ -206,7 +206,7 @@ class In(SymbolicInput):
# mutable=True should require borrow=True. Raise warning when borrow is explicitely set # mutable=True should require borrow=True. Raise warning when borrow is explicitely set
# to False with mutable=True. # to False with mutable=True.
if mutable: if mutable:
if borrow==False: if borrow == False:
_logger.warning("Symbolic input for variable %s (name=%s) has " _logger.warning("Symbolic input for variable %s (name=%s) has "
"flags mutable=True, borrow=False. This combination is " "flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the " "incompatible since mutable=True implies that the "
......
"""Provide a simple user friendly API """ """Provide a simple user friendly API """
__docformat__ = 'restructuredtext en' __docformat__ = 'restructuredtext en'
import numpy # for backport to 2.4, to get any().
from profiling import ProfileStats from profiling import ProfileStats
from theano.gof import Container, Variable, generic, graph, Constant, Value
from theano import config
from theano.compile import orig_function, In, Out from theano.compile import orig_function, In, Out
from theano.compile.sharedvalue import SharedVariable, shared from theano.compile.sharedvalue import SharedVariable, shared
from theano import config from theano.gof import Container, Variable, generic, graph, Constant, Value
from theano.gof.python25 import any
import logging import logging
_logger=logging.getLogger("theano.compile.pfunc") _logger = logging.getLogger("theano.compile.pfunc")
def rebuild_collect_shared( outputs
, inputs = None def rebuild_collect_shared(outputs,
, replace = None inputs=None,
, updates = None replace=None,
, rebuild_strict = True updates=None,
, copy_inputs_over = True rebuild_strict=True,
, no_default_updates = False copy_inputs_over=True,
no_default_updates=False,
): ):
""" """
Function that allows replacing subgraphs of a computational Function that allows replacing subgraphs of a computational
...@@ -60,7 +63,7 @@ def rebuild_collect_shared( outputs ...@@ -60,7 +63,7 @@ def rebuild_collect_shared( outputs
""" """
if isinstance(outputs,tuple): if isinstance(outputs, tuple):
outputs = list(outputs) outputs = list(outputs)
## This function implements similar functionality as graph.clone ## This function implements similar functionality as graph.clone
...@@ -71,7 +74,6 @@ def rebuild_collect_shared( outputs ...@@ -71,7 +74,6 @@ def rebuild_collect_shared( outputs
# list of shared inputs that are used as inputs of the graph # list of shared inputs that are used as inputs of the graph
shared_inputs = [] shared_inputs = []
def clone_v_get_shared_updates(v, copy_inputs_over): def clone_v_get_shared_updates(v, copy_inputs_over):
''' '''
Clones a variable and its inputs recursively until all are in Clones a variable and its inputs recursively until all are in
...@@ -88,36 +90,34 @@ def rebuild_collect_shared( outputs ...@@ -88,36 +90,34 @@ def rebuild_collect_shared( outputs
return clone_d[v] return clone_d[v]
if v.owner: if v.owner:
clone_a(v.owner, copy_inputs_over) clone_a(v.owner, copy_inputs_over)
return clone_d.setdefault(v,v) return clone_d.setdefault(v, v)
elif isinstance(v, SharedVariable): elif isinstance(v, SharedVariable):
if v not in shared_inputs: if v not in shared_inputs:
shared_inputs.append(v) shared_inputs.append(v)
if hasattr(v, 'default_update'): if hasattr(v, 'default_update'):
# Check that v should not be excluded from the default # Check that v should not be excluded from the default
# updates list # updates list
if ( no_default_updates is False or if (no_default_updates is False or
( isinstance(no_default_updates, list) and (isinstance(no_default_updates, list) and
v not in no_default_updates v not in no_default_updates)):
)
):
# Do not use default_update if a "real" update was # Do not use default_update if a "real" update was
# provided # provided
if v not in update_d: if v not in update_d:
v_update = v.type.filter_variable(v.default_update) v_update = v.type.filter_variable(v.default_update)
if v_update.type != v.type: if v_update.type != v.type:
raise TypeError( raise TypeError(
( 'an update must have the same type as ' 'an update must have the same type as '
'the original shared variable' ) 'the original shared variable',
, (v, v.type, v_update, v_update.type)) (v, v.type, v_update, v_update.type))
update_d[v] = v_update update_d[v] = v_update
update_expr.append((v, v_update)) update_expr.append((v, v_update))
if not copy_inputs_over or (isinstance(v, Constant) and if not copy_inputs_over or (isinstance(v, Constant) and
hasattr(v,'env')): hasattr(v, 'env')):
### Cloning shared variables implies copying their underlying ### Cloning shared variables implies copying their underlying
### memory buffer ?? No. ### memory buffer ?? No.
return clone_d.setdefault(v,v.clone()) return clone_d.setdefault(v, v.clone())
else: else:
return clone_d.setdefault(v,v) return clone_d.setdefault(v, v)
def clone_a(a, copy_inputs_over): def clone_a(a, copy_inputs_over):
''' '''
...@@ -132,12 +132,11 @@ def rebuild_collect_shared( outputs ...@@ -132,12 +132,11 @@ def rebuild_collect_shared( outputs
clone_d[a] = a.clone_with_new_inputs([clone_d[i] for i in clone_d[a] = a.clone_with_new_inputs([clone_d[i] for i in
a.inputs], a.inputs],
strict = rebuild_strict) strict=rebuild_strict)
for old_o, new_o in zip(a.outputs, clone_d[a].outputs): for old_o, new_o in zip(a.outputs, clone_d[a].outputs):
clone_d.setdefault(old_o,new_o) clone_d.setdefault(old_o, new_o)
return clone_d[a] return clone_d[a]
# intialize the clone_d mapping with the replace dictionary # intialize the clone_d mapping with the replace dictionary
if replace is None: if replace is None:
replace = [] replace = []
...@@ -147,9 +146,9 @@ def rebuild_collect_shared( outputs ...@@ -147,9 +146,9 @@ def rebuild_collect_shared( outputs
replace_pairs = replace replace_pairs = replace
for v_orig, v_repl in replace_pairs: for v_orig, v_repl in replace_pairs:
if not isinstance(v_orig,Variable): if not isinstance(v_orig, Variable):
raise TypeError('given keys must be Variable', v_orig) raise TypeError('given keys must be Variable', v_orig)
if not isinstance(v_repl,Variable): if not isinstance(v_repl, Variable):
v_repl = shared(v_repl) v_repl = shared(v_repl)
assert v_orig not in clone_d assert v_orig not in clone_d
clone_d[v_orig] = clone_v_get_shared_updates(v_repl, clone_d[v_orig] = clone_v_get_shared_updates(v_repl,
...@@ -160,9 +159,9 @@ def rebuild_collect_shared( outputs ...@@ -160,9 +159,9 @@ def rebuild_collect_shared( outputs
def clone_inputs(i): def clone_inputs(i):
if not copy_inputs_over: if not copy_inputs_over:
return clone_d.setdefault(i,i.clone()) return clone_d.setdefault(i, i.clone())
else: else:
return clone_d.setdefault(i,i) return clone_d.setdefault(i, i)
input_variables = [clone_inputs(i) for i in inputs] input_variables = [clone_inputs(i) for i in inputs]
...@@ -171,7 +170,7 @@ def rebuild_collect_shared( outputs ...@@ -171,7 +170,7 @@ def rebuild_collect_shared( outputs
# it is also not clear when/how to use the value of that shared # it is also not clear when/how to use the value of that shared
# variable (is it a default? ignored?, if the shared variable changes, # variable (is it a default? ignored?, if the shared variable changes,
# does that function default also change?). # does that function default also change?).
if numpy.any([isinstance(v, SharedVariable) for v in input_variables]): if any([isinstance(v, SharedVariable) for v in input_variables]):
raise TypeError(('Cannot use a shared variable (%s) as explicit ' raise TypeError(('Cannot use a shared variable (%s) as explicit '
'input. Consider substituting a non-shared' 'input. Consider substituting a non-shared'
' variable via the `givens` parameter') % v) ' variable via the `givens` parameter') % v)
...@@ -181,25 +180,25 @@ def rebuild_collect_shared( outputs ...@@ -181,25 +180,25 @@ def rebuild_collect_shared( outputs
updates = [] updates = []
for (store_into, update_val) in iter_over_pairs(updates): for (store_into, update_val) in iter_over_pairs(updates):
if not isinstance(store_into, SharedVariable): if not isinstance(store_into, SharedVariable):
raise TypeError('update target must be a SharedVariable' raise TypeError('update target must be a SharedVariable',
, store_into) store_into)
if store_into in update_d: if store_into in update_d:
raise ValueError(('this shared variable already has an update ' raise ValueError('this shared variable already has an update '
'expression'), 'expression',
(store_into, update_d[store_into])) (store_into, update_d[store_into]))
# filter_variable ensure smooth conversion of cpu/gpu Types # filter_variable ensure smooth conversion of cpu/gpu Types
update_val = store_into.type.filter_variable(update_val) update_val = store_into.type.filter_variable(update_val)
if update_val.type != store_into.type: if update_val.type != store_into.type:
err_msg = ( 'an update must have the same type as the ' err_msg = ('an update must have the same type as the '
'original shared variable(dest, dest.type, ' 'original shared variable(dest, dest.type, '
'update_val, update_val.type)') 'update_val, update_val.type)')
err_arg = ( store_into err_arg = (store_into,
, store_into.type store_into.type,
, update_val update_val,
, update_val.type) update_val.type)
raise TypeError(err_msg, err_arg ) raise TypeError(err_msg, err_arg)
update_d[store_into] = update_val update_d[store_into] = update_val
update_expr.append((store_into, update_val)) update_expr.append((store_into, update_val))
...@@ -215,8 +214,8 @@ def rebuild_collect_shared( outputs ...@@ -215,8 +214,8 @@ def rebuild_collect_shared( outputs
copy_inputs_over) copy_inputs_over)
cloned_outputs.append(Out(cloned_v, borrow=v.borrow)) cloned_outputs.append(Out(cloned_v, borrow=v.borrow))
else: else:
raise TypeError( ( 'outputs must be theano Variable or ' raise TypeError('outputs must be theano Variable or '
'Out instances'), v) 'Out instances', v)
#computed_list.append(cloned_v) #computed_list.append(cloned_v)
else: else:
if isinstance(outputs, Variable): if isinstance(outputs, Variable):
...@@ -229,12 +228,11 @@ def rebuild_collect_shared( outputs ...@@ -229,12 +228,11 @@ def rebuild_collect_shared( outputs
cloned_outputs = Out(cloned_v, borrow=outputs.borrow) cloned_outputs = Out(cloned_v, borrow=outputs.borrow)
#computed_list.append(cloned_v) #computed_list.append(cloned_v)
elif outputs is None: elif outputs is None:
cloned_outputs = [] # TODO: get Function.__call__ to return None cloned_outputs = [] # TODO: get Function.__call__ to return None
else: else:
raise TypeError( ('output must be a theano Variable or Out ' raise TypeError('output must be a theano Variable or Out '
'instance (or list of them)') 'instance (or list of them)',
, outputs) outputs)
# Iterate over update_expr, cloning its elements, and updating # Iterate over update_expr, cloning its elements, and updating
# shared_inputs, update_d and update_expr from the SharedVariables # shared_inputs, update_d and update_expr from the SharedVariables
...@@ -244,7 +242,7 @@ def rebuild_collect_shared( outputs ...@@ -244,7 +242,7 @@ def rebuild_collect_shared( outputs
# Note: we extend update_expr while iterating over it. # Note: we extend update_expr while iterating over it.
i = 0 i = 0
while i<len(update_expr): while i < len(update_expr):
v, v_update = update_expr[i] v, v_update = update_expr[i]
cloned_v_update = clone_v_get_shared_updates(v_update, cloned_v_update = clone_v_get_shared_updates(v_update,
copy_inputs_over) copy_inputs_over)
...@@ -253,12 +251,13 @@ def rebuild_collect_shared( outputs ...@@ -253,12 +251,13 @@ def rebuild_collect_shared( outputs
shared_inputs.append(v) shared_inputs.append(v)
i += 1 i += 1
return ( input_variables, cloned_outputs return (input_variables, cloned_outputs,
, [clone_d, update_d, update_expr, shared_inputs] ) [clone_d, update_d, update_expr, shared_inputs])
class Param(object): class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, def __init__(self, variable, default=None, name=None, mutable=False,
strict=False, allow_downcast=None, implicit=None, borrow = None): strict=False, allow_downcast=None, implicit=None, borrow=None):
""" """
:param variable: A variable in an expression graph to use as a compiled-function parameter :param variable: A variable in an expression graph to use as a compiled-function parameter
...@@ -295,7 +294,7 @@ class Param(object): ...@@ -295,7 +294,7 @@ class Param(object):
# mutable=True should require borrow=True. Raise warning when borrow is explicitely set # mutable=True should require borrow=True. Raise warning when borrow is explicitely set
# to False with mutable=True. # to False with mutable=True.
if mutable: if mutable:
if borrow==False: if not borrow:
_logger.warning("Symbolic input for variable %s (name=%s) has " _logger.warning("Symbolic input for variable %s (name=%s) has "
"flags mutable=True, borrow=False. This combination is " "flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the " "incompatible since mutable=True implies that the "
...@@ -308,6 +307,7 @@ class Param(object): ...@@ -308,6 +307,7 @@ class Param(object):
self.implicit = implicit self.implicit = implicit
self.borrow = borrow self.borrow = borrow
def pfunc(params, outputs=None, mode=None, updates=[], givens=[], def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, rebuild_strict=True, allow_input_downcast=None,
...@@ -398,27 +398,25 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -398,27 +398,25 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
# No need to block other objects being passed through though. It might be # No need to block other objects being passed through though. It might be
# useful. # useful.
if not isinstance(params,(list,tuple)): if not isinstance(params, (list, tuple)):
raise Exception("in pfunc() the first argument must be a list or a tuple") raise Exception("in pfunc() the first argument must be a list or a tuple")
if not isinstance(no_default_updates, bool)\ if not isinstance(no_default_updates, bool)\
and not isinstance(no_default_updates, list): and not isinstance(no_default_updates, list):
raise TypeError("no_default_update should be either a boolean or a list") raise TypeError("no_default_update should be either a boolean or a list")
# transform params into theano.compile.In objects. # transform params into theano.compile.In objects.
inputs = [_pfunc_param_to_in(p, allow_downcast=allow_input_downcast) inputs = [_pfunc_param_to_in(p, allow_downcast=allow_input_downcast)
for p in params] for p in params]
in_variables = [ input.variable for input in inputs ] in_variables = [input.variable for input in inputs]
output_vars = rebuild_collect_shared( output_vars = rebuild_collect_shared(outputs,
outputs in_variables,
, in_variables replace=givens,
, replace = givens updates=updates,
, updates = updates rebuild_strict=True,
, rebuild_strict = True copy_inputs_over=True,
, copy_inputs_over = True no_default_updates=no_default_updates)
, no_default_updates = no_default_updates )
# extracting the arguments # extracting the arguments
input_variables, cloned_outputs, other_stuff = output_vars input_variables, cloned_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff clone_d, update_d, update_expr, shared_inputs = other_stuff
...@@ -431,14 +429,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -431,14 +429,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
#value will be stored in the resulting functions' defaults list #value will be stored in the resulting functions' defaults list
#but since the value of shared variables never needs to be refed, it is not needed #but since the value of shared variables never needs to be refed, it is not needed
if sv in update_d: if sv in update_d:
si = In(variable=sv, value = sv.container, mutable=True, si = In(variable=sv, value=sv.container, mutable=True,
borrow=True, update=update_d[sv], shared = True) borrow=True, update=update_d[sv], shared=True)
else: else:
si = In(variable=sv, value = sv.container, si = In(variable=sv, value=sv.container,
mutable=False, borrow=True, shared = True) mutable=False, borrow=True, shared=True)
inputs.append(si) inputs.append(si)
return orig_function(inputs, cloned_outputs, mode, return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile) accept_inplace=accept_inplace, name=name, profile=profile)
...@@ -449,7 +446,7 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None): ...@@ -449,7 +446,7 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
#if isinstance(param, Value): #if isinstance(param, Value):
#return In(variable=param) #return In(variable=param)
#raise NotImplementedError() #raise NotImplementedError()
if isinstance(param, Variable): #N.B. includes Value and SharedVariable if isinstance(param, Variable): # N.B. includes Value and SharedVariable
return In(variable=param, strict=strict, allow_downcast=allow_downcast) return In(variable=param, strict=strict, allow_downcast=allow_downcast)
elif isinstance(param, Param): elif isinstance(param, Param):
return In( return In(
...@@ -458,9 +455,9 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None): ...@@ -458,9 +455,9 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
value=param.default, value=param.default,
mutable=param.mutable, mutable=param.mutable,
strict=param.strict, strict=param.strict,
borrow = param.borrow, borrow=param.borrow,
allow_downcast=param.allow_downcast, allow_downcast=param.allow_downcast,
implicit = param.implicit) implicit=param.implicit)
raise TypeError('Unknown parameter type: %s' % type(param)) raise TypeError('Unknown parameter type: %s' % type(param))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论