提交 35da0e0d authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fix to avoid overwriting a container's content when used as a function's default value

上级 16c870a4
......@@ -235,7 +235,12 @@ class Function(object):
if input.strict:
c.strict = True
if value is not None:
# always initialize the storage
# Always initialize the storage.
if isinstance(value, gof.Container):
# We obtain the default value from whatever value is currently
# stored in the default container.
assert len(value.storage) == 1
value = value.storage[0]
c.data = value
c.required = required
c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__)
......@@ -382,6 +387,8 @@ class Function(object):
# Put default values back in the storage
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, gof.Container):
value = value.storage[0]
self[i] = value
if self.return_none:
return None
......@@ -643,32 +650,39 @@ class FunctionMaker(object):
# The following loop is to fill in the input_storage and _defaults lists.
for (input, indices, subinputs), default in zip(self.indices, defaults):
# Replace any default value given as a variable by its container.
# Note that this makes sense only in the context of shared variables,
# but for now we avoid dealing directly with them to avoid dependency
# on the shared variables work-in-progress repository.
if isinstance(default, gof.Variable):
default = default.container
__default = default
if isinstance(default, gof.Container):
# If the default is a gof.Container, this means we want to share
# the same storage. This is done by appending default.storage
# to input_storage
if indices is not None:
raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
input_storage.append(default.storage)
default = None
required = False
elif isinstance(input, SymbolicInputKit):
#if isinstance(default, gof.Container):
## If the default is a gof.Container, this means we want to share
## the same storage. This is done by appending default.storage
## to input_storage
#if indices is not None:
#raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
#input_storage.append(default.storage)
#default = None
#required = False
if isinstance(input, SymbolicInputKit):
# If the input is a SymbolicInputKit, it represents more than
# one storage unit. The indices and subinputs lists represent which
# of the kit's inputs are active in this graph, so we make as many
# storage units as needed
if isinstance(default, (list, tuple)) \
and all(isinstance(x, gof.Container) for x in default):
if len(default) == len(indices):
input_storage += [x.storage for x in default]
elif len(default) > len(indices):
input_storage += [default[i].storage for i in indices]
else:
raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default)
default = NODEFAULT
else:
#if isinstance(default, (list, tuple)) \
#and all(isinstance(x, gof.Container) for x in default):
#if len(default) == len(indices):
#input_storage += [x.storage for x in default]
#elif len(default) > len(indices):
#input_storage += [default[i].storage for i in indices]
#else:
#raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default)
#default = NODEFAULT
#else:
input_storage += [[None] for i in indices]
else:
# Normal case: one new, independent storage unit
......@@ -695,7 +709,7 @@ class FunctionMaker(object):
# back into the storage as it would defeat the point of updating it. We
# always do this policy.
if default is None:
if trustme or isinstance(__default, gof.Container):
if trustme: #or isinstance(__default, gof.Container):
_defaults.append((False, False, None))
else:
# This might catch some bugs early
......@@ -704,7 +718,7 @@ class FunctionMaker(object):
_defaults.append((False, False, default))
else:
if default is None:
if trustme or isinstance(__default, gof.Container):
if trustme: #or isinstance(__default, gof.Container):
_defaults.append((False, False, None))
else:
# No default, so this is a required input. Nothing to feed back, initial value is None.
......@@ -820,6 +834,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
else:
#return a different kind of function
def dup_defaults():
# TODO This may need to be changed to use containers as defaults.
return [copy.copy(default.value) if isinstance(default, gof.Container) else
copy.copy(default)
for default in defaults]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论