提交 002ffea2 authored 作者: james@X40's avatar james@X40

corrections to the pickling of Function after ODs shared state changes

上级 f9c10791
......@@ -229,6 +229,7 @@ class Function(object):
#setters = []
# Initialize the storage
# this loop works by modifying the elements (as variable c) of self.input_storage inplace.
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(zip(self.indices, defaults)):
if indices is None: # this is true iff input is not a SymbolicInputKit
c = containers[0] #containers is being used as a stack. Here we pop off the next one.
......@@ -468,26 +469,26 @@ class Function(object):
def _pickle_Function(f):
#copy of the input storage list
ins = list(f.input_storage)
defaults = []
input_storage = []
for (input, indices, inputs), (required, refeed, default) in zip(f.indices, f.defaults):
if isinstance(input, SymbolicInputKit):
li = len(indices)
if not default:
defaults.append(ins[:li])
input_storage.append(ins[:li])
else:
defaults.append(default)
input_storage.append(default)
ins[:li] = []
else:
defaults.append(ins[0])
input_storage.append(ins[0])
del ins[0]
inputs_data = [x.data for x in f.input_storage]
# HACK to detect aliased storage.
# aliased relationships will not be preserved across the pickle operation
# This is here because aliased relationships are not [currently] preserved across the pickle operation
if not (f.pickle_aliased_memory_strategy == 'ignore'):
all_data = defaults + inputs_data
all_data = input_storage + inputs_data # addition here means list append
for i, d_i in enumerate(all_data):
for j, d_j in enumerate(all_data):
if (i < j) and isinstance(d_i, numpy.ndarray) and isinstance(d_j, numpy.ndarray):
......@@ -500,14 +501,14 @@ def _pickle_Function(f):
else:
raise AliasedMemoryError(d_i, d_j)
rval = (_constructor_Function, (f.maker, defaults, inputs_data))
rval = (_constructor_Function, (f.maker, input_storage, inputs_data))
return rval
def _constructor_Function(maker, defaults, data):
f = maker.create(defaults, trustme = True)
assert len(f.input_storage) == len(data)
for container, x in zip(f.input_storage, data):
container.data = x
def _constructor_Function(maker, input_storage, inputs_data):
f = maker.create(input_storage, trustme = True)
assert len(f.input_storage) == len(inputs_data)
for container, x in zip(f.input_storage, inputs_data):
assert (container.data is x) or (container.data == x)
return f
copy_reg.pickle(Function, _pickle_Function)
......@@ -690,7 +691,12 @@ class FunctionMaker(object):
self.accept_inplace = accept_inplace
self.function_builder = function_builder
def create(self, defaults = None, trustme = False):
self.required = [(i.value == None) for i in self.inputs]
self.refeed = [
(i.value != None and not isinstance(i.value, gof.Container) and i.update == None)
for i in self.inputs]
def create(self, input_storage = None, trustme = False):
"""
Create a function.
......@@ -700,101 +706,38 @@ class FunctionMaker(object):
acts as initialization.
trustme -> disables some exceptions, used internally
"""
if defaults is None:
defaults = [None]*len(self.inputs)
input_storage = [] # list of independent one-element lists, will be passed to the linker
_defaults = []
# The following loop is to fill in the input_storage and _defaults lists.
for (input, indices, subinputs), default in zip(self.indices, defaults):
if input_storage is None:
input_storage = [None]*len(self.inputs)
input_storage_lists = [] # list of independent one-element lists, will be passed to the linker
defaults = []
# The following loop is to fill in the input_storage_lists and defaults lists.
assert len(self.indices) == len(input_storage)
for i, ((input, indices, subinputs), input_storage_i) in enumerate(zip(self.indices, input_storage)):
# Replace any default value given as a variable by its container.
# Note that this makes sense only in the context of shared variables,
# but for now we avoid dealing directly with them to avoid dependency
# on the shared variables work-in-progress repository.
if isinstance(default, gof.Variable):
default = default.container
__default = default
if isinstance(input_storage_i, gof.Variable):
input_storage_i = input_storage_i.container
if isinstance(default, gof.Container) and input.implicit:
if isinstance(input_storage_i, gof.Container):
# If the default is a gof.Container and it is an implicit
# input, this means we want to share the same storage. This is
# done by appending default.storage to input_storage
# done by appending default.storage to input_storage_lists
if indices is not None:
raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
input_storage.append(default.storage)
default = None
required = False
elif isinstance(input, SymbolicInputKit):
# If the input is a SymbolicInputKit, it represents more than
# one storage unit. The indices and subinputs lists represent which
# of the kit's inputs are active in this graph, so we make as many
# storage units as needed
if isinstance(default, (list, tuple)) \
and all(isinstance(x, gof.Container) for x in default) \
and input.implicit:
if len(default) == len(indices):
input_storage += [x.storage for x in default]
elif len(default) > len(indices):
input_storage += [default[i].storage for i in indices]
else:
raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default)
default = NODEFAULT
else:
input_storage += [[None] for i in indices]
input_storage_lists.append(input_storage_i.storage)
defaults.append((self.required[i],
self.refeed[i],
input_storage_i.storage[0]))
else:
# Normal case: one new, independent storage unit
input_storage.append([None])
# Filling _defaults. Each entry is a tuple of three elements:
# (required, refeed, value)
# - required means that the user must provide a value when calling the function
# - refeed means that we want to put the default back in the storage after each function call
# - value is the value that will be put in the storage initially
# Even though a SymbolicInputKit represents more than one input,
# we still only have one entry for the defaults list.
if isinstance(input, SymbolicInputKit):
if default is NODEFAULT:
_defaults.append((False, False, None))
elif default is None:
_defaults.append((True, True, None))
else:
_defaults.append((False, False, default))
elif input.update is not None:
# If the input has an update, then (logically) it is not required since
# it is just a parameter and of course we don't want to refeed the default
# back into the storage as it would defeat the point of updating it. We
# always do this policy.
if default is None:
if (trustme or (isinstance(__default, gof.Container)
and input.implicit)):
_defaults.append((False, False, None))
else:
# This might catch some bugs early
raise ValueError("A default (initial) value is required for an input which can update itself.", input)
else:
_defaults.append((False, False, default))
else:
if default is None:
if (trustme or (isinstance(__default, gof.Container)
and input.implicit)):
_defaults.append((False, False, None))
else:
# No default, so this is a required input. Nothing to feed back, initial value is None.
_defaults.append((True, False, None))
else:
# Default value. It is not required, but we want to put it back into the storage
# everytime so it behaves like most programming languages' default values.
# Note (OD): why is it not required? If it was not put back
# into the storage, then the default value may be incorrect
# on subsequent calls. Thus, setting 'refeed' to True seems
# very important here.
_defaults.append((False, True, default))
defaults = _defaults
input_storage_lists.append([input_storage_i])
defaults.append((self.required[i], self.refeed[i], input_storage_i))
# Get a function instance
_fn, _i, _o = self.linker.make_thunk(input_storage = input_storage)
_fn, _i, _o = self.linker.make_thunk(input_storage = input_storage_lists)
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, defaults, self.unpack_single, self.return_none, self)
return fn
......
......@@ -250,7 +250,9 @@ class T_function(unittest.TestCase):
self.failUnless(f[s] == 2)
self.failUnless(g[s] == 2)
f(1, 2)
g(1, 2)
self.failUnless(f[s] == 4)
self.failUnless(g[s] == 4)
g(1, 2) # has no effect on state
self.failUnless(f[s] == 4)
self.failUnless(g[s] == 4)
......@@ -301,7 +303,7 @@ class T_picklefunction(unittest.TestCase):
print 'f.defaults = %s' % (f.defaults, )
print 'g.defaults = %s' % (g.defaults, )
self.failUnless(all([f_req == g_req and f_feed == g_feed and
type(f_val) == type(g_val)
f_val == g_val
for ((f_req, f_feed, f_val), (g_req, g_feed, g_val)) in zip(
f.defaults, g.defaults)]))
......@@ -314,7 +316,7 @@ class T_picklefunction(unittest.TestCase):
f(1,2) # put them out of sync
self.failIf(f(1, 2) == g(1, 2)) #they should not be equal anymore.
g(1, 2) # put them back in sync
self.failIf(f(3) == g(3)) # They should be in sync again.
self.failUnless(f(3) == g(3)) # They should be in sync again.
def test_pickle(self):
a = T.scalar() # the a is for 'anonymous' (un-named).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论