提交 4ce2c854 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Better handling of implicit inputs to allow shared containers when possible.

上级 534c5f6c
......@@ -234,15 +234,30 @@ class Function(object):
c = containers[0] #containers is being used as a stack. Here we pop off the next one.
if input.strict:
c.strict = True
# Whether the default value will be directly accessible within
# the function's container (c.copy_from_container = None), or
# if the function has its own container and thus needs to copy
# the default value at each call (c.copy_from_container =
# pointer towards it).
# Shared containers are only used for implicit inputs (so that
# there is no risk of overwriting their content with a user-
# provided value).
c.copy_from_container = None
if value is not None:
# Always initialize the storage.
if isinstance(value, gof.Container):
# We obtain the default value from whatever value is currently
# stored in the default container.
assert len(value.storage) == 1
value = value.storage[0]
c.data = value
# There is no point in obtaining the current value
# stored in the container, since:
# - for an implicit input, the container is shared
# - for a non implicit input, the value may change
# the function is called.
if not input.implicit:
c.copy_from_container = value
else:
c.value = value
c.required = required
c.implicit = input.implicit
c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__)
finder[i] = c
finder[input.variable] = c
......@@ -252,6 +267,9 @@ class Function(object):
#setters.append(partial(assign, c))
containers[:1] = []
else:
# TODO The following code may need to do something to handle
# implicit inputs.
# The input is a SymbolicInputKit, so we take as many containers as the Kit provides inputs
cs = containers[:len(indices)]
# distribute does the initialization of the containers
......@@ -352,12 +370,33 @@ class Function(object):
# Set keyword arguments
for k, arg in kwargs.iteritems():
self[k] = arg
# Check if inputs are missing or if inputs were set more than once
# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
# Also initialize default values that are obtained from an external
# container. This is required because this container's value may be
# modified between function calls.
# Other types of default values should not need to be re-initialized:
# - shared containers are updated automatically
# - default values defined directly by their value are re-fed into the
# input storage after a function call, and any modification possibly
# made to them (for mutable types) will be reflected there as well.
for c in self.input_storage:
if c.required and not c.provided:
raise TypeError("Missing required input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c]))
if c.provided > 1:
raise TypeError("Multiple values for input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c]))
if c.implicit and c.provided > 0:
raise TypeError('Tried to provide value for implicit input: %s'
% getattr(self.inv_finder[c], 'variable',
self.inv_finder[c]))
if c.provided == 0 and c.copy_from_container is not None:
# Copy default value from another (non shared) container.
# Safety check, may be removed in the future.
assert not c.implicit
c.value = c.copy_from_container.value
# TODO Would it be better to use self[..] = value?
# Do the actual work
self.fn()
......@@ -382,14 +421,22 @@ class Function(object):
# Update the inputs that have an update function
for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)):
if input.update:
if input.update is not None:
storage.data = outputs.pop()
# Put default values back in the storage
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, gof.Container):
if value.implicit and value.copy_from_container is None:
# This is a shared container, so there is no need to
# re-feed anything. Thus refeed should be false.
# This safety check may be removed in the future since
# it should not be needed.
assert False
value = value.storage[0]
self[i] = value
if self.return_none:
return None
elif self.unpack_single and len(outputs) == 1:
......@@ -659,30 +706,31 @@ class FunctionMaker(object):
__default = default
#if isinstance(default, gof.Container):
## If the default is a gof.Container, this means we want to share
## the same storage. This is done by appending default.storage
## to input_storage
#if indices is not None:
#raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
#input_storage.append(default.storage)
#default = None
#required = False
if isinstance(input, SymbolicInputKit):
if isinstance(default, gof.Container) and input.implicit:
# If the default is a gof.Container and it is an implicit
# input, this means we want to share the same storage. This is
# done by appending default.storage to input_storage
if indices is not None:
raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
input_storage.append(default.storage)
default = None
required = False
elif isinstance(input, SymbolicInputKit):
# If the input is a SymbolicInputKit, it represents more than
# one storage unit. The indices and subinputs lists represent which
# of the kit's inputs are active in this graph, so we make as many
# storage units as needed
#if isinstance(default, (list, tuple)) \
#and all(isinstance(x, gof.Container) for x in default):
#if len(default) == len(indices):
#input_storage += [x.storage for x in default]
#elif len(default) > len(indices):
#input_storage += [default[i].storage for i in indices]
#else:
#raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default)
#default = NODEFAULT
#else:
if isinstance(default, (list, tuple)) \
and all(isinstance(x, gof.Container) for x in default) \
and input.implicit:
if len(default) == len(indices):
input_storage += [x.storage for x in default]
elif len(default) > len(indices):
input_storage += [default[i].storage for i in indices]
else:
raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default)
default = NODEFAULT
else:
input_storage += [[None] for i in indices]
else:
# Normal case: one new, independent storage unit
......@@ -709,7 +757,8 @@ class FunctionMaker(object):
# back into the storage as it would defeat the point of updating it. We
# always do this policy.
if default is None:
if trustme: #or isinstance(__default, gof.Container):
if (trustme or (isinstance(__default, gof.Container)
and input.implicit)):
_defaults.append((False, False, None))
else:
# This might catch some bugs early
......@@ -718,14 +767,15 @@ class FunctionMaker(object):
_defaults.append((False, False, default))
else:
if default is None:
if trustme: #or isinstance(__default, gof.Container):
if (trustme or (isinstance(__default, gof.Container)
and input.implicit)):
_defaults.append((False, False, None))
else:
# No default, so this is a required input. Nothing to feed back, initial value is None.
_defaults.append((True, False, None))
else:
# Default value. It is not required, but we want to put it back into the storage
# everytime so it behaves like most programming languages' default values
# everytime so it behaves like most programming languages' default values.
_defaults.append((False, True, default))
defaults = _defaults
......@@ -819,6 +869,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
"""
mode = mode if mode is not None else mode_module.default_mode
inputs = map(convert_function_input, inputs)
if outputs is not None:
outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论