提交 905a639c authored 作者: Frederic's avatar Frederic

some pep8

上级 5c173a6b
...@@ -351,12 +351,12 @@ class Function(object): ...@@ -351,12 +351,12 @@ class Function(object):
It maps container -> SymbolicInput It maps container -> SymbolicInput
""" """
def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, return_none, maker): def __init__(self, fn, input_storage, output_storage, indices, outputs,
defaults, unpack_single, return_none, maker):
""" """
Initialize attributes. create finder, inv_finder. Initialize attributes. create finder, inv_finder.
""" """
self.fn = fn self.fn = fn
self.input_storage = input_storage self.input_storage = input_storage
self.output_storage = output_storage self.output_storage = output_storage
...@@ -487,7 +487,8 @@ class Function(object): ...@@ -487,7 +487,8 @@ class Function(object):
except KeyError: except KeyError:
# Print informative error message. # Print informative error message.
msg = get_info_on_inputs(named_inputs, n_unnamed_inputs) msg = get_info_on_inputs(named_inputs, n_unnamed_inputs)
raise TypeError("Unknown input or state: %s. %s" % (str(item), msg)) raise TypeError("Unknown input or state: %s. %s" %
(str(item), msg))
if s is DUPLICATE: if s is DUPLICATE:
raise TypeError("Ambiguous name: %s - please check the names "\ raise TypeError("Ambiguous name: %s - please check the names "\
"of the inputs of your function for duplicates." % str(item)) "of the inputs of your function for duplicates." % str(item))
...@@ -531,11 +532,12 @@ class Function(object): ...@@ -531,11 +532,12 @@ class Function(object):
def __setitem__(self, item, value): def __setitem__(self, item, value):
self.value[item] = value self.value[item] = value
def __copy__(self): def __copy__(self):
defaults = [default for _1, _2, default in self.defaults] defaults = [default for _1, _2, default in self.defaults]
cpy = self.maker.create(defaults, trustme = True) cpy = self.maker.create(defaults, trustme=True)
for (input,_1,_2), here, there in zip(self.indices, self.input_storage, cpy.input_storage): for (input, _1, _2), here, there in zip(self.indices,
self.input_storage,
cpy.input_storage):
if input.mutable and here is not None: if input.mutable and here is not None:
there.data = copy.copy(here.data) there.data = copy.copy(here.data)
else: else:
...@@ -550,7 +552,7 @@ class Function(object): ...@@ -550,7 +552,7 @@ class Function(object):
for c in self.input_storage: for c in self.input_storage:
c.provided = 0 c.provided = 0
if len(args)+len(kwargs)>len(self.input_storage): if len(args) + len(kwargs) > len(self.input_storage):
raise TypeError("Too many parameter passed to theano function") raise TypeError("Too many parameter passed to theano function")
# Set positional arguments # Set positional arguments
...@@ -569,18 +571,18 @@ class Function(object): ...@@ -569,18 +571,18 @@ class Function(object):
allow_downcast=s.allow_downcast) allow_downcast=s.allow_downcast)
except Exception, e: except Exception, e:
function_name="theano function" function_name = "theano function"
if self.name: if self.name:
function_name += 'with name "'+self.name+'" ' function_name += 'with name "' + self.name + '" '
#end if #end if
e.args = tuple(["Bad input argument to " + function_name + e.args = tuple(["Bad input argument to " + function_name +
" at index %d(0-based)" % i] + list(e.args)) " at index %d(0-based)" % i] +
list(e.args))
raise raise
#end except #end except
#end if #end if
s.provided += 1 s.provided += 1
i+=1 i += 1
# Set keyword arguments # Set keyword arguments
if kwargs: # for speed, skip the iteritems for empty kwargs if kwargs: # for speed, skip the iteritems for empty kwargs
...@@ -594,7 +596,7 @@ class Function(object): ...@@ -594,7 +596,7 @@ class Function(object):
for i in xrange(len(self.input_storage)): for i in xrange(len(self.input_storage)):
i_var = self.maker.inputs[i].variable i_var = self.maker.inputs[i].variable
i_val = self.input_storage[i].storage[0] i_val = self.input_storage[i].storage[0]
if hasattr( i_var.type, 'may_share_memory'): if hasattr(i_var.type, 'may_share_memory'):
is_aliased = False is_aliased = False
for j in xrange(len(args_share_memory)): for j in xrange(len(args_share_memory)):
...@@ -603,9 +605,9 @@ class Function(object): ...@@ -603,9 +605,9 @@ class Function(object):
in args_share_memory[j]], in args_share_memory[j]],
[self.input_storage[k].storage[0] for k [self.input_storage[k].storage[0] for k
in args_share_memory[j]]) in args_share_memory[j]])
if numpy.any([ (var.type is i_var.type and if numpy.any([(var.type is i_var.type and
var.type.may_share_memory(val,i_val) var.type.may_share_memory(val,i_val))
) for (var,val) in group_j]): for (var,val) in group_j]):
is_aliased = True is_aliased = True
args_share_memory[j].append(i) args_share_memory[j].append(i)
...@@ -619,23 +621,24 @@ class Function(object): ...@@ -619,23 +621,24 @@ class Function(object):
if len(group) > 1: if len(group) > 1:
# see if any of these arguments are mutable # see if any of these arguments are mutable
mutable = numpy.any([(self.maker.inputs[idx].mutable or mutable = numpy.any([(self.maker.inputs[idx].mutable or
self.maker.inputs[idx].borrow ) self.maker.inputs[idx].borrow)
for idx in group ]) for idx in group])
# copy all but the first # copy all but the first
for idx in group[1:]: for idx in group[1:]:
self.input_storage[i].storage[0] = copy.copy( self.input_storage[i].storage[0] = copy.copy(
self.input_storage[i].storage[0]) self.input_storage[i].storage[0])
# Check if inputs are missing, or if inputs were set more than once, or # Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit. # if we tried to provide inputs that are supposed to be implicit.
for c in self.input_storage: for c in self.input_storage:
if c.required and not c.provided: if c.required and not c.provided:
raise TypeError("Missing required input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c])) raise TypeError("Missing required input: %s" %
getattr(self.inv_finder[c], 'variable',
self.inv_finder[c]))
if c.provided > 1: if c.provided > 1:
raise TypeError("Multiple values for input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c])) raise TypeError("Multiple values for input: %s" %
getattr(self.inv_finder[c], 'variable',
self.inv_finder[c]))
if c.implicit and c.provided > 0: if c.implicit and c.provided > 0:
raise TypeError('Tried to provide value for implicit input: %s' raise TypeError('Tried to provide value for implicit input: %s'
% getattr(self.inv_finder[c], 'variable', % getattr(self.inv_finder[c], 'variable',
...@@ -671,11 +674,12 @@ class Function(object): ...@@ -671,11 +674,12 @@ class Function(object):
if c.required: if c.required:
c.storage[0] = None c.storage[0] = None
# if we are allowing garbage collection, remove the input and output reference from the internal # if we are allowing garbage collection, remove the input and
# storage cells # output reference from the internal storage cells
if getattr(self.fn, 'allow_gc', False): if getattr(self.fn, 'allow_gc', False):
assert len(self.output_storage) == len(self.maker.fgraph.outputs) assert len(self.output_storage) == len(self.maker.fgraph.outputs)
for o_container, o_variable in zip(self.output_storage, self.maker.fgraph.outputs): for o_container, o_variable in zip(self.output_storage,
self.maker.fgraph.outputs):
if o_variable.owner is not None: if o_variable.owner is not None:
# this node is the variable of computation # this node is the variable of computation
# WARNING: This circumvents the 'readonly' attribute in x # WARNING: This circumvents the 'readonly' attribute in x
...@@ -683,7 +687,8 @@ class Function(object): ...@@ -683,7 +687,8 @@ class Function(object):
if getattr(self.fn, 'need_update_inputs', True): if getattr(self.fn, 'need_update_inputs', True):
# Update the inputs that have an update function # Update the inputs that have an update function
for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)): for input, storage in reversed(zip(self.maker.expanded_inputs,
self.input_storage)):
if input.update is not None: if input.update is not None:
storage.data = outputs.pop() storage.data = outputs.pop()
else: else:
...@@ -719,7 +724,7 @@ class Function(object): ...@@ -719,7 +724,7 @@ class Function(object):
value = property( value = property(
lambda self: self._value, lambda self: self._value,
None, # this property itself is not settable None, # this property itself is not settable
doc="""dictionary-like access to the values associated with Variables""") doc="dictionary-like access to the values associated with Variables")
container = property( container = property(
lambda self: self._container, lambda self: self._container,
None, # this property itself is not settable None, # this property itself is not settable
...@@ -727,6 +732,7 @@ class Function(object): ...@@ -727,6 +732,7 @@ class Function(object):
# pickling/deepcopy support for Function # pickling/deepcopy support for Function
def _pickle_Function(f): def _pickle_Function(f):
#copy of the input storage list #copy of the input storage list
ins = list(f.input_storage) ins = list(f.input_storage)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论