提交 13276e95 authored 作者: James Bergstra's avatar James Bergstra

fixed shared state bug when no update is given

上级 9eecfab3
...@@ -335,6 +335,36 @@ class T_function(unittest.TestCase): ...@@ -335,6 +335,36 @@ class T_function(unittest.TestCase):
self.failUnless(f[s] == 0) self.failUnless(f[s] == 0)
self.failUnless(g[s] == 0) self.failUnless(g[s] == 0)
def test_shared_state1(self):
a = T.scalar() # the a is for 'anonymous' (un-named).
x,s = T.scalars('xs')
f = function([x, In(a, value=1.0,name='a'), In(s, value=0.0, update=s+a*x, mutable=True)], s+a*x)
g = function([x, In(a, value=1.0,name='a'), In(s, value=f.container[s])], s+a*x)
f(1, 2)
self.failUnless(f[s] == 2)
self.failUnless(g[s] == 2)
f(1, 2)
g(1, 2)
self.failUnless(f[s] == 4)
self.failUnless(g[s] == 4)
def test_shared_state2(self):
a = T.scalar() # the a is for 'anonymous' (un-named).
x,s = T.scalars('xs')
f = function([x, In(a, value=1.0,name='a'), In(s, value=0.0, update=s+a*x,
mutable=False)], s+a*x)
g = function([x, In(a, value=1.0,name='a'), In(s, value=f.container[s])], s+a*x)
f(1, 2)
self.failUnless(f[s] == 2)
self.failUnless(g[s] == 2)
f(1, 2)
g(1, 2)
self.failUnless(f[s] == 4)
self.failUnless(g[s] == 4)
# class T_function_examples(unittest.TestCase): # class T_function_examples(unittest.TestCase):
# def test_accumulator(self): # def test_accumulator(self):
......
...@@ -468,22 +468,23 @@ class FunctionMaker(object): ...@@ -468,22 +468,23 @@ class FunctionMaker(object):
for (input, indices, subinputs), default in zip(self.indices, defaults): for (input, indices, subinputs), default in zip(self.indices, defaults):
__default = default __default = default
if isinstance(default, gof.Container):
# If the default is a gof.Container, this means we want to share # If the default is a gof.Container, this means we want to share
# the same storage. This is done by appending default.storage # the same storage. This is done by appending default.storage
# to input_storage # to input_storage
if isinstance(default, gof.Container):
if indices is not None: if indices is not None:
raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.") raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
input_storage.append(default.storage) input_storage.append(default.storage)
default = None default = None
required = False
elif isinstance(input, SymbolicInputKit):
# If the input is a SymbolicInputKit, it represents more than # If the input is a SymbolicInputKit, it represents more than
# one storage unit. The indices and subinputs lists represent which # one storage unit. The indices and subinputs lists represent which
# of the kit's inputs are active in this graph, so we make as many # of the kit's inputs are active in this graph, so we make as many
# storage units as needed # storage units as needed
elif isinstance(input, SymbolicInputKit):
input_storage += [[None] for i in indices] input_storage += [[None] for i in indices]
# Normal case: one new, independent storage unit
else: else:
# Normal case: one new, independent storage unit
input_storage.append([None]) input_storage.append([None])
# Filling _defaults. Each entry is a tuple of three elements: # Filling _defaults. Each entry is a tuple of three elements:
...@@ -506,7 +507,7 @@ class FunctionMaker(object): ...@@ -506,7 +507,7 @@ class FunctionMaker(object):
# always do this policy. # always do this policy.
if default is None: if default is None:
if trustme or isinstance(__default, gof.Container): if trustme or isinstance(__default, gof.Container):
_defaults.append((False, False, default)) _defaults.append((False, False, None))
else: else:
# This might catch some bugs early # This might catch some bugs early
raise ValueError("A default (initial) value is required for an input which can update itself.", input) raise ValueError("A default (initial) value is required for an input which can update itself.", input)
...@@ -514,6 +515,9 @@ class FunctionMaker(object): ...@@ -514,6 +515,9 @@ class FunctionMaker(object):
_defaults.append((False, False, default)) _defaults.append((False, False, default))
else: else:
if default is None: if default is None:
if trustme or isinstance(__default, gof.Container):
_defaults.append((False, False, None))
else:
# No default, so this is a required input. Nothing to feed back, initial value is None. # No default, so this is a required input. Nothing to feed back, initial value is None.
_defaults.append((True, False, None)) _defaults.append((True, False, None))
else: else:
...@@ -651,8 +655,9 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False): ...@@ -651,8 +655,9 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
if not mode: if not mode:
raise ValueError("Please provide at least one mode.") raise ValueError("Please provide at least one mode.")
elif len(mode) == 1: elif len(mode) == 1:
mode = mode[0] fn = FunctionMaker(inputs, outputs, mode[0], accept_inplace = accept_inplace).create(defaults)
else: else:
#return a different kind of function
def dup_defaults(): def dup_defaults():
return [copy(default.value) if isinstance(default, gof.Container) else copy(default) return [copy(default.value) if isinstance(default, gof.Container) else copy(default)
for default in defaults] for default in defaults]
...@@ -661,8 +666,7 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False): ...@@ -661,8 +666,7 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
builder = partial(SanityCheckFunction, fns, check_equal_numpy) builder = partial(SanityCheckFunction, fns, check_equal_numpy)
maker1 = FunctionMaker(inputs, outputs, mode[0], accept_inplace = accept_inplace, function_builder = builder) maker1 = FunctionMaker(inputs, outputs, mode[0], accept_inplace = accept_inplace, function_builder = builder)
fn = maker1.create(defaults) fn = maker1.create(defaults)
return fn else:
fn = FunctionMaker(inputs, outputs, mode, accept_inplace = accept_inplace).create(defaults) fn = FunctionMaker(inputs, outputs, mode, accept_inplace = accept_inplace).create(defaults)
return fn return fn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论