提交 edd6fdba authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Changed behavior of non implicit inputs with shared container as default value

It is now possible to overwrite the shared container's content by providing a value as input. This is mostly to make the code simpler to write, understand, and use, at the cost of a small risk of accidentally overwriting a container's content.
上级 6cf9e0e8
...@@ -72,20 +72,14 @@ The ``inputs`` argument to ``theano.function`` is a list, containing the ``Varia ...@@ -72,20 +72,14 @@ The ``inputs`` argument to ``theano.function`` is a list, containing the ``Varia
``implicit``: Bool or ``None`` (default: ``None``) ``implicit``: Bool or ``None`` (default: ``None``)
``True``: This input is implicit in the sense that the user is not allowed ``True``: This input is implicit in the sense that the user is not allowed
to provide a value for it. Requires ``value`` to be set. Setting an to provide a value for it. Requires ``value`` to be set.
input as implicit allows Theano to directly share containers when ``False``: The user can provide a value for this input. Be careful
``value`` is an existing container. when ``value`` is a container, because providing an input value will
``False``: The user can provide a value for this input. In this case, overwrite the content of this container.
containers will not be shared (to avoid accidentally overwriting a
container's content with an input value provided by the user).
This means a function will create its own container, and will copy in
it the content of ``value`` at call time when ``value`` is a container.
Updates (if ``update`` is not ``None``) will be stored into the ``value``
container and not in the function container (which will be filled
with ``None`` instead, to make sure noone tries to use it by mistake).
``None``: Automatically choose between ``True`` or ``False`` depending on the ``None``: Automatically choose between ``True`` or ``False`` depending on the
situation. It will be set to ``False`` in all cases except if 'value' situation. It will be set to ``False`` in all cases except if
is a container (so that it can be shared by default). ``value`` is a container (so that there is less risk of accidentally
overwriting its content without being aware of it).
Value: initial and default values Value: initial and default values
...@@ -165,13 +159,9 @@ Note that when an input's ``value`` parameter is a shared container, this ...@@ -165,13 +159,9 @@ Note that when an input's ``value`` parameter is a shared container, this
input is considered as implicit by default. This means it cannot be set by the input is considered as implicit by default. This means it cannot be set by the
user. user.
If ``implicit`` is manually set to ``False``, then it can be set by the user, If ``implicit`` is manually set to ``False``, then it can be set by the user,
but the container will not directly be shared: instead, the content of the but then it will overwrite the container's content, so one should be careful
container will be copied at call time into a separate container. However, when allowing this.
updates will be performed in the container given by ``value``. The behavior This is illustrated in the following example.
in such situations may not be obvious, and thus is it advised not to set
``implicit`` to ``False`` when ``value`` is a shared container, unless you
understand what it means.
The following code illustrates this.
>>> dec(1, 0) # Try to manually set an implicit input >>> dec(1, 0) # Try to manually set an implicit input
<type 'exceptions.TypeError'>: Tried to provide value for implicit input: s <type 'exceptions.TypeError'>: Tried to provide value for implicit input: s
...@@ -181,16 +171,14 @@ None ...@@ -181,16 +171,14 @@ None
>>> inc[s] = 2 >>> inc[s] = 2
>>> dec(1) >>> dec(1)
[] []
>>> print inc[s] # Calling dec did decrease the value in inc's container >>> print inc[s] # Calling dec decreased the value in inc's container
1.0 1.0
>>> print dec[s] # Set back to None instead of storing the update
None
>>> dec(1, 0) # Update inc[s] with 0 - 1 = -1 >>> dec(1, 0) # Update inc[s] with 0 - 1 = -1
[] []
>>> print inc[s] >>> print inc[s]
-1.0 -1.0
>>> print dec[s] # Still set back to None >>> print dec[s] # Still shared.
None -1.0
Input Argument Restrictions Input Argument Restrictions
--------------------------- ---------------------------
......
...@@ -236,29 +236,14 @@ class Function(object): ...@@ -236,29 +236,14 @@ class Function(object):
if input.strict: if input.strict:
c.strict = True c.strict = True
# Whether the default value will be directly accessible within
# the function's container (c.copy_from_container = None), or
# if the function has its own container and thus needs to copy
# the default value at each call (c.copy_from_container =
# pointer towards it).
# Shared containers are only used for implicit inputs (so that
# there is no risk of overwriting their content with a user-
# provided value).
c.copy_from_container = None
if value is not None: if value is not None:
# Always initialize the storage. # Always initialize the storage.
if isinstance(value, gof.Container): if isinstance(value, gof.Container):
# There is no point in obtaining the current value # There is no point in obtaining the current value
# stored in the container, since: # stored in the container, since the container is
# - for an implicit input, the container is shared # shared.
# - for a non implicit input, the value may change # For safety, we make sure 'refeed' is False, since
# the function is called. # there is no need to refeed the defaullt value.
if not input.implicit:
c.copy_from_container = value
else:
# Safety check: the container will be shared, so
# there should be no need to refeed the default
# value.
assert not refeed assert not refeed
else: else:
c.value = value c.value = value
...@@ -379,14 +364,6 @@ class Function(object): ...@@ -379,14 +364,6 @@ class Function(object):
# Check if inputs are missing, or if inputs were set more than once, or # Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit. # if we tried to provide inputs that are supposed to be implicit.
# Also initialize default values that are obtained from an external
# container. This is required because this container's value may be
# modified between function calls.
# Other types of default values should not need to be re-initialized:
# - shared containers are updated automatically
# - default values defined directly by their value are re-fed into the
# input storage after a function call, and any modification possibly
# made to them (for mutable types) will be reflected there as well.
for c in self.input_storage: for c in self.input_storage:
if c.required and not c.provided: if c.required and not c.provided:
raise TypeError("Missing required input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c])) raise TypeError("Missing required input: %s" % getattr(self.inv_finder[c], 'variable', self.inv_finder[c]))
...@@ -396,12 +373,6 @@ class Function(object): ...@@ -396,12 +373,6 @@ class Function(object):
raise TypeError('Tried to provide value for implicit input: %s' raise TypeError('Tried to provide value for implicit input: %s'
% getattr(self.inv_finder[c], 'variable', % getattr(self.inv_finder[c], 'variable',
self.inv_finder[c])) self.inv_finder[c]))
if c.provided == 0 and c.copy_from_container is not None:
# Copy default value from another (non shared) container.
# Safety check, may be removed in the future.
assert not c.implicit
c.value = c.copy_from_container.value
# TODO Would it be better to use self[..] = value?
# Do the actual work # Do the actual work
self.fn() self.fn()
...@@ -409,8 +380,8 @@ class Function(object): ...@@ -409,8 +380,8 @@ class Function(object):
# Retrieve the values that were computed # Retrieve the values that were computed
outputs = [x.data for x in self.output_storage] outputs = [x.data for x in self.output_storage]
#remove internal references to required inputs # Remove internal references to required inputs.
#these can't be re-used anyway # These cannot be re-used anyway.
for x in self.input_storage: for x in self.input_storage:
if c.required: if c.required:
c.storage[0] = None c.storage[0] = None
...@@ -428,18 +399,7 @@ class Function(object): ...@@ -428,18 +399,7 @@ class Function(object):
# Update the inputs that have an update function # Update the inputs that have an update function
for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)): for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)):
if input.update is not None: if input.update is not None:
# If the storage is getting its value from another container,
# we want to update that other container.
store_into = getattr(storage, 'copy_from_container', None)
if store_into is None:
storage.data = outputs.pop() storage.data = outputs.pop()
else:
store_into.data = outputs.pop()
# Also store None in the function's storage. This ensures
# noone tries to use it by mistake (since it simply mirrors
# the content of 'store_into', but may not always be in
# synch with it).
storage.data = None
# Put default values back in the storage # Put default values back in the storage
for i, (required, refeed, value) in enumerate(self.defaults): for i, (required, refeed, value) in enumerate(self.defaults):
...@@ -722,10 +682,9 @@ class FunctionMaker(object): ...@@ -722,10 +682,9 @@ class FunctionMaker(object):
input_storage_i = input_storage_i.container input_storage_i = input_storage_i.container
if isinstance(input_storage_i, gof.Container): if isinstance(input_storage_i, gof.Container):
# If the default is a gof.Container and it is an implicit # If the default is a gof.Container, this means we want to
# input, this means we want to share the same storage. This is # share the same storage. This is done by appending
# done by appending input_storage_i.storage to # input_storage_i.storage to input_storage_lists.
# input_storage_lists.
if indices is not None: if indices is not None:
raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.") raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.")
input_storage_lists.append(input_storage_i.storage) input_storage_lists.append(input_storage_i.storage)
......
...@@ -148,21 +148,17 @@ class In(SymbolicInput): ...@@ -148,21 +148,17 @@ class In(SymbolicInput):
implicit: Bool or None (default: None) implicit: Bool or None (default: None)
True: This input is implicit in the sense that the user is not allowed True: This input is implicit in the sense that the user is not allowed
to provide a value for it. Requires 'value' to be set. Setting an to provide a value for it. Requires 'value' to be set.
input as implicit allows Theano to directly share containers when False: The user can provide a value for this input. Be careful when
'value' is an existing container. 'value' is a container, because providing an input value will
False: The user can provide a value for this input. In this case, overwrite the content of this container.
containers will not be shared (to avoid accidentally overwriting a
container's content with an input value provided by the user). This
means a function will create its own container, and will copy in it
the content of 'value' at call time when 'value' is a container.
Updates (if 'update' is not None) will be stored into the 'value'
container and not in the function container (which will be filled
with None instead, to make sure noone tries to use it by mistake).
None: Automatically choose between True or False depending on the None: Automatically choose between True or False depending on the
situation. It will be set to False in all cases except if 'value' situation. It will be set to False in all cases except if 'value'
is a container (so that it can be shared by default). is a container (so that there is less risk of accidentally
overwriting its content without being aware of it).
""" """
# Note: the documentation above is duplicated in doc/topics/function.txt,
# try to keep it synchronized.
def __init__(self, variable, name=None, value=None, update=None, def __init__(self, variable, name=None, value=None, update=None,
mutable=None, strict=False, autoname=True, mutable=None, strict=False, autoname=True,
implicit=None): implicit=None):
......
...@@ -7,7 +7,7 @@ from theano.compile.function_module import * ...@@ -7,7 +7,7 @@ from theano.compile.function_module import *
from theano import tensor from theano import tensor
from theano import tensor as T from theano import tensor as T
import random import random, theano
import numpy as N import numpy as N
...@@ -265,15 +265,13 @@ class T_function(unittest.TestCase): ...@@ -265,15 +265,13 @@ class T_function(unittest.TestCase):
inc = function([x, In(s, update=(s+x), value=10.0)], []) inc = function([x, In(s, update=(s+x), value=10.0)], [])
dec = function([x, In(s, update=(s-x), value=inc.container[s], dec = function([x, In(s, update=(s-x), value=inc.container[s],
implicit = False)], []) implicit = False)], [])
self.failUnless(inc[s] is not dec[s]) self.failUnless(dec[s] is inc[s])
self.failUnless(dec[s] is None)
inc[s] = 2 inc[s] = 2
dec(1) dec(1)
self.failUnless(inc[s] == 1) self.failUnless(inc[s] == 1)
self.failUnless(dec[s] is None)
dec(1, 0) dec(1, 0)
self.failUnless(inc[s] == -1) self.failUnless(inc[s] == -1)
self.failUnless(dec[s] is None) self.failUnless(dec[s] == -1)
class T_picklefunction(unittest.TestCase): class T_picklefunction(unittest.TestCase):
...@@ -527,7 +525,7 @@ if __name__ == '__main__': ...@@ -527,7 +525,7 @@ if __name__ == '__main__':
if 1: if 1:
unittest.main() unittest.main()
else: elif 0:
testcases = [] testcases = []
testcases.append(T_function) testcases.append(T_function)
...@@ -538,3 +536,11 @@ if __name__ == '__main__': ...@@ -538,3 +536,11 @@ if __name__ == '__main__':
suite.addTest(testloader.loadTestsFromTestCase(testcase)) suite.addTest(testloader.loadTestsFromTestCase(testcase))
unittest.TextTestRunner(verbosity=2).run(suite) unittest.TextTestRunner(verbosity=2).run(suite)
#</boilerplate> #</boilerplate>
elif 0:
theano.compile.mode.default_mode = 'FAST_COMPILE'
t = T_picklefunction()
def fu(b):
assert b
t.failUnless = fu
t.test_deepcopy_shared_container()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论