提交 905a639c authored 作者: Frederic's avatar Frederic

some pep8

上级 5c173a6b
......@@ -351,12 +351,12 @@ class Function(object):
It maps container -> SymbolicInput
"""
def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, return_none, maker):
def __init__(self, fn, input_storage, output_storage, indices, outputs,
defaults, unpack_single, return_none, maker):
"""
Initialize attributes. create finder, inv_finder.
"""
self.fn = fn
self.input_storage = input_storage
self.output_storage = output_storage
......@@ -366,7 +366,7 @@ class Function(object):
self.unpack_single = unpack_single
self.return_none = return_none
self.maker = maker
self.profile = None # reassigned in FunctionMaker.create
self.profile = None # reassigned in FunctionMaker.create
# We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage)
......@@ -487,7 +487,8 @@ class Function(object):
except KeyError:
# Print informative error message.
msg = get_info_on_inputs(named_inputs, n_unnamed_inputs)
raise TypeError("Unknown input or state: %s. %s" % (str(item), msg))
raise TypeError("Unknown input or state: %s. %s" %
(str(item), msg))
if s is DUPLICATE:
raise TypeError("Ambiguous name: %s - please check the names "\
"of the inputs of your function for duplicates." % str(item))
......@@ -531,11 +532,12 @@ class Function(object):
def __setitem__(self, item, value):
self.value[item] = value
def __copy__(self):
defaults = [default for _1, _2, default in self.defaults]
cpy = self.maker.create(defaults, trustme = True)
for (input,_1,_2), here, there in zip(self.indices, self.input_storage, cpy.input_storage):
cpy = self.maker.create(defaults, trustme=True)
for (input, _1, _2), here, there in zip(self.indices,
self.input_storage,
cpy.input_storage):
if input.mutable and here is not None:
there.data = copy.copy(here.data)
else:
......@@ -550,7 +552,7 @@ class Function(object):
for c in self.input_storage:
c.provided = 0
if len(args)+len(kwargs)>len(self.input_storage):
if len(args) + len(kwargs) > len(self.input_storage):
raise TypeError("Too many parameter passed to theano function")
# Set positional arguments
......@@ -569,18 +571,18 @@ class Function(object):
allow_downcast=s.allow_downcast)
except Exception, e:
function_name="theano function"
function_name = "theano function"
if self.name:
function_name += 'with name "'+self.name+'" '
function_name += 'with name "' + self.name + '" '
#end if
e.args = tuple(["Bad input argument to " + function_name +
" at index %d(0-based)" % i] + list(e.args))
" at index %d(0-based)" % i] +
list(e.args))
raise
#end except
#end if
s.provided += 1
i+=1
i += 1
# Set keyword arguments
if kwargs: # for speed, skip the iteritems for empty kwargs
......@@ -594,7 +596,7 @@ class Function(object):
for i in xrange(len(self.input_storage)):
i_var = self.maker.inputs[i].variable
i_val = self.input_storage[i].storage[0]
if hasattr( i_var.type, 'may_share_memory'):
if hasattr(i_var.type, 'may_share_memory'):
is_aliased = False
for j in xrange(len(args_share_memory)):
......@@ -603,9 +605,9 @@ class Function(object):
in args_share_memory[j]],
[self.input_storage[k].storage[0] for k
in args_share_memory[j]])
if numpy.any([ (var.type is i_var.type and
var.type.may_share_memory(val,i_val)
) for (var,val) in group_j]):
if numpy.any([(var.type is i_var.type and
var.type.may_share_memory(val,i_val))
for (var,val) in group_j]):
is_aliased = True
args_share_memory[j].append(i)
......@@ -619,23 +621,24 @@ class Function(object):
if len(group) > 1:
# see if any of these arguments are mutable
mutable = numpy.any([(self.maker.inputs[idx].mutable or
self.maker.inputs[idx].borrow )
for idx in group ])
self.maker.inputs[idx].borrow)
for idx in group])
# copy all but the first
for idx in group[1:]:
self.input_storage[i].storage[0] = copy.copy(
self.input_storage[i].storage[0])
# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
for c in self.input_storage:
if c.required and not c.provided:
raise TypeError("Missing required input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c]))
raise TypeError("Missing required input: %s" %
getattr(self.inv_finder[c], 'variable',
self.inv_finder[c]))
if c.provided > 1:
raise TypeError("Multiple values for input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c]))
raise TypeError("Multiple values for input: %s" %
getattr(self.inv_finder[c], 'variable',
self.inv_finder[c]))
if c.implicit and c.provided > 0:
raise TypeError('Tried to provide value for implicit input: %s'
% getattr(self.inv_finder[c], 'variable',
......@@ -671,11 +674,12 @@ class Function(object):
if c.required:
c.storage[0] = None
# if we are allowing garbage collection, remove the input and output reference from the internal
# storage cells
# if we are allowing garbage collection, remove the input and
# output reference from the internal storage cells
if getattr(self.fn, 'allow_gc', False):
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
for o_container, o_variable in zip(self.output_storage, self.maker.fgraph.outputs):
for o_container, o_variable in zip(self.output_storage,
self.maker.fgraph.outputs):
if o_variable.owner is not None:
# this node is the variable of computation
# WARNING: This circumvents the 'readonly' attribute in x
......@@ -683,7 +687,8 @@ class Function(object):
if getattr(self.fn, 'need_update_inputs', True):
# Update the inputs that have an update function
for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)):
for input, storage in reversed(zip(self.maker.expanded_inputs,
self.input_storage)):
if input.update is not None:
storage.data = outputs.pop()
else:
......@@ -718,15 +723,16 @@ class Function(object):
value = property(
lambda self: self._value,
None, # this property itself is not settable
doc="""dictionary-like access to the values associated with Variables""")
None, # this property itself is not settable
doc="dictionary-like access to the values associated with Variables")
container = property(
lambda self: self._container,
None, # this property itself is not settable
None, # this property itself is not settable
doc="""dictionary-like access to the containers associated with Variables""")
# pickling/deepcopy support for Function
def _pickle_Function(f):
#copy of the input storage list
ins = list(f.input_storage)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论