提交 35da0e0d authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fix to avoid overwriting a container's content when used as a function's default value

上级 16c870a4
...@@ -235,7 +235,12 @@ class Function(object): ...@@ -235,7 +235,12 @@ class Function(object):
if input.strict: if input.strict:
c.strict = True c.strict = True
if value is not None: if value is not None:
# always initialize the storage # Always initialize the storage.
if isinstance(value, gof.Container):
# We obtain the default value from whatever value is currently
# stored in the default container.
assert len(value.storage) == 1
value = value.storage[0]
c.data = value c.data = value
c.required = required c.required = required
c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__) c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__)
...@@ -382,6 +387,8 @@ class Function(object): ...@@ -382,6 +387,8 @@ class Function(object):
# Put default values back in the storage # Put default values back in the storage
for i, (required, refeed, value) in enumerate(self.defaults): for i, (required, refeed, value) in enumerate(self.defaults):
if refeed: if refeed:
if isinstance(value, gof.Container):
value = value.storage[0]
self[i] = value self[i] = value
if self.return_none: if self.return_none:
return None return None
...@@ -643,32 +650,39 @@ class FunctionMaker(object): ...@@ -643,32 +650,39 @@ class FunctionMaker(object):
# The following loop is to fill in the input_storage and _defaults lists. # The following loop is to fill in the input_storage and _defaults lists.
for (input, indices, subinputs), default in zip(self.indices, defaults): for (input, indices, subinputs), default in zip(self.indices, defaults):
# Replace any default value given as a variable by its container.
# Note that this makes sense only in the context of shared variables,
# but for now we avoid dealing directly with them to avoid dependency
# on the shared variables work-in-progress repository.
if isinstance(default, gof.Variable):
default = default.container
__default = default __default = default
if isinstance(default, gof.Container): #if isinstance(default, gof.Container):
# If the default is a gof.Container, this means we want to share ## If the default is a gof.Container, this means we want to share
# the same storage. This is done by appending default.storage ## the same storage. This is done by appending default.storage
# to input_storage ## to input_storage
if indices is not None: #if indices is not None:
raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.") #raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
input_storage.append(default.storage) #input_storage.append(default.storage)
default = None #default = None
required = False #required = False
elif isinstance(input, SymbolicInputKit): if isinstance(input, SymbolicInputKit):
# If the input is a SymbolicInputKit, it represents more than # If the input is a SymbolicInputKit, it represents more than
# one storage unit. The indices and subinputs lists represent which # one storage unit. The indices and subinputs lists represent which
# of the kit's inputs are active in this graph, so we make as many # of the kit's inputs are active in this graph, so we make as many
# storage units as needed # storage units as needed
if isinstance(default, (list, tuple)) \ #if isinstance(default, (list, tuple)) \
and all(isinstance(x, gof.Container) for x in default): #and all(isinstance(x, gof.Container) for x in default):
if len(default) == len(indices): #if len(default) == len(indices):
input_storage += [x.storage for x in default] #input_storage += [x.storage for x in default]
elif len(default) > len(indices): #elif len(default) > len(indices):
input_storage += [default[i].storage for i in indices] #input_storage += [default[i].storage for i in indices]
else: #else:
raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default) #raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default)
default = NODEFAULT #default = NODEFAULT
else: #else:
input_storage += [[None] for i in indices] input_storage += [[None] for i in indices]
else: else:
# Normal case: one new, independent storage unit # Normal case: one new, independent storage unit
...@@ -695,7 +709,7 @@ class FunctionMaker(object): ...@@ -695,7 +709,7 @@ class FunctionMaker(object):
# back into the storage as it would defeat the point of updating it. We # back into the storage as it would defeat the point of updating it. We
# always do this policy. # always do this policy.
if default is None: if default is None:
if trustme or isinstance(__default, gof.Container): if trustme: #or isinstance(__default, gof.Container):
_defaults.append((False, False, None)) _defaults.append((False, False, None))
else: else:
# This might catch some bugs early # This might catch some bugs early
...@@ -704,7 +718,7 @@ class FunctionMaker(object): ...@@ -704,7 +718,7 @@ class FunctionMaker(object):
_defaults.append((False, False, default)) _defaults.append((False, False, default))
else: else:
if default is None: if default is None:
if trustme or isinstance(__default, gof.Container): if trustme: #or isinstance(__default, gof.Container):
_defaults.append((False, False, None)) _defaults.append((False, False, None))
else: else:
# No default, so this is a required input. Nothing to feed back, initial value is None. # No default, so this is a required input. Nothing to feed back, initial value is None.
...@@ -820,6 +834,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False): ...@@ -820,6 +834,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
else: else:
#return a different kind of function #return a different kind of function
def dup_defaults(): def dup_defaults():
# TODO This may need to be changed to use containers as defaults.
return [copy.copy(default.value) if isinstance(default, gof.Container) else return [copy.copy(default.value) if isinstance(default, gof.Container) else
copy.copy(default) copy.copy(default)
for default in defaults] for default in defaults]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论