提交 4ce2c854 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Better handling of implicit inputs to allow shared containers when possible.

上级 534c5f6c
...@@ -234,15 +234,30 @@ class Function(object): ...@@ -234,15 +234,30 @@ class Function(object):
c = containers[0] #containers is being used as a stack. Here we pop off the next one. c = containers[0] #containers is being used as a stack. Here we pop off the next one.
if input.strict: if input.strict:
c.strict = True c.strict = True
# Whether the default value will be directly accessible within
# the function's container (c.copy_from_container = None), or
# if the function has its own container and thus needs to copy
# the default value at each call (c.copy_from_container =
# pointer towards it).
# Shared containers are only used for implicit inputs (so that
# there is no risk of overwriting their content with a user-
# provided value).
c.copy_from_container = None
if value is not None: if value is not None:
# Always initialize the storage. # Always initialize the storage.
if isinstance(value, gof.Container): if isinstance(value, gof.Container):
# We obtain the default value from whatever value is currently # There is no point in obtaining the current value
# stored in the default container. # stored in the container, since:
assert len(value.storage) == 1 # - for an implicit input, the container is shared
value = value.storage[0] # - for a non implicit input, the value may change
c.data = value # the function is called.
if not input.implicit:
c.copy_from_container = value
else:
c.value = value
c.required = required c.required = required
c.implicit = input.implicit
c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__) c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__)
finder[i] = c finder[i] = c
finder[input.variable] = c finder[input.variable] = c
...@@ -252,6 +267,9 @@ class Function(object): ...@@ -252,6 +267,9 @@ class Function(object):
#setters.append(partial(assign, c)) #setters.append(partial(assign, c))
containers[:1] = [] containers[:1] = []
else: else:
# TODO The following code may need to do something to handle
# implicit inputs.
# The input is a SymbolicInputKit, so we take as many containers as the Kit provides inputs # The input is a SymbolicInputKit, so we take as many containers as the Kit provides inputs
cs = containers[:len(indices)] cs = containers[:len(indices)]
# distribute does the initialization of the containers # distribute does the initialization of the containers
...@@ -352,12 +370,33 @@ class Function(object): ...@@ -352,12 +370,33 @@ class Function(object):
# Set keyword arguments # Set keyword arguments
for k, arg in kwargs.iteritems(): for k, arg in kwargs.iteritems():
self[k] = arg self[k] = arg
# Check if inputs are missing or if inputs were set more than once
# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
# Also initialize default values that are obtained from an external
# container. This is required because this container's value may be
# modified between function calls.
# Other types of default values should not need to be re-initialized:
# - shared containers are updated automatically
# - default values defined directly by their value are re-fed into the
# input storage after a function call, and any modification possibly
# made to them (for mutable types) will be reflected there as well.
for c in self.input_storage: for c in self.input_storage:
if c.required and not c.provided: if c.required and not c.provided:
raise TypeError("Missing required input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c])) raise TypeError("Missing required input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c]))
if c.provided > 1: if c.provided > 1:
raise TypeError("Multiple values for input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c])) raise TypeError("Multiple values for input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c]))
if c.implicit and c.provided > 0:
raise TypeError('Tried to provide value for implicit input: %s'
% getattr(self.inv_finder[c], 'variable',
self.inv_finder[c]))
if c.provided == 0 and c.copy_from_container is not None:
# Copy default value from another (non shared) container.
# Safety check, may be removed in the future.
assert not c.implicit
c.value = c.copy_from_container.value
# TODO Would it be better to use self[..] = value?
# Do the actual work # Do the actual work
self.fn() self.fn()
...@@ -382,14 +421,22 @@ class Function(object): ...@@ -382,14 +421,22 @@ class Function(object):
# Update the inputs that have an update function # Update the inputs that have an update function
for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)): for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)):
if input.update: if input.update is not None:
storage.data = outputs.pop() storage.data = outputs.pop()
# Put default values back in the storage # Put default values back in the storage
for i, (required, refeed, value) in enumerate(self.defaults): for i, (required, refeed, value) in enumerate(self.defaults):
if refeed: if refeed:
if isinstance(value, gof.Container): if isinstance(value, gof.Container):
if value.implicit and value.copy_from_container is None:
# This is a shared container, so there is no need to
# re-feed anything. Thus refeed should be false.
# This safety check may be removed in the future since
# it should not be needed.
assert False
value = value.storage[0] value = value.storage[0]
self[i] = value self[i] = value
if self.return_none: if self.return_none:
return None return None
elif self.unpack_single and len(outputs) == 1: elif self.unpack_single and len(outputs) == 1:
...@@ -659,30 +706,31 @@ class FunctionMaker(object): ...@@ -659,30 +706,31 @@ class FunctionMaker(object):
__default = default __default = default
#if isinstance(default, gof.Container): if isinstance(default, gof.Container) and input.implicit:
## If the default is a gof.Container, this means we want to share # If the default is a gof.Container and it is an implicit
## the same storage. This is done by appending default.storage # input, this means we want to share the same storage. This is
## to input_storage # done by appending default.storage to input_storage
#if indices is not None: if indices is not None:
#raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.") raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
#input_storage.append(default.storage) input_storage.append(default.storage)
#default = None default = None
#required = False required = False
if isinstance(input, SymbolicInputKit): elif isinstance(input, SymbolicInputKit):
# If the input is a SymbolicInputKit, it represents more than # If the input is a SymbolicInputKit, it represents more than
# one storage unit. The indices and subinputs lists represent which # one storage unit. The indices and subinputs lists represent which
# of the kit's inputs are active in this graph, so we make as many # of the kit's inputs are active in this graph, so we make as many
# storage units as needed # storage units as needed
#if isinstance(default, (list, tuple)) \ if isinstance(default, (list, tuple)) \
#and all(isinstance(x, gof.Container) for x in default): and all(isinstance(x, gof.Container) for x in default) \
#if len(default) == len(indices): and input.implicit:
#input_storage += [x.storage for x in default] if len(default) == len(indices):
#elif len(default) > len(indices): input_storage += [x.storage for x in default]
#input_storage += [default[i].storage for i in indices] elif len(default) > len(indices):
#else: input_storage += [default[i].storage for i in indices]
#raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default) else:
#default = NODEFAULT raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default)
#else: default = NODEFAULT
else:
input_storage += [[None] for i in indices] input_storage += [[None] for i in indices]
else: else:
# Normal case: one new, independent storage unit # Normal case: one new, independent storage unit
...@@ -709,7 +757,8 @@ class FunctionMaker(object): ...@@ -709,7 +757,8 @@ class FunctionMaker(object):
# back into the storage as it would defeat the point of updating it. We # back into the storage as it would defeat the point of updating it. We
# always do this policy. # always do this policy.
if default is None: if default is None:
if trustme: #or isinstance(__default, gof.Container): if (trustme or (isinstance(__default, gof.Container)
and input.implicit)):
_defaults.append((False, False, None)) _defaults.append((False, False, None))
else: else:
# This might catch some bugs early # This might catch some bugs early
...@@ -718,14 +767,15 @@ class FunctionMaker(object): ...@@ -718,14 +767,15 @@ class FunctionMaker(object):
_defaults.append((False, False, default)) _defaults.append((False, False, default))
else: else:
if default is None: if default is None:
if trustme: #or isinstance(__default, gof.Container): if (trustme or (isinstance(__default, gof.Container)
and input.implicit)):
_defaults.append((False, False, None)) _defaults.append((False, False, None))
else: else:
# No default, so this is a required input. Nothing to feed back, initial value is None. # No default, so this is a required input. Nothing to feed back, initial value is None.
_defaults.append((True, False, None)) _defaults.append((True, False, None))
else: else:
# Default value. It is not required, but we want to put it back into the storage # Default value. It is not required, but we want to put it back into the storage
# everytime so it behaves like most programming languages' default values # everytime so it behaves like most programming languages' default values.
_defaults.append((False, True, default)) _defaults.append((False, True, default))
defaults = _defaults defaults = _defaults
...@@ -819,6 +869,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False): ...@@ -819,6 +869,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
""" """
mode = mode if mode is not None else mode_module.default_mode mode = mode if mode is not None else mode_module.default_mode
inputs = map(convert_function_input, inputs) inputs = map(convert_function_input, inputs)
if outputs is not None: if outputs is not None:
outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs) outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论