提交 13276e95 authored 作者: James Bergstra's avatar James Bergstra

fixed shared state bug when no update is given

上级 9eecfab3
......@@ -335,6 +335,36 @@ class T_function(unittest.TestCase):
self.failUnless(f[s] == 0)
self.failUnless(g[s] == 0)
def test_shared_state1(self):
a = T.scalar() # the a is for 'anonymous' (un-named).
x,s = T.scalars('xs')
f = function([x, In(a, value=1.0,name='a'), In(s, value=0.0, update=s+a*x, mutable=True)], s+a*x)
g = function([x, In(a, value=1.0,name='a'), In(s, value=f.container[s])], s+a*x)
f(1, 2)
self.failUnless(f[s] == 2)
self.failUnless(g[s] == 2)
f(1, 2)
g(1, 2)
self.failUnless(f[s] == 4)
self.failUnless(g[s] == 4)
def test_shared_state2(self):
a = T.scalar() # the a is for 'anonymous' (un-named).
x,s = T.scalars('xs')
f = function([x, In(a, value=1.0,name='a'), In(s, value=0.0, update=s+a*x,
mutable=False)], s+a*x)
g = function([x, In(a, value=1.0,name='a'), In(s, value=f.container[s])], s+a*x)
f(1, 2)
self.failUnless(f[s] == 2)
self.failUnless(g[s] == 2)
f(1, 2)
g(1, 2)
self.failUnless(f[s] == 4)
self.failUnless(g[s] == 4)
# class T_function_examples(unittest.TestCase):
# def test_accumulator(self):
......
......@@ -468,22 +468,23 @@ class FunctionMaker(object):
for (input, indices, subinputs), default in zip(self.indices, defaults):
__default = default
# If the default is a gof.Container, this means we want to share
# the same storage. This is done by appending default.storage
# to input_storage
if isinstance(default, gof.Container):
# If the default is a gof.Container, this means we want to share
# the same storage. This is done by appending default.storage
# to input_storage
if indices is not None:
raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
input_storage.append(default.storage)
default = None
# If the input is a SymbolicInputKit, it represents more than
# one storage unit. The indices and subinputs lists represent which
# of the kit's inputs are active in this graph, so we make as many
# storage units as needed
required = False
elif isinstance(input, SymbolicInputKit):
# If the input is a SymbolicInputKit, it represents more than
# one storage unit. The indices and subinputs lists represent which
# of the kit's inputs are active in this graph, so we make as many
# storage units as needed
input_storage += [[None] for i in indices]
# Normal case: one new, independent storage unit
else:
# Normal case: one new, independent storage unit
input_storage.append([None])
# Filling _defaults. Each entry is a tuple of three elements:
......@@ -506,7 +507,7 @@ class FunctionMaker(object):
# always do this policy.
if default is None:
if trustme or isinstance(__default, gof.Container):
_defaults.append((False, False, default))
_defaults.append((False, False, None))
else:
# This might catch some bugs early
raise ValueError("A default (initial) value is required for an input which can update itself.", input)
......@@ -514,8 +515,11 @@ class FunctionMaker(object):
_defaults.append((False, False, default))
else:
if default is None:
# No default, so this is a required input. Nothing to feed back, initial value is None.
_defaults.append((True, False, None))
if trustme or isinstance(__default, gof.Container):
_defaults.append((False, False, None))
else:
# No default, so this is a required input. Nothing to feed back, initial value is None.
_defaults.append((True, False, None))
else:
# Default value. It is not required, but we want to put it back into the storage
# everytime so it behaves like most programming languages' default values
......@@ -646,13 +650,14 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
outputs = map(wrap_out, outputs) if isinstance(outputs, (list, tuple)) else wrap_out(outputs)
defaults = [getattr(input, 'value', None) for input in inputs]
if isinstance(mode, (list, tuple)): # "mode comparison" semantics
if not mode:
raise ValueError("Please provide at least one mode.")
elif len(mode) == 1:
mode = mode[0]
else:
fn = FunctionMaker(inputs, outputs, mode[0], accept_inplace = accept_inplace).create(defaults)
else:
#return a different kind of function
def dup_defaults():
return [copy(default.value) if isinstance(default, gof.Container) else copy(default)
for default in defaults]
......@@ -661,9 +666,8 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
builder = partial(SanityCheckFunction, fns, check_equal_numpy)
maker1 = FunctionMaker(inputs, outputs, mode[0], accept_inplace = accept_inplace, function_builder = builder)
fn = maker1.create(defaults)
return fn
fn = FunctionMaker(inputs, outputs, mode, accept_inplace = accept_inplace).create(defaults)
else:
fn = FunctionMaker(inputs, outputs, mode, accept_inplace = accept_inplace).create(defaults)
return fn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论