提交 002ffea2 authored 作者: james@X40's avatar james@X40

corrections to the pickling of Function after ODs shared state changes

上级 f9c10791
...@@ -229,6 +229,7 @@ class Function(object): ...@@ -229,6 +229,7 @@ class Function(object):
#setters = [] #setters = []
# Initialize the storage # Initialize the storage
# this loop works by modifying the elements (as variable c) of self.input_storage inplace.
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(zip(self.indices, defaults)): for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(zip(self.indices, defaults)):
if indices is None: # this is true iff input is not a SymbolicInputKit if indices is None: # this is true iff input is not a SymbolicInputKit
c = containers[0] #containers is being used as a stack. Here we pop off the next one. c = containers[0] #containers is being used as a stack. Here we pop off the next one.
...@@ -468,26 +469,26 @@ class Function(object): ...@@ -468,26 +469,26 @@ class Function(object):
def _pickle_Function(f): def _pickle_Function(f):
#copy of the input storage list #copy of the input storage list
ins = list(f.input_storage) ins = list(f.input_storage)
defaults = [] input_storage = []
for (input, indices, inputs), (required, refeed, default) in zip(f.indices, f.defaults): for (input, indices, inputs), (required, refeed, default) in zip(f.indices, f.defaults):
if isinstance(input, SymbolicInputKit): if isinstance(input, SymbolicInputKit):
li = len(indices) li = len(indices)
if not default: if not default:
defaults.append(ins[:li]) input_storage.append(ins[:li])
else: else:
defaults.append(default) input_storage.append(default)
ins[:li] = [] ins[:li] = []
else: else:
defaults.append(ins[0]) input_storage.append(ins[0])
del ins[0] del ins[0]
inputs_data = [x.data for x in f.input_storage] inputs_data = [x.data for x in f.input_storage]
# HACK to detect aliased storage. # HACK to detect aliased storage.
# aliased relationships will not be preserved across the pickle operation # This is here because aliased relationships are not [currently] preserved across the pickle operation
if not (f.pickle_aliased_memory_strategy == 'ignore'): if not (f.pickle_aliased_memory_strategy == 'ignore'):
all_data = defaults + inputs_data all_data = input_storage + inputs_data # addition here means list append
for i, d_i in enumerate(all_data): for i, d_i in enumerate(all_data):
for j, d_j in enumerate(all_data): for j, d_j in enumerate(all_data):
if (i < j) and isinstance(d_i, numpy.ndarray) and isinstance(d_j, numpy.ndarray): if (i < j) and isinstance(d_i, numpy.ndarray) and isinstance(d_j, numpy.ndarray):
...@@ -500,14 +501,14 @@ def _pickle_Function(f): ...@@ -500,14 +501,14 @@ def _pickle_Function(f):
else: else:
raise AliasedMemoryError(d_i, d_j) raise AliasedMemoryError(d_i, d_j)
rval = (_constructor_Function, (f.maker, defaults, inputs_data)) rval = (_constructor_Function, (f.maker, input_storage, inputs_data))
return rval return rval
def _constructor_Function(maker, defaults, data): def _constructor_Function(maker, input_storage, inputs_data):
f = maker.create(defaults, trustme = True) f = maker.create(input_storage, trustme = True)
assert len(f.input_storage) == len(data) assert len(f.input_storage) == len(inputs_data)
for container, x in zip(f.input_storage, data): for container, x in zip(f.input_storage, inputs_data):
container.data = x assert (container.data is x) or (container.data == x)
return f return f
copy_reg.pickle(Function, _pickle_Function) copy_reg.pickle(Function, _pickle_Function)
...@@ -690,7 +691,12 @@ class FunctionMaker(object): ...@@ -690,7 +691,12 @@ class FunctionMaker(object):
self.accept_inplace = accept_inplace self.accept_inplace = accept_inplace
self.function_builder = function_builder self.function_builder = function_builder
def create(self, defaults = None, trustme = False): self.required = [(i.value == None) for i in self.inputs]
self.refeed = [
(i.value != None and not isinstance(i.value, gof.Container) and i.update == None)
for i in self.inputs]
def create(self, input_storage = None, trustme = False):
""" """
Create a function. Create a function.
...@@ -700,101 +706,38 @@ class FunctionMaker(object): ...@@ -700,101 +706,38 @@ class FunctionMaker(object):
acts as initialization. acts as initialization.
trustme -> disables some exceptions, used internally trustme -> disables some exceptions, used internally
""" """
if defaults is None: if input_storage is None:
defaults = [None]*len(self.inputs) input_storage = [None]*len(self.inputs)
input_storage = [] # list of independent one-element lists, will be passed to the linker input_storage_lists = [] # list of independent one-element lists, will be passed to the linker
_defaults = [] defaults = []
# The following loop is to fill in the input_storage and _defaults lists. # The following loop is to fill in the input_storage_lists and defaults lists.
for (input, indices, subinputs), default in zip(self.indices, defaults): assert len(self.indices) == len(input_storage)
for i, ((input, indices, subinputs), input_storage_i) in enumerate(zip(self.indices, input_storage)):
# Replace any default value given as a variable by its container. # Replace any default value given as a variable by its container.
# Note that this makes sense only in the context of shared variables, # Note that this makes sense only in the context of shared variables,
# but for now we avoid dealing directly with them to avoid dependency # but for now we avoid dealing directly with them to avoid dependency
# on the shared variables work-in-progress repository. # on the shared variables work-in-progress repository.
if isinstance(default, gof.Variable): if isinstance(input_storage_i, gof.Variable):
default = default.container input_storage_i = input_storage_i.container
__default = default
if isinstance(default, gof.Container) and input.implicit: if isinstance(input_storage_i, gof.Container):
# If the default is a gof.Container and it is an implicit # If the default is a gof.Container and it is an implicit
# input, this means we want to share the same storage. This is # input, this means we want to share the same storage. This is
# done by appending default.storage to input_storage # done by appending default.storage to input_storage_lists
if indices is not None: if indices is not None:
raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.") raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
input_storage.append(default.storage) input_storage_lists.append(input_storage_i.storage)
default = None defaults.append((self.required[i],
required = False self.refeed[i],
elif isinstance(input, SymbolicInputKit): input_storage_i.storage[0]))
# If the input is a SymbolicInputKit, it represents more than
# one storage unit. The indices and subinputs lists represent which
# of the kit's inputs are active in this graph, so we make as many
# storage units as needed
if isinstance(default, (list, tuple)) \
and all(isinstance(x, gof.Container) for x in default) \
and input.implicit:
if len(default) == len(indices):
input_storage += [x.storage for x in default]
elif len(default) > len(indices):
input_storage += [default[i].storage for i in indices]
else:
raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default)
default = NODEFAULT
else:
input_storage += [[None] for i in indices]
else: else:
# Normal case: one new, independent storage unit # Normal case: one new, independent storage unit
input_storage.append([None]) input_storage_lists.append([input_storage_i])
defaults.append((self.required[i], self.refeed[i], input_storage_i))
# Filling _defaults. Each entry is a tuple of three elements:
# (required, refeed, value)
# - required means that the user must provide a value when calling the function
# - refeed means that we want to put the default back in the storage after each function call
# - value is the value that will be put in the storage initially
# Even though a SymbolicInputKit represents more than one input,
# we still only have one entry for the defaults list.
if isinstance(input, SymbolicInputKit):
if default is NODEFAULT:
_defaults.append((False, False, None))
elif default is None:
_defaults.append((True, True, None))
else:
_defaults.append((False, False, default))
elif input.update is not None:
# If the input has an update, then (logically) it is not required since
# it is just a parameter and of course we don't want to refeed the default
# back into the storage as it would defeat the point of updating it. We
# always do this policy.
if default is None:
if (trustme or (isinstance(__default, gof.Container)
and input.implicit)):
_defaults.append((False, False, None))
else:
# This might catch some bugs early
raise ValueError("A default (initial) value is required for an input which can update itself.", input)
else:
_defaults.append((False, False, default))
else:
if default is None:
if (trustme or (isinstance(__default, gof.Container)
and input.implicit)):
_defaults.append((False, False, None))
else:
# No default, so this is a required input. Nothing to feed back, initial value is None.
_defaults.append((True, False, None))
else:
# Default value. It is not required, but we want to put it back into the storage
# everytime so it behaves like most programming languages' default values.
# Note (OD): why is it not required? If it was not put back
# into the storage, then the default value may be incorrect
# on subsequent calls. Thus, setting 'refeed' to True seems
# very important here.
_defaults.append((False, True, default))
defaults = _defaults
# Get a function instance # Get a function instance
_fn, _i, _o = self.linker.make_thunk(input_storage = input_storage) _fn, _i, _o = self.linker.make_thunk(input_storage = input_storage_lists)
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, defaults, self.unpack_single, self.return_none, self) fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, defaults, self.unpack_single, self.return_none, self)
return fn return fn
......
...@@ -250,7 +250,9 @@ class T_function(unittest.TestCase): ...@@ -250,7 +250,9 @@ class T_function(unittest.TestCase):
self.failUnless(f[s] == 2) self.failUnless(f[s] == 2)
self.failUnless(g[s] == 2) self.failUnless(g[s] == 2)
f(1, 2) f(1, 2)
g(1, 2) self.failUnless(f[s] == 4)
self.failUnless(g[s] == 4)
g(1, 2) # has no effect on state
self.failUnless(f[s] == 4) self.failUnless(f[s] == 4)
self.failUnless(g[s] == 4) self.failUnless(g[s] == 4)
...@@ -301,7 +303,7 @@ class T_picklefunction(unittest.TestCase): ...@@ -301,7 +303,7 @@ class T_picklefunction(unittest.TestCase):
print 'f.defaults = %s' % (f.defaults, ) print 'f.defaults = %s' % (f.defaults, )
print 'g.defaults = %s' % (g.defaults, ) print 'g.defaults = %s' % (g.defaults, )
self.failUnless(all([f_req == g_req and f_feed == g_feed and self.failUnless(all([f_req == g_req and f_feed == g_feed and
type(f_val) == type(g_val) f_val == g_val
for ((f_req, f_feed, f_val), (g_req, g_feed, g_val)) in zip( for ((f_req, f_feed, f_val), (g_req, g_feed, g_val)) in zip(
f.defaults, g.defaults)])) f.defaults, g.defaults)]))
...@@ -314,7 +316,7 @@ class T_picklefunction(unittest.TestCase): ...@@ -314,7 +316,7 @@ class T_picklefunction(unittest.TestCase):
f(1,2) # put them out of sync f(1,2) # put them out of sync
self.failIf(f(1, 2) == g(1, 2)) #they should not be equal anymore. self.failIf(f(1, 2) == g(1, 2)) #they should not be equal anymore.
g(1, 2) # put them back in sync g(1, 2) # put them back in sync
self.failIf(f(3) == g(3)) # They should be in sync again. self.failUnless(f(3) == g(3)) # They should be in sync again.
def test_pickle(self): def test_pickle(self):
a = T.scalar() # the a is for 'anonymous' (un-named). a = T.scalar() # the a is for 'anonymous' (un-named).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论