提交 165eb4e6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2201 from nouiz/yaoli-pickle_theano_function-cp

Allow Pickle/Unpickle of Theano function without recompiling
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
========================== ==========================
Frequently Asked Questions Frequently Asked Questions
========================== ==========================
TypeError: object of type 'TensorVariable' has no len() TypeError: object of type 'TensorVariable' has no len()
------------------------------------------------------- -------------------------------------------------------
...@@ -63,6 +62,13 @@ compilation but it will also use more memory because ...@@ -63,6 +62,13 @@ compilation but it will also use more memory because
``optimizer_excluding=inplace`` excludes inplace optimizations resulting ``optimizer_excluding=inplace`` excludes inplace optimizations resulting
in a trade off between speed of compilation and memory usage. in a trade off between speed of compilation and memory usage.
Theano flag `reoptimize_unpickled_function` controls if an unpickled theano function
should reoptimize its graph or not. Theano users can use the standard python pickle
tools to save a compiled theano function. When pickling, both graph before and
after the optimization are saved, including shared variables. When set to True,
the graph is reoptimized when being unpickled. Otherwise, skip the graph optimization
and use directly the optimized graph from the pickled file.
Faster Theano function Faster Theano function
---------------------- ----------------------
......
...@@ -683,6 +683,16 @@ import theano and print the config variable, as in: ...@@ -683,6 +683,16 @@ import theano and print the config variable, as in:
optimization phase. Theano user's do not need to use this. This is optimization phase. Theano user's do not need to use this. This is
to help debug shape error in Theano optimization. to help debug shape error in Theano optimization.
.. attribute:: config.reoptimize_unpickled_function
Bool value, default: True
Theano users can use the standard python pickle tools to save a compiled
theano function. When pickling, both graph before and after the optimization
are saved, including shared variables. When set to True, the graph is
reoptimized when being unpickled. Otherwise, skip the graph optimization and
use directly the optimized graph.
.. attribute:: config.exception_verbosity .. attribute:: config.exception_verbosity
String Value: ``'low'``, ``'high'``. String Value: ``'low'``, ``'high'``.
......
...@@ -118,6 +118,7 @@ AddConfigVar('print_active_device', ...@@ -118,6 +118,7 @@ AddConfigVar('print_active_device',
BoolParam(True, allow_override=False), BoolParam(True, allow_override=False),
in_c_key=False) in_c_key=False)
# Do not add FAST_RUN_NOGC to this list (nor any other ALL CAPS shortcut). # Do not add FAST_RUN_NOGC to this list (nor any other ALL CAPS shortcut).
# The way to get FAST_RUN_NOGC is with the flag 'linker=c|py_nogc'. # The way to get FAST_RUN_NOGC is with the flag 'linker=c|py_nogc'.
# The old all capital letter way of working is deprecated as it is not # The old all capital letter way of working is deprecated as it is not
...@@ -465,6 +466,12 @@ AddConfigVar('unpickle_function', ...@@ -465,6 +466,12 @@ AddConfigVar('unpickle_function',
BoolParam(True), BoolParam(True),
in_c_key=False) in_c_key=False)
AddConfigVar('reoptimize_unpickled_function',
"Re-optimize the graph when a theano function is unpickled from the disk.",
BoolParam(True, allow_override=True),
in_c_key=False)
"""Note to developers: """Note to developers:
Generally your exceptions should use an apply node's __str__ Generally your exceptions should use an apply node's __str__
method when exception_verbosity == 'low'. When exception_verbosity method when exception_verbosity == 'low'. When exception_verbosity
...@@ -538,3 +545,11 @@ AddConfigVar('check_input', ...@@ -538,3 +545,11 @@ AddConfigVar('check_input',
"(particularly for scalars) and reduce the number of generated C " "(particularly for scalars) and reduce the number of generated C "
"files.", "files.",
BoolParam(True)) BoolParam(True))
AddConfigVar('cache_optimizations',
"WARNING: work in progress, does not work yet."
"Specify if the optimization cache should be used. This cache will"
"any optimized graph and its optimization. Actually slow downs a lot"
"the first optimization, and could possibly still contains some bugs."
"Use at your own risks.",
BoolParam(False))
...@@ -662,6 +662,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -662,6 +662,7 @@ class DestroyHandler(toolbox.Bookkeeper):
The following data structures remain to be converted: The following data structures remain to be converted:
<unknown> <unknown>
""" """
pickle_rm_attr = ["destroyers"]
def __init__(self, do_imports_on_attach=True): def __init__(self, do_imports_on_attach=True):
self.fgraph = None self.fgraph = None
...@@ -720,15 +721,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -720,15 +721,7 @@ class DestroyHandler(toolbox.Bookkeeper):
" or in conflict with another plugin.") " or in conflict with another plugin.")
####### Annotate the FunctionGraph ############ ####### Annotate the FunctionGraph ############
self.unpickle(fgraph)
def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact()
try:
return [root_destroyer[droot[r]]]
except Exception:
return []
fgraph.destroyers = get_destroyers_of
fgraph.destroy_handler = self fgraph.destroy_handler = self
self.fgraph = fgraph self.fgraph = fgraph
...@@ -743,6 +736,15 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -743,6 +736,15 @@ class DestroyHandler(toolbox.Bookkeeper):
if self.do_imports_on_attach: if self.do_imports_on_attach:
toolbox.Bookkeeper.on_attach(self, fgraph) toolbox.Bookkeeper.on_attach(self, fgraph)
def unpickle(self, fgraph):
def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact()
try:
return [root_destroyer[droot[r]]]
except Exception:
return []
fgraph.destroyers = get_destroyers_of
def refresh_droot_impact(self): def refresh_droot_impact(self):
""" """
Makes sure self.droot, self.impact, and self.root_destroyer are Makes sure self.droot, self.impact, and self.root_destroyer are
......
...@@ -87,6 +87,11 @@ class FunctionGraph(utils.object2): ...@@ -87,6 +87,11 @@ class FunctionGraph(utils.object2):
#TODO: document what variables are[not] set in the FunctionGraph when a feature #TODO: document what variables are[not] set in the FunctionGraph when a feature
is added via the constructor. How constructed is the FunctionGraph? is added via the constructor. How constructed is the FunctionGraph?
Note: the intermediate nodes between 'inputs' and 'outputs' are not explicitely
passed.
:param inputs: inputs nodes of the graph, usually declared by the user
:param outputs: outputs nodes of the graph.
:param clone: If true, we will clone the graph. This is :param clone: If true, we will clone the graph. This is
useful to remove the constant cache problem. useful to remove the constant cache problem.
...@@ -724,17 +729,42 @@ class FunctionGraph(utils.object2): ...@@ -724,17 +729,42 @@ class FunctionGraph(utils.object2):
return self.__str__() return self.__str__()
### clone ### ### clone ###
def clone(self): def clone(self, check_integrity=True):
"""WRITEME""" """WRITEME"""
return self.clone_get_equiv()[0] return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv(self): def clone_get_equiv(self, check_integrity=True):
"""WRITEME""" """WRITEME"""
equiv = graph.clone_get_equiv(self.inputs, self.outputs) equiv = graph.clone_get_equiv(self.inputs, self.outputs)
self.check_integrity() if check_integrity:
self.check_integrity()
e = FunctionGraph([equiv[i] for i in self.inputs], e = FunctionGraph([equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs]) [equiv[o] for o in self.outputs])
e.check_integrity() if check_integrity:
e.check_integrity()
for feature in self._features: for feature in self._features:
e.attach_feature(feature) e.attach_feature(feature)
return e, equiv return e, equiv
def __getstate__(self):
"""This is needed as some feature introduce instancemethod and
this is not pickable.
"""
d = self.__dict__.copy()
for feature in self._features:
for attr in getattr(feature, "pickle_rm_attr", []):
del d[attr]
# The class Updater take fct as parameter and they are lambda function, so unpicklable.
# execute_callbacks_times have reference to optimizer, and they can't
# be pickled as the decorators with parameters aren't pickable.
if "execute_callbacks_times" in d:
del d["execute_callbacks_times"]
return d
def __setstate__(self, dct):
self.__dict__.update(dct)
for feature in self._features:
if hasattr(feature, "unpickle"):
feature.unpickle(self)
...@@ -878,6 +878,7 @@ def is_same_graph(var1, var2, givens=None, debug=False): ...@@ -878,6 +878,7 @@ def is_same_graph(var1, var2, givens=None, debug=False):
# Get result from the merge-based function. # Get result from the merge-based function.
rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens) rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
# Get result from the function `equal_computations` from scan_utils. # Get result from the function `equal_computations` from scan_utils.
use_equal_computations = True use_equal_computations = True
if givens: if givens:
# We need to build the `in_xs` and `in_ys` lists. To do this, we need # We need to build the `in_xs` and `in_ys` lists. To do this, we need
......
...@@ -22,8 +22,6 @@ import theano ...@@ -22,8 +22,6 @@ import theano
from theano import config from theano import config
from theano.gof.python25 import any, all, deque from theano.gof.python25 import any, all, deque
#if sys.version_info[:2] >= (2,5):
# from collections import defaultdict
_logger = logging.getLogger('theano.gof.opt') _logger = logging.getLogger('theano.gof.opt')
...@@ -1241,6 +1239,30 @@ class PatternSub(LocalOptimizer): ...@@ -1241,6 +1239,30 @@ class PatternSub(LocalOptimizer):
# Use the following classes to apply LocalOptimizers # Use the following classes to apply LocalOptimizers
class Updater:
def __init__(self, importer, pruner, chin):
self.importer = importer
self.pruner = pruner
self.chin = chin
def on_import(self, fgraph, node, reason):
if self.importer:
self.importer(node)
def on_prune(self, fgraph, node, reason):
if self.pruner:
self.pruner(node)
def on_change_input(self, fgraph, node, i, r, new_r, reason):
if self.chin:
self.chin(node, i, r, new_r, reason)
def on_detach(self, fgraph):
# To allow pickling this object
self.importer = None
self.pruner = None
self.chin = None
class NavigatorOptimizer(Optimizer): class NavigatorOptimizer(Optimizer):
"""Abstract class """Abstract class
...@@ -1329,18 +1351,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -1329,18 +1351,7 @@ class NavigatorOptimizer(Optimizer):
if importer is None and pruner is None: if importer is None and pruner is None:
return None return None
class Updater: u = Updater(importer, pruner, chin)
if importer is not None:
def on_import(self, fgraph, node, reason):
importer(node)
if pruner is not None:
def on_prune(self, fgraph, node, reason):
pruner(node)
if chin is not None:
def on_change_input(self, fgraph, node, i, r, new_r, reason):
chin(node, i, r, new_r, reason)
u = Updater()
fgraph.attach_feature(u) fgraph.attach_feature(u)
return u return u
......
import pickle
import unittest import unittest
import theano import theano
from theano.gof import CachedConstantError, FunctionGraph from theano.gof import CachedConstantError, FunctionGraph
from theano import tensor as tt
class TFunctionGraph(unittest.TestCase): class TFunctionGraph(unittest.TestCase):
...@@ -15,3 +17,10 @@ class TFunctionGraph(unittest.TestCase): ...@@ -15,3 +17,10 @@ class TFunctionGraph(unittest.TestCase):
v = theano.tensor.constant(1) v = theano.tensor.constant(1)
assert v.cached assert v.cached
FunctionGraph([], [v + 1]) FunctionGraph([], [v + 1])
def test_pickle(self):
v = tt.vector()
func = theano.gof.FunctionGraph([v], [v + 1])
s = pickle.dumps(func)
func2 = pickle.loads(s)
...@@ -104,7 +104,32 @@ class Bookkeeper(Feature): ...@@ -104,7 +104,32 @@ class Bookkeeper(Feature):
self.on_prune(fgraph, node, 'Bookkeeper.detach') self.on_prune(fgraph, node, 'Bookkeeper.detach')
class GetCheckpoint:
def __init__(self, history, fgraph):
self.h = history
self.fgraph = fgraph
def __call__(self):
return len(self.h.history[self.fgraph])
class LambdExtract:
def __init__(self, fgraph, node, i, r, reason=None):
self.fgraph = fgraph
self.node = node
self.i = i
self.r = r
self.reason = reason
def __call__(self):
return self.fgraph.change_input(self.node, self.i, self.r,
reason=("Revert", self.reason))
class History(Feature): class History(Feature):
pickle_rm_attr = ["checkpoint", "revert"]
def __init__(self): def __init__(self):
self.history = {} self.history = {}
...@@ -114,7 +139,14 @@ class History(Feature): ...@@ -114,7 +139,14 @@ class History(Feature):
raise AlreadyThere("History feature is already present or in" raise AlreadyThere("History feature is already present or in"
" conflict with another plugin.") " conflict with another plugin.")
self.history[fgraph] = [] self.history[fgraph] = []
fgraph.checkpoint = lambda: len(self.history[fgraph]) # Don't call unpickle here, as ReplaceValidate.on_attach()
# call to History.on_attach() will call the
# ReplaceValidate.unpickle and not History.unpickle
fgraph.checkpoint = GetCheckpoint(self, fgraph)
fgraph.revert = partial(self.revert, fgraph)
def unpickle(self, fgraph):
fgraph.checkpoint = GetCheckpoint(self, fgraph)
fgraph.revert = partial(self.revert, fgraph) fgraph.revert = partial(self.revert, fgraph)
def on_detach(self, fgraph): def on_detach(self, fgraph):
...@@ -126,8 +158,7 @@ class History(Feature): ...@@ -126,8 +158,7 @@ class History(Feature):
if self.history[fgraph] is None: if self.history[fgraph] is None:
return return
h = self.history[fgraph] h = self.history[fgraph]
h.append(lambda: fgraph.change_input(node, i, r, h.append(LambdExtract(fgraph, node, i, r, reason))
reason=("Revert", reason)))
def revert(self, fgraph, checkpoint): def revert(self, fgraph, checkpoint):
""" """
...@@ -144,47 +175,66 @@ class History(Feature): ...@@ -144,47 +175,66 @@ class History(Feature):
class Validator(Feature): class Validator(Feature):
pickle_rm_attr = ["validate", "consistent"]
def on_attach(self, fgraph): def on_attach(self, fgraph):
for attr in ('validate', 'validate_time'): for attr in ('validate', 'validate_time'):
if hasattr(fgraph, attr): if hasattr(fgraph, attr):
raise AlreadyThere("Validator feature is already present or in" raise AlreadyThere("Validator feature is already present or in"
" conflict with another plugin.") " conflict with another plugin.")
# Don't call unpickle here, as ReplaceValidate.on_attach()
# call to History.on_attach() will call the
# ReplaceValidate.unpickle and not History.unpickle
fgraph.validate = partial(self.validate_, fgraph)
fgraph.consistent = partial(self.consistent_, fgraph)
def validate(): def unpickle(self, fgraph):
t0 = time.time() fgraph.validate = partial(self.validate_, fgraph)
ret = fgraph.execute_callbacks('validate') fgraph.consistent = partial(self.consistent_, fgraph)
t1 = time.time()
if fgraph.profile:
fgraph.profile.validate_time += t1 - t0
return ret
fgraph.validate = validate
def consistent():
try:
fgraph.validate()
return True
except Exception:
return False
fgraph.consistent = consistent
def on_detach(self, fgraph): def on_detach(self, fgraph):
del fgraph.validate del fgraph.validate
del fgraph.consistent del fgraph.consistent
def validate_(self, fgraph):
t0 = time.time()
ret = fgraph.execute_callbacks('validate')
t1 = time.time()
if fgraph.profile:
fgraph.profile.validate_time += t1 - t0
return ret
def consistent_(self, fgraph):
try:
fgraph.validate()
return True
except Exception:
return False
class ReplaceValidate(History, Validator): class ReplaceValidate(History, Validator):
pickle_rm_attr = ["replace_validate", "replace_all_validate",
"replace_all_validate_remove"] + \
History.pickle_rm_attr + Validator.pickle_rm_attr
def on_attach(self, fgraph): def on_attach(self, fgraph):
History.on_attach(self, fgraph) for attr in ('replace_validate', 'replace_all_validate',
Validator.on_attach(self, fgraph) 'replace_all_validate_remove'):
for attr in ('replace_validate', 'replace_all_validate'):
if hasattr(fgraph, attr): if hasattr(fgraph, attr):
raise AlreadyThere("ReplaceValidate feature is already present" raise AlreadyThere("ReplaceValidate feature is already present"
" or in conflict with another plugin.") " or in conflict with another plugin.")
History.on_attach(self, fgraph)
Validator.on_attach(self, fgraph)
self.unpickle(fgraph)
def unpickle(self, fgraph):
History.unpickle(self, fgraph)
Validator.unpickle(self, fgraph)
fgraph.replace_validate = partial(self.replace_validate, fgraph) fgraph.replace_validate = partial(self.replace_validate, fgraph)
fgraph.replace_all_validate = partial(self.replace_all_validate, fgraph) fgraph.replace_all_validate = partial(self.replace_all_validate,
fgraph)
fgraph.replace_all_validate_remove = partial( fgraph.replace_all_validate_remove = partial(
self.replace_all_validate_remove, fgraph) self.replace_all_validate_remove, fgraph)
...@@ -247,6 +297,12 @@ class ReplaceValidate(History, Validator): ...@@ -247,6 +297,12 @@ class ReplaceValidate(History, Validator):
print >> out, reason, replacements print >> out, reason, replacements
raise ReplacementDidntRemovedError() raise ReplacementDidntRemovedError()
def __getstate__(self):
d = self.__dict__.copy()
if "history" in d:
del d["history"]
return d
class NodeFinder(Bookkeeper): class NodeFinder(Bookkeeper):
......
...@@ -694,7 +694,7 @@ class VM_Linker(link.LocalLinker): ...@@ -694,7 +694,7 @@ class VM_Linker(link.LocalLinker):
if k.owner and k.clients: if k.owner and k.clients:
ls = [] ls = []
for cl in k.clients: for cl in k.clients:
if cl[0] is not 'output': if cl[0] != 'output':
ls += cl[0].outputs ls += cl[0].outputs
dependencies[k] += ls dependencies[k] += ls
return dependencies return dependencies
......
...@@ -44,6 +44,8 @@ if MutableSet is not None: ...@@ -44,6 +44,8 @@ if MutableSet is not None:
import weakref import weakref
class Link(object): class Link(object):
# This make that we need to use a different pickle protocol
# then the default. Othewise, there is pickling errors
__slots__ = 'prev', 'next', 'key', '__weakref__' __slots__ = 'prev', 'next', 'key', '__weakref__'
def __getstate__(self): def __getstate__(self):
......
...@@ -1494,11 +1494,11 @@ class GemmOptimizer(Optimizer): ...@@ -1494,11 +1494,11 @@ class GemmOptimizer(Optimizer):
callbacks_before = fgraph.execute_callbacks_times.copy() callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time callback_before = fgraph.execute_callbacks_time
class Updater: def on_import(new_node):
def on_import(self, fgraph, new_node, reason): if new_node is not node:
if new_node is not node: nodelist.append(new_node)
nodelist.append(new_node)
u = Updater() u = theano.gof.opt.Updater(on_import, None, None)
fgraph.attach_feature(u) fgraph.attach_feature(u)
while did_something: while did_something:
nb_iter += 1 nb_iter += 1
......
...@@ -2664,6 +2664,7 @@ def local_useless_tile(node): ...@@ -2664,6 +2664,7 @@ def local_useless_tile(node):
except NotScalarConstantError: except NotScalarConstantError:
return return
################ ################
# Flatten Opts # # Flatten Opts #
################ ################
......
...@@ -53,10 +53,10 @@ def test_gc_never_pickles_temporaries(): ...@@ -53,10 +53,10 @@ def test_gc_never_pickles_temporaries():
len_pre_f = len(cPickle.dumps(f)) len_pre_f = len(cPickle.dumps(f))
len_pre_g = len(cPickle.dumps(g)) len_pre_g = len(cPickle.dumps(g))
# should be no difference at first # We can't compare the content or the length of the string
# In future, FunctionMaker might pickle linker-dependent stuff and make # between f and g. 2 reason, we store some timming information
# this assertion fail. # in float. They won't be the same each time. Different float
assert len_pre_f == len_pre_g # can have different lenght when printed.
def a(fn): def a(fn):
return len(cPickle.dumps(fn.maker)) return len(cPickle.dumps(fn.maker))
......
"""
This script tests the pickle and unpickle of theano functions.
When a compiled theano has shared vars, their values are also being pickled.
Side notes useful for debugging:
The pickling tools theano uses is here:
theano.compile.function_module._pickle_Function()
theano.compile.function_module._pickle_FunctionMaker()
Whether reoptimize the pickled function graph is handled by
FunctionMaker.__init__()
The config option is in configdefaults.py
This note is written by Li Yao.
"""
import unittest
import numpy
import cPickle
from theano.compat.python2x import DictMixin, OrderedDict
floatX = 'float32'
import theano
import theano.tensor as T
def test_pickle_unpickle_with_reoptimization():
mode = theano.config.mode
if mode in ["DEBUG_MODE", "DebugMode"]:
mode = "FAST_RUN"
x1 = T.fmatrix('x1')
x2 = T.fmatrix('x2')
x3 = theano.shared(numpy.ones((10, 10), dtype=floatX))
x4 = theano.shared(numpy.ones((10, 10), dtype=floatX))
y = T.sum(T.sum(T.sum(x1 ** 2 + x2) + x3) + x4)
updates = OrderedDict()
updates[x3] = x3 + 1
updates[x4] = x4 + 1
f = theano.function([x1, x2], y, updates=updates, mode=mode)
# now pickle the compiled theano fn
string_pkl = cPickle.dumps(f, -1)
in1 = numpy.ones((10, 10), dtype=floatX)
in2 = numpy.ones((10, 10), dtype=floatX)
# test unpickle with optimization
default = theano.config.reoptimize_unpickled_function
try:
# the default is True
theano.config.reoptimize_unpickled_function = True
f_ = cPickle.loads(string_pkl)
assert f(in1, in2) == f_(in1, in2)
finally:
theano.config.reoptimize_unpickled_function = default
def test_pickle_unpickle_without_reoptimization():
mode = theano.config.mode
if mode in ["DEBUG_MODE", "DebugMode"]:
mode = "FAST_RUN"
x1 = T.fmatrix('x1')
x2 = T.fmatrix('x2')
x3 = theano.shared(numpy.ones((10, 10), dtype=floatX))
x4 = theano.shared(numpy.ones((10, 10), dtype=floatX))
y = T.sum(T.sum(T.sum(x1**2 + x2) + x3) + x4)
updates = OrderedDict()
updates[x3] = x3 + 1
updates[x4] = x4 + 1
f = theano.function([x1, x2], y, updates=updates, mode=mode)
# now pickle the compiled theano fn
string_pkl = cPickle.dumps(f, -1)
# compute f value
in1 = numpy.ones((10, 10), dtype=floatX)
in2 = numpy.ones((10, 10), dtype=floatX)
# test unpickle without optimization
default = theano.config.reoptimize_unpickled_function
try:
# the default is True
theano.config.reoptimize_unpickled_function = False
f_ = cPickle.loads(string_pkl)
assert f(in1, in2) == f_(in1, in2)
finally:
theano.config.reoptimize_unpickled_function = default
if __name__ == '__main__':
test_pickle_unpickle_with_reoptimization()
test_pickle_unpickle_without_reoptimization()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论