提交 63dc46bc authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6275 from nouiz/function_pickle

Make sure to pickle all attribute of Function (and FunctionMaker)
...@@ -2189,7 +2189,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2189,7 +2189,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
profile=None, profile=None,
on_unused_input=None, on_unused_input=None,
fgraph=None, # If present the optimized graph. we ignore it. fgraph=None, # If present the optimized graph. we ignore it.
output_keys=None): output_keys=None,
name=None):
self.profile = profile self.profile = profile
optimizer = mode.optimizer optimizer = mode.optimizer
# Handle the case where inputs and/or outputs is a single # Handle the case where inputs and/or outputs is a single
...@@ -2320,6 +2321,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2320,6 +2321,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
self.mode = mode self.mode = mode
self.on_unused_input = on_unused_input # Used for the pickling/copy self.on_unused_input = on_unused_input # Used for the pickling/copy
self.output_keys = output_keys self.output_keys = output_keys
self.name = name
def create(self, defaults=None, trustme=False, storage_map=None): def create(self, defaults=None, trustme=False, storage_map=None):
""" """
...@@ -2406,7 +2408,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2406,7 +2408,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
storage_map=storage_map) storage_map=storage_map)
fn = self.function_builder(_fn, _i, _o, self.indices, fn = self.function_builder(_fn, _i, _o, self.indices,
self.outputs, defaults, self.unpack_single, self.outputs, defaults, self.unpack_single,
self.return_none, self.output_keys, self) self.return_none, self.output_keys, self,
name=self.name)
return fn return fn
......
...@@ -10,7 +10,6 @@ import traceback as tb ...@@ -10,7 +10,6 @@ import traceback as tb
import re import re
from six import string_types from six import string_types
from theano.compile.io import In
from theano.compile.function_module import orig_function from theano.compile.function_module import orig_function
from theano.compile.pfunc import pfunc from theano.compile.pfunc import pfunc
import warnings import warnings
...@@ -289,13 +288,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -289,13 +288,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
uses_updates = bool(updates) uses_updates = bool(updates)
uses_givens = bool(givens) uses_givens = bool(givens)
# See if we have any mutable / borrow inputs
check_for_aliased_inputs = False
for i in inputs:
if (isinstance(i, In) and ((hasattr(i, 'borrow') and i.borrow) or
(hasattr(i, 'mutable') and i.mutable))):
check_for_aliased_inputs = True
if uses_tuple: if uses_tuple:
# we must use old semantics in this case. # we must use old semantics in this case.
if profile: if profile:
...@@ -323,7 +315,4 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -323,7 +315,4 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
profile=profile, profile=profile,
output_keys=output_keys) output_keys=output_keys)
# We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs
return fn return fn
...@@ -359,7 +359,8 @@ class Function(object): ...@@ -359,7 +359,8 @@ class Function(object):
""" """
def __init__(self, fn, input_storage, output_storage, indices, outputs, def __init__(self, fn, input_storage, output_storage, indices, outputs,
defaults, unpack_single, return_none, output_keys, maker): defaults, unpack_single, return_none, output_keys, maker,
name=None):
self.fn = fn self.fn = fn
self.input_storage = input_storage self.input_storage = input_storage
self.output_storage = output_storage self.output_storage = output_storage
...@@ -371,10 +372,19 @@ class Function(object): ...@@ -371,10 +372,19 @@ class Function(object):
self.maker = maker self.maker = maker
self.profile = None # reassigned in FunctionMaker.create self.profile = None # reassigned in FunctionMaker.create
self.trust_input = False # If True, we don't check the input parameter self.trust_input = False # If True, we don't check the input parameter
self.name = None self.name = name
self.nodes_with_inner_function = [] self.nodes_with_inner_function = []
self.output_keys = output_keys self.output_keys = output_keys
# See if we have any mutable / borrow inputs
# TODO: this only need to be set if there is more then 1 input
self._check_for_aliased_inputs = False
for i in maker.inputs:
if (isinstance(i, In) and ((hasattr(i, 'borrow') and i.borrow) or
(hasattr(i, 'mutable') and i.mutable))):
self._check_for_aliased_inputs = True
break
# We will be popping stuff off this `containers` object. It is a copy. # We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage) containers = list(self.input_storage)
finder = {} finder = {}
...@@ -821,6 +831,7 @@ class Function(object): ...@@ -821,6 +831,7 @@ class Function(object):
self[k] = arg self[k] = arg
if (not self.trust_input and if (not self.trust_input and
# The getattr is only needed for old pickle
getattr(self, '_check_for_aliased_inputs', True)): getattr(self, '_check_for_aliased_inputs', True)):
# Collect aliased inputs among the storage space # Collect aliased inputs among the storage space
args_share_memory = [] args_share_memory = []
...@@ -1047,19 +1058,25 @@ def _pickle_Function(f): ...@@ -1047,19 +1058,25 @@ def _pickle_Function(f):
(str(d_i), str(d_j))) (str(d_i), str(d_j)))
else: else:
raise AliasedMemoryError(d_i, d_j) raise AliasedMemoryError(d_i, d_j)
rval = (_constructor_Function, (f.maker, input_storage, inputs_data)) # The user can override trust_input. Our doc tell that. We should
# not do that anymore and make sure the Maker have all the
# information needed.
rval = (_constructor_Function,
(f.maker, input_storage, inputs_data, f.trust_input))
return rval return rval
def _constructor_Function(maker, input_storage, inputs_data): def _constructor_Function(maker, input_storage, inputs_data, trust_input=False):
if not theano.config.unpickle_function: if not theano.config.unpickle_function:
return None return None
f = maker.create(input_storage, trustme=True) f = maker.create(input_storage, trustme=True)
assert len(f.input_storage) == len(inputs_data) assert len(f.input_storage) == len(inputs_data)
for container, x in zip(f.input_storage, inputs_data): for container, x in zip(f.input_storage, inputs_data):
assert (container.data is x) or \ assert (container.data is x) or \
(isinstance(x, np.ndarray) and (container.data == x).all()) or \ (isinstance(x, np.ndarray) and (container.data == x).all()) or \
(container.data == x) (container.data == x)
f.trust_input = trust_input
return f return f
copyreg.pickle(Function, _pickle_Function) copyreg.pickle(Function, _pickle_Function)
...@@ -1185,6 +1202,9 @@ class FunctionMaker(object): ...@@ -1185,6 +1202,9 @@ class FunctionMaker(object):
- 'warn': log a warning - 'warn': log a warning
- 'ignore': do not do anything - 'ignore': do not do anything
- None: Use the value in the Theano flags on_unused_input. - None: Use the value in the Theano flags on_unused_input.
name : str
An optional name for this function. If used, the profile mode will
print the time spent in this function.
""" """
...@@ -1399,7 +1419,12 @@ class FunctionMaker(object): ...@@ -1399,7 +1419,12 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs, def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function, mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None, profile=None, on_unused_input=None, fgraph=None,
output_keys=None): output_keys=None, name=None):
# Save the provided mode, not the instanciated mode.
# The instanciated mode don't pickle and if we unpickle a Theano
# function and it get re-compiled, we want the current optimizer to be
# used, not the optimizer when it was saved.
self.mode = mode
mode = theano.compile.mode.get_mode(mode) mode = theano.compile.mode.get_mode(mode)
# Assert old way of working isn't used # Assert old way of working isn't used
...@@ -1538,18 +1563,18 @@ class FunctionMaker(object): ...@@ -1538,18 +1563,18 @@ class FunctionMaker(object):
# hacky thing so VMLinker knows about updates # hacky thing so VMLinker knows about updates
self.linker.accept_var_updates( self.linker.accept_var_updates(
fgraph_updated_vars(fgraph, inputs)) fgraph_updated_vars(fgraph, inputs))
fgraph.name = name
self.indices = indices self.indices = indices
self.inputs = inputs self.inputs = inputs
self.expanded_inputs = inputs self.expanded_inputs = inputs
self.outputs = outputs self.outputs = outputs
self.unpack_single = unpack_single self.unpack_single = unpack_single
self.return_none = return_none self.return_none = return_none
self.mode = mode
self.accept_inplace = accept_inplace self.accept_inplace = accept_inplace
self.function_builder = function_builder self.function_builder = function_builder
self.on_unused_input = on_unused_input # Used for the pickling/copy self.on_unused_input = on_unused_input # Used for the pickling/copy
self.output_keys = output_keys self.output_keys = output_keys
self.name = name
self.required = [(i.value is None) for i in self.inputs] self.required = [(i.value is None) for i in self.inputs]
self.refeed = [ self.refeed = [
...@@ -1696,26 +1721,16 @@ class FunctionMaker(object): ...@@ -1696,26 +1721,16 @@ class FunctionMaker(object):
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single, defaults, self.unpack_single,
self.return_none, self.output_keys, self) self.return_none, self.output_keys, self,
name=self.name)
fn.profile = self.profile fn.profile = self.profile
return fn return fn
def _pickle_FunctionMaker(self):
kwargs = dict(
inputs=self.inputs,
outputs=self.orig_outputs,
fgraph=self.fgraph,
mode=self.mode,
accept_inplace=self.accept_inplace,
function_builder=self.function_builder,
profile=self.profile,
on_unused_input=self.on_unused_input)
return (_constructor_FunctionMaker, (kwargs,))
def _constructor_FunctionMaker(kwargs): def _constructor_FunctionMaker(kwargs):
# Needed for old pickle
# Old pickle have at least the problem that output_keys where not saved.
if theano.config.unpickle_function: if theano.config.unpickle_function:
if theano.config.reoptimize_unpickled_function: if theano.config.reoptimize_unpickled_function:
del kwargs['fgraph'] del kwargs['fgraph']
...@@ -1723,8 +1738,6 @@ def _constructor_FunctionMaker(kwargs): ...@@ -1723,8 +1738,6 @@ def _constructor_FunctionMaker(kwargs):
else: else:
return None return None
copyreg.pickle(FunctionMaker, _pickle_FunctionMaker)
__checkers = [] __checkers = []
...@@ -1814,7 +1827,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1814,7 +1827,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
accept_inplace=accept_inplace, accept_inplace=accept_inplace,
profile=profile, profile=profile,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
output_keys=output_keys) output_keys=output_keys,
name=name)
with theano.configparser.change_flags(compute_test_value="off"): with theano.configparser.change_flags(compute_test_value="off"):
fn = m.create(defaults) fn = m.create(defaults)
finally: finally:
...@@ -1824,8 +1838,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1824,8 +1838,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
# TODO: append # TODO: append
profile.nb_nodes = len(fn.maker.fgraph.apply_nodes) profile.nb_nodes = len(fn.maker.fgraph.apply_nodes)
fn.name = name
fn.maker.fgraph.name = name
return fn return fn
......
...@@ -291,7 +291,6 @@ class Mode(object): ...@@ -291,7 +291,6 @@ class Mode(object):
self._optimizer = optimizer self._optimizer = optimizer
self.call_time = 0 self.call_time = 0
self.fn_time = 0 self.fn_time = 0
linker.mode = self # TODO: WHY IS THIS HERE?
def __str__(self): def __str__(self):
return "%s(linker = %s, optimizer = %s)" % (self.__class__.__name__, return "%s(linker = %s, optimizer = %s)" % (self.__class__.__name__,
......
...@@ -590,8 +590,8 @@ class T_picklefunction(unittest.TestCase): ...@@ -590,8 +590,8 @@ class T_picklefunction(unittest.TestCase):
x, s = T.scalars('xs') x, s = T.scalars('xs')
f = function([x, In(a, value=1.0, name='a'), f = function([x, In(a, value=1.0, name='a'),
In(s, value=0.0, update=s + a * x, mutable=True)], s + a * x) In(s, value=0.0, update=s + a * x, mutable=True)],
s + a * x)
try: try:
g = copy.deepcopy(f) g = copy.deepcopy(f)
except NotImplementedError as e: except NotImplementedError as e:
...@@ -609,6 +609,9 @@ class T_picklefunction(unittest.TestCase): ...@@ -609,6 +609,9 @@ class T_picklefunction(unittest.TestCase):
self.assertFalse(x in g.container) self.assertFalse(x in g.container)
self.assertFalse(x in g.value) self.assertFalse(x in g.value)
self.assertTrue(len(f.defaults) == len(g.defaults)) self.assertTrue(len(f.defaults) == len(g.defaults))
self.assertTrue(f._check_for_aliased_inputs is g._check_for_aliased_inputs)
self.assertTrue(f.name == g.name)
self.assertTrue(f.maker.fgraph.name == f.maker.fgraph.name)
# print 'f.defaults = %s' % (f.defaults, ) # print 'f.defaults = %s' % (f.defaults, )
# print 'g.defaults = %s' % (g.defaults, ) # print 'g.defaults = %s' % (g.defaults, )
self.assertTrue(all([f_req == g_req and f_feed == g_feed and self.assertTrue(all([f_req == g_req and f_feed == g_feed and
...@@ -627,6 +630,34 @@ class T_picklefunction(unittest.TestCase): ...@@ -627,6 +630,34 @@ class T_picklefunction(unittest.TestCase):
g(1, 2) # put them back in sync g(1, 2) # put them back in sync
self.assertTrue(f(3) == g(3)) # They should be in sync again. self.assertTrue(f(3) == g(3)) # They should be in sync again.
def test_deepcopy_trust_input(self):
a = T.dscalar() # the a is for 'anonymous' (un-named).
x, s = T.dscalars('xs')
f = function([x, In(a, value=1.0, name='a'),
In(s, value=0.0, update=s + a * x, mutable=True)],
s + a * x)
f.trust_input = True
try:
g = copy.deepcopy(f)
except NotImplementedError as e:
if e[0].startswith('DebugMode is not picklable'):
return
else:
raise
self.assertTrue(f.trust_input is g.trust_input)
f(np.asarray(2.))
self.assertRaises((ValueError, AttributeError), f, 2.)
g(np.asarray(2.))
self.assertRaises((ValueError, AttributeError), g, 2.)
def test_output_keys(self):
x = T.vector()
f = theano.function([x], {'vec': x**2})
assert isinstance(f([2, 3, 4]), dict)
g = copy.deepcopy(f)
assert isinstance(g([2, 3, 4]), dict)
def test_deepcopy_shared_container(self): def test_deepcopy_shared_container(self):
# Ensure that shared containers remain shared after a deep copy. # Ensure that shared containers remain shared after a deep copy.
a, x = T.scalars('ax') a, x = T.scalars('ax')
......
...@@ -201,7 +201,7 @@ def deprecated_gpuarray_sync(val): ...@@ -201,7 +201,7 @@ def deprecated_gpuarray_sync(val):
AddConfigVar('gpuarray.sync', AddConfigVar('gpuarray.sync',
"""This flag is deprecated and will be removed in next Theano release.""", """This flag is deprecated and will be removed in next Theano release.""",
ConfigParam(False, allow_override=False, filter=deprecated_gpuarray_sync), ConfigParam(False, allow_override=False, filter=deprecated_gpuarray_sync),
in_c_key=True) in_c_key=False)
AddConfigVar('gpuarray.preallocate', AddConfigVar('gpuarray.preallocate',
"""If negative it disables the allocation cache. If """If negative it disables the allocation cache. If
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论