提交 5fc50232 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged

...@@ -234,10 +234,35 @@ class Function(object): ...@@ -234,10 +234,35 @@ class Function(object):
c = containers[0] #containers is being used as a stack. Here we pop off the next one. c = containers[0] #containers is being used as a stack. Here we pop off the next one.
if input.strict: if input.strict:
c.strict = True c.strict = True
# Whether the default value will be directly accessible within
# the function's container (c.copy_from_container = None), or
# if the function has its own container and thus needs to copy
# the default value at each call (c.copy_from_container =
# pointer towards it).
# Shared containers are only used for implicit inputs (so that
# there is no risk of overwriting their content with a user-
# provided value).
c.copy_from_container = None
if value is not None: if value is not None:
# always initialize the storage # Always initialize the storage.
c.data = value if isinstance(value, gof.Container):
# There is no point in obtaining the current value
# stored in the container, since:
# - for an implicit input, the container is shared
# - for a non implicit input, the value may change
# the function is called.
if not input.implicit:
c.copy_from_container = value
else:
# Safety check: the container will be shared, so
# there should be no need to refeed the default
# value.
assert not refeed
else:
c.value = value
c.required = required c.required = required
c.implicit = input.implicit
c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__) c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__)
finder[i] = c finder[i] = c
finder[input.variable] = c finder[input.variable] = c
...@@ -247,6 +272,9 @@ class Function(object): ...@@ -247,6 +272,9 @@ class Function(object):
#setters.append(partial(assign, c)) #setters.append(partial(assign, c))
containers[:1] = [] containers[:1] = []
else: else:
# TODO The following code may need to do something to handle
# implicit inputs.
# The input is a SymbolicInputKit, so we take as many containers as the Kit provides inputs # The input is a SymbolicInputKit, so we take as many containers as the Kit provides inputs
cs = containers[:len(indices)] cs = containers[:len(indices)]
# distribute does the initialization of the containers # distribute does the initialization of the containers
...@@ -347,12 +375,33 @@ class Function(object): ...@@ -347,12 +375,33 @@ class Function(object):
# Set keyword arguments # Set keyword arguments
for k, arg in kwargs.iteritems(): for k, arg in kwargs.iteritems():
self[k] = arg self[k] = arg
# Check if inputs are missing or if inputs were set more than once
# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
# Also initialize default values that are obtained from an external
# container. This is required because this container's value may be
# modified between function calls.
# Other types of default values should not need to be re-initialized:
# - shared containers are updated automatically
# - default values defined directly by their value are re-fed into the
# input storage after a function call, and any modification possibly
# made to them (for mutable types) will be reflected there as well.
for c in self.input_storage: for c in self.input_storage:
if c.required and not c.provided: if c.required and not c.provided:
raise TypeError("Missing required input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c])) raise TypeError("Missing required input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c]))
if c.provided > 1: if c.provided > 1:
raise TypeError("Multiple values for input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c])) raise TypeError("Multiple values for input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c]))
if c.implicit and c.provided > 0:
raise TypeError('Tried to provide value for implicit input: %s'
% getattr(self.inv_finder[c], 'variable',
self.inv_finder[c]))
if c.provided == 0 and c.copy_from_container is not None:
# Copy default value from another (non shared) container.
# Safety check, may be removed in the future.
assert not c.implicit
c.value = c.copy_from_container.value
# TODO Would it be better to use self[..] = value?
# Do the actual work # Do the actual work
self.fn() self.fn()
...@@ -377,12 +426,16 @@ class Function(object): ...@@ -377,12 +426,16 @@ class Function(object):
# Update the inputs that have an update function # Update the inputs that have an update function
for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)): for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)):
if input.update: if input.update is not None:
storage.data = outputs.pop() storage.data = outputs.pop()
# Put default values back in the storage # Put default values back in the storage
for i, (required, refeed, value) in enumerate(self.defaults): for i, (required, refeed, value) in enumerate(self.defaults):
if refeed: if refeed:
if isinstance(value, gof.Container):
value = value.storage[0]
self[i] = value self[i] = value
if self.return_none: if self.return_none:
return None return None
elif self.unpack_single and len(outputs) == 1: elif self.unpack_single and len(outputs) == 1:
...@@ -643,12 +696,19 @@ class FunctionMaker(object): ...@@ -643,12 +696,19 @@ class FunctionMaker(object):
# The following loop is to fill in the input_storage and _defaults lists. # The following loop is to fill in the input_storage and _defaults lists.
for (input, indices, subinputs), default in zip(self.indices, defaults): for (input, indices, subinputs), default in zip(self.indices, defaults):
# Replace any default value given as a variable by its container.
# Note that this makes sense only in the context of shared variables,
# but for now we avoid dealing directly with them to avoid dependency
# on the shared variables work-in-progress repository.
if isinstance(default, gof.Variable):
default = default.container
__default = default __default = default
if isinstance(default, gof.Container): if isinstance(default, gof.Container) and input.implicit:
# If the default is a gof.Container, this means we want to share # If the default is a gof.Container and it is an implicit
# the same storage. This is done by appending default.storage # input, this means we want to share the same storage. This is
# to input_storage # done by appending default.storage to input_storage
if indices is not None: if indices is not None:
raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.") raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
input_storage.append(default.storage) input_storage.append(default.storage)
...@@ -660,7 +720,8 @@ class FunctionMaker(object): ...@@ -660,7 +720,8 @@ class FunctionMaker(object):
# of the kit's inputs are active in this graph, so we make as many # of the kit's inputs are active in this graph, so we make as many
# storage units as needed # storage units as needed
if isinstance(default, (list, tuple)) \ if isinstance(default, (list, tuple)) \
and all(isinstance(x, gof.Container) for x in default): and all(isinstance(x, gof.Container) for x in default) \
and input.implicit:
if len(default) == len(indices): if len(default) == len(indices):
input_storage += [x.storage for x in default] input_storage += [x.storage for x in default]
elif len(default) > len(indices): elif len(default) > len(indices):
...@@ -695,7 +756,8 @@ class FunctionMaker(object): ...@@ -695,7 +756,8 @@ class FunctionMaker(object):
# back into the storage as it would defeat the point of updating it. We # back into the storage as it would defeat the point of updating it. We
# always do this policy. # always do this policy.
if default is None: if default is None:
if trustme or isinstance(__default, gof.Container): if (trustme or (isinstance(__default, gof.Container)
and input.implicit)):
_defaults.append((False, False, None)) _defaults.append((False, False, None))
else: else:
# This might catch some bugs early # This might catch some bugs early
...@@ -704,7 +766,8 @@ class FunctionMaker(object): ...@@ -704,7 +766,8 @@ class FunctionMaker(object):
_defaults.append((False, False, default)) _defaults.append((False, False, default))
else: else:
if default is None: if default is None:
if trustme or isinstance(__default, gof.Container): if (trustme or (isinstance(__default, gof.Container)
and input.implicit)):
_defaults.append((False, False, None)) _defaults.append((False, False, None))
else: else:
# No default, so this is a required input. Nothing to feed back, initial value is None. # No default, so this is a required input. Nothing to feed back, initial value is None.
...@@ -809,6 +872,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False): ...@@ -809,6 +872,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
""" """
mode = mode if mode is not None else mode_module.default_mode mode = mode if mode is not None else mode_module.default_mode
inputs = map(convert_function_input, inputs) inputs = map(convert_function_input, inputs)
if outputs is not None: if outputs is not None:
outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs) outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs)
...@@ -824,6 +888,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False): ...@@ -824,6 +888,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
else: else:
#return a different kind of function #return a different kind of function
def dup_defaults(): def dup_defaults():
# TODO This may need to be changed to use containers as defaults.
return [copy.copy(default.value) if isinstance(default, gof.Container) else return [copy.copy(default.value) if isinstance(default, gof.Container) else
copy.copy(default) copy.copy(default)
for default in defaults] for default in defaults]
......
"""Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out` """ """Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out` """
__docformat__ = 'restructuredtext en' __docformat__ = 'restructuredtext en'
from theano import gof
class SymbolicInput(object): class SymbolicInput(object):
""" """
Represents a symbolic input for use with function or FunctionMaker. Represents a symbolic input for use with function or FunctionMaker.
...@@ -27,9 +29,15 @@ class SymbolicInput(object): ...@@ -27,9 +29,15 @@ class SymbolicInput(object):
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
implicit: Bool (default: False)
See help(In). Note that 'None' is not allowed here, since we are in the
symbolic case.
""" """
def __init__(self, variable, name=None, update=None, mutable=None, strict=False, autoname=True): def __init__(self, variable, name=None, update=None, mutable=None, strict=False, autoname=True,
implicit=False):
assert implicit is not None # Safety check.
self.variable = variable self.variable = variable
self.name = variable.name if (autoname and name is None) else name self.name = variable.name if (autoname and name is None) else name
if self.name is not None and not isinstance(self.name, str): if self.name is not None and not isinstance(self.name, str):
...@@ -37,6 +45,7 @@ class SymbolicInput(object): ...@@ -37,6 +45,7 @@ class SymbolicInput(object):
self.update = update self.update = update
self.mutable = mutable if (mutable is not None) else (update is not None) self.mutable = mutable if (mutable is not None) else (update is not None)
self.strict = strict self.strict = strict
self.implicit = implicit
def __str__(self): def __str__(self):
if self.update: if self.update:
...@@ -136,10 +145,29 @@ class In(SymbolicInput): ...@@ -136,10 +145,29 @@ class In(SymbolicInput):
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
implicit: Bool or None (default: None)
True: This input is implicit in the sense that the user is not allowed
to provide a value for it. Requires 'value' to be set. Setting an
input as implicit allows Theano to directly share containers when
'value' is an existing container.
False: The user can provide a value for this input. In this case,
containers will not be shared (to avoid accidentally overwriting a
container's content with an input value provided by the user).
None: Automatically choose between True or False depending on the
situation. It will be set to False in all cases except if 'value'
is a container (so that it can be shared by default).
""" """
def __init__(self, variable, name=None, value=None, update=None, mutable=None, strict=False, autoname=True): def __init__(self, variable, name=None, value=None, update=None,
super(In, self).__init__(variable, name, update, mutable, strict, autoname) mutable=None, strict=False, autoname=True,
implicit=None):
if implicit is None:
implicit = isinstance(value, gof.Container)
super(In, self).__init__(variable, name, update, mutable, strict,
autoname, implicit = implicit)
self.value = value self.value = value
if self.implicit and value is None:
raise TypeError('An implicit input must be given a default value')
class SymbolicOutput(object): class SymbolicOutput(object):
......
...@@ -405,7 +405,8 @@ class Method(Component): ...@@ -405,7 +405,8 @@ class Method(Component):
variable=k, variable=k,
update=v, update=v,
value=get_storage(k, not allocate_all).value, value=get_storage(k, not allocate_all).value,
mutable=True) mutable=True,
implicit = True)
inputs.append(input_k) inputs.append(input_k)
else: else:
raise ValueError(('Variable listed in both inputs and updates.' raise ValueError(('Variable listed in both inputs and updates.'
...@@ -437,6 +438,13 @@ class Method(Component): ...@@ -437,6 +438,13 @@ class Method(Component):
assert storage.mutable == False assert storage.mutable == False
else: else:
storage = get_storage(input, not allocate_all) storage = get_storage(input, not allocate_all)
# Declare as an implicit input.
# TODO Note from OD: is this dangerous? (in case this storage
# is shared, and would sometimes need to be implicit, sometimes
# not).
storage.implicit = True
assert type(storage) is io.In assert type(storage) is io.In
inputs.append(storage) inputs.append(storage)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论