提交 165eb4e6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2201 from nouiz/yaoli-pickle_theano_function-cp

Allow Pickle/Unpickle of Theano function without recompiling
......@@ -4,7 +4,6 @@
==========================
Frequently Asked Questions
==========================
TypeError: object of type 'TensorVariable' has no len()
-------------------------------------------------------
......@@ -63,6 +62,13 @@ compilation but it will also use more memory because
``optimizer_excluding=inplace`` excludes inplace optimizations resulting
in a trade off between speed of compilation and memory usage.
Theano flag `reoptimize_unpickled_function` controls if an unpickled theano function
should reoptimize its graph or not. Theano users can use the standard python pickle
tools to save a compiled theano function. When pickling, both graph before and
after the optimization are saved, including shared variables. When set to True,
the graph is reoptimized when being unpickled. Otherwise, skip the graph optimization
and use directly the optimized graph from the pickled file.
Faster Theano function
----------------------
......
......@@ -683,6 +683,16 @@ import theano and print the config variable, as in:
optimization phase. Theano user's do not need to use this. This is
to help debug shape error in Theano optimization.
.. attribute:: config.reoptimize_unpickled_function
Bool value, default: True
Theano users can use the standard python pickle tools to save a compiled
theano function. When pickling, both graph before and after the optimization
are saved, including shared variables. When set to True, the graph is
reoptimized when being unpickled. Otherwise, skip the graph optimization and
use directly the optimized graph.
.. attribute:: config.exception_verbosity
String Value: ``'low'``, ``'high'``.
......
......@@ -19,10 +19,9 @@ import theano.compile.mode
from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
from theano.compile.ops import deep_copy_op, view_op
from theano.gof.graph import is_same_graph
from theano.gof.op import ops_with_inner_function
import logging
_logger = logging.getLogger('theano.compile.function_module')
......@@ -737,7 +736,6 @@ def _pickle_Function(f):
' operation') %(str(d_i), str(d_j)))
else:
raise AliasedMemoryError(d_i, d_j)
rval = (_constructor_Function, (f.maker, input_storage, inputs_data))
return rval
......@@ -970,10 +968,180 @@ class FunctionMaker(object):
return SymbolicOutput(output)
else:
raise TypeError("Unknown output type: %s (%s)", type(output), output)
def retrieve_fgraph_from_opt_cache():
# This function is not finished
raise NotImplementedError('optimization cache is not finished! Should not be called.')
from theano.gof.compilelock import get_lock, release_lock
import os.path
graph_db_file = os.path.join(theano.config.compiledir, 'optimized_graphs.pkl')
# the inputs, outputs, and size of the graph to be optimized
inputs_new = [inp.variable for inp in inputs]
outputs_new = [out.variable for out in outputs]
size_new = len(fgraph.apply_nodes)
need_optimize = False
get_lock()
key = None
#Beginning of cache optimizations.
#Could be refactored in different functions.
if theano.config.cache_optimizations: #set to false by default
'''
graph_db and need_optimize
'''
if os.path.isfile(graph_db_file):
print 'graph_db exists'
else:
# create graph_db
f = open(graph_db_file, 'wb')
print 'created new graph_db %s' % graph_db_file
#file needs to be open and closed for every pickle
f.close()
# load the graph_db dictionary
try:
f = open(graph_db_file, 'rb')
#Temporary hack to allow theano.scan_module.tests.test_scan.T_Scan
#to finish. Should be changed in definitive version.
tmp = theano.config.unpickle_function
theano.config.unpickle_function = False
graph_db = cPickle.load(f)
theano.config.unpickle_function = tmp
#hack end
f.close()
print 'graph_db is not empty'
except EOFError, e:
# the file has nothing in it
print e
print 'graph_db is empty'
graph_db = {}
need_optimize = True
print 'loaded graph_db from %s, size=%d' % (graph_db_file, len(graph_db))
# the sole purpose of this loop is to set 'need_optimize'
for i, graph_old in enumerate(graph_db.keys()):
inputs_old = graph_old.inputs
outputs_old = graph_old.outputs
size_old = len(graph_old.apply_nodes)
print 'looping through graph_db %d/%d' % (i + 1, len(graph_db))
# Some heuristics to check is the same graphs have
# already been optimized before.
if len(inputs_new) != len(inputs_old):
# If the inputs are of different size,
# two graphs are for sure different
print 'need to optimize, because input size is different'
continue
elif len(outputs_new) != len(outputs_old):
# If the inputs are of different size,
# two graphs are for sure different
print 'need to optimize, because output size is different'
continue
elif not all(input_new.type == input_old.type for
input_new, input_old in zip(inputs_new, inputs_old)):
print 'need to optimize, because inputs are of different types'
continue
elif not all(output_new.type == output_old.type for
output_new, output_old in zip(outputs_new, outputs_old)):
print 'need to optimize, because outputs are of different types'
continue
elif not size_old == size_new:
print 'need to optimize, because numbers of nodes in graph are different'
continue
else:
flags = []
for output_new, output_old, i in zip(outputs_new, outputs_old, range(len(outputs_new))):
print 'loop through outputs node for both graphs'
graph_old.variables = set(gof.graph.variables(graph_old.inputs, graph_old.outputs))
#using clone allowed to avoid a lot of errors
#deep copy seemed to had.
f2 = graph_old.clone(check_integrity=False)
t1 = output_new
t2 = f2.outputs[i]
#Used to remove "already used by another graph error
def removeAllFgraph(remove):
if hasattr(remove, 'fgraph'):
del remove.fgraph
if hasattr(remove, 'owner'):
if remove.owner == None:
pass
else:
if hasattr(remove.owner, 'fgraph'):
del remove.owner.fgraph
if hasattr(remove.owner, 'inputs'):
remove.owner.inputs = [removeAllFgraph(
i) for i in remove.owner.inputs]
for o in remove.owner.outputs:
if hasattr(o, 'fgraph'):
del o.fgraph
return remove
t2 = removeAllFgraph(t2)
givens = dict(zip(gof.graph.inputs([t1]),
gof.graph.inputs([t2])))
temp = dict(zip(gof.graph.inputs([t1]),
gof.graph.inputs([t2])))
#hack to remove inconstent entry in givens
#seems to work that but source of inconsistency
#could be worth investigating.
for key, value in temp.iteritems():
if key.type != value.type:
del givens[key]
flag = is_same_graph(t1, t2, givens=givens)
flags.append(flag)
is_same = all(flags)
if is_same:
# found the match
print 'found #TODO: he match, no need to optimize'
need_optimize = False
key = graph_old
break
if need_optimize:
# this is a brand new graph, optimize it, save it to graph_db
print 'optimizing the graph'
fgraph.variables = set(gof.graph.variables(fgraph.inputs, fgraph.outputs))
#check_integrity parameters was added to ignore
#"excess cached variables" errors. Works that way
#but once again the error couldbe worth
#investigating.
before_opt = fgraph.clone(check_integrity=False)
start_optimizer = time.time()
optimizer_profile = optimizer(fgraph)
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
graph_db.update({before_opt:fgraph})
f = open(graph_db_file, 'wb')
cPickle.dump(graph_db, f, -1)
f.close()
print 'saved into graph_db'
else:
print 'no opt, get graph from graph_db'
# just read the optmized graph from graph_db
opt_time = 0
#"Naive" insertion. It's seems to work, but there may
#be some problems inserting it like that.
self.fgraph = graph_db[key]
fgraph = self.fgraph
# release stuff
release_lock()
def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None):
profile=None, on_unused_input=None, fgraph=None):
"""
:type inputs: a list of SymbolicInput instances
......@@ -1040,7 +1208,6 @@ class FunctionMaker(object):
inputs = [inputs]
# Wrap them in In or Out instances if needed.
#import pudb; pudb.set_trace()
inputs, outputs = map(self.wrap_in, inputs), map(self.wrap_out, outputs)
_inputs = gof.graph.inputs([o.variable for o in outputs] + [i.update
for i in inputs if getattr(i, 'update', False)])
......@@ -1052,37 +1219,44 @@ class FunctionMaker(object):
# tuple for each input. (See Function.indices for more details)
indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
# make the fgraph (copies the graph, creates NEW INPUT AND OUTPUT VARIABLES)
fgraph, additional_outputs = std_fgraph(inputs, outputs, accept_inplace)
fgraph.profile = profile
if fgraph is None:
need_opt = True
# make the fgraph (copies the graph, creates NEW INPUT AND OUTPUT VARIABLES)
fgraph, additional_outputs = std_fgraph(inputs, outputs, accept_inplace)
fgraph.profile = profile
else:
# fgraph is already an optimized one
need_opt = False
_, additional_outputs = std_fgraph(inputs, outputs, accept_inplace)
pass
self.fgraph = fgraph
# Fetch the optimizer and linker
optimizer, linker = mode.optimizer, copy.copy(mode.linker)
# optimize the fgraph
compute_test_value_orig = theano.config.compute_test_value
add_stack_trace_on_call = gof.Op.add_stack_trace_on_call
try:
theano.config.compute_test_value = theano.config.compute_test_value_opt
gof.Op.add_stack_trace_on_call = False
start_optimizer = time.time()
optimizer_profile = optimizer(fgraph)
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
if profile:
profile.optimizer_time += opt_time
if theano.config.profile_optimizer:
profile.optimizer_profile = (optimizer, optimizer_profile)
_logger.debug('Optimizing took %f seconds', opt_time)
#Add deep copy to respect the memory interface
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
finally:
theano.config.compute_test_value = compute_test_value_orig
gof.Op.add_stack_trace_on_call = add_stack_trace_on_call
if need_opt:
compute_test_value_orig = theano.config.compute_test_value
add_stack_trace_on_call_orig = gof.Op.add_stack_trace_on_call
try:
# optimize the fgraph
theano.config.compute_test_value = theano.config.compute_test_value_opt
gof.Op.add_stack_trace_on_call = False
start_optimizer = time.time()
optimizer_profile = optimizer(fgraph)
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
if profile:
profile.optimizer_time += opt_time
if theano.config.profile_optimizer:
profile.optimizer_profile = (optimizer, optimizer_profile)
_logger.debug('Optimizing took %f seconds', opt_time)
#Add deep copy to respect the memory interface
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
finally:
theano.config.compute_test_value = compute_test_value_orig
gof.Op.add_stack_trace_on_call = add_stack_trace_on_call_orig
# initialize the linker
if not hasattr(linker, 'accept'):
raise ValueError("'linker' parameter of FunctionMaker should be a Linker with an accept method " \
......@@ -1245,6 +1419,7 @@ def _pickle_FunctionMaker(self):
kwargs = dict(
inputs=self.inputs,
outputs=self.orig_outputs,
fgraph=self.fgraph,
mode=self.mode,
accept_inplace=self.accept_inplace,
function_builder=self.function_builder,
......@@ -1256,6 +1431,8 @@ def _pickle_FunctionMaker(self):
def _constructor_FunctionMaker(kwargs):
if theano.config.unpickle_function:
if theano.config.reoptimize_unpickled_function:
del kwargs['fgraph']
return FunctionMaker(**kwargs)
else:
return None
......
......@@ -118,6 +118,7 @@ AddConfigVar('print_active_device',
BoolParam(True, allow_override=False),
in_c_key=False)
# Do not add FAST_RUN_NOGC to this list (nor any other ALL CAPS shortcut).
# The way to get FAST_RUN_NOGC is with the flag 'linker=c|py_nogc'.
# The old all capital letter way of working is deprecated as it is not
......@@ -465,6 +466,12 @@ AddConfigVar('unpickle_function',
BoolParam(True),
in_c_key=False)
AddConfigVar('reoptimize_unpickled_function',
"Re-optimize the graph when a theano function is unpickled from the disk.",
BoolParam(True, allow_override=True),
in_c_key=False)
"""Note to developers:
Generally your exceptions should use an apply node's __str__
method when exception_verbosity == 'low'. When exception_verbosity
......@@ -538,3 +545,11 @@ AddConfigVar('check_input',
"(particularly for scalars) and reduce the number of generated C "
"files.",
BoolParam(True))
AddConfigVar('cache_optimizations',
"WARNING: work in progress, does not work yet."
"Specify if the optimization cache should be used. This cache will"
"any optimized graph and its optimization. Actually slow downs a lot"
"the first optimization, and could possibly still contains some bugs."
"Use at your own risks.",
BoolParam(False))
......@@ -662,6 +662,7 @@ class DestroyHandler(toolbox.Bookkeeper):
The following data structures remain to be converted:
<unknown>
"""
pickle_rm_attr = ["destroyers"]
def __init__(self, do_imports_on_attach=True):
self.fgraph = None
......@@ -720,15 +721,7 @@ class DestroyHandler(toolbox.Bookkeeper):
" or in conflict with another plugin.")
####### Annotate the FunctionGraph ############
def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact()
try:
return [root_destroyer[droot[r]]]
except Exception:
return []
fgraph.destroyers = get_destroyers_of
self.unpickle(fgraph)
fgraph.destroy_handler = self
self.fgraph = fgraph
......@@ -743,6 +736,15 @@ class DestroyHandler(toolbox.Bookkeeper):
if self.do_imports_on_attach:
toolbox.Bookkeeper.on_attach(self, fgraph)
def unpickle(self, fgraph):
def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact()
try:
return [root_destroyer[droot[r]]]
except Exception:
return []
fgraph.destroyers = get_destroyers_of
def refresh_droot_impact(self):
"""
Makes sure self.droot, self.impact, and self.root_destroyer are
......
......@@ -87,6 +87,11 @@ class FunctionGraph(utils.object2):
#TODO: document what variables are[not] set in the FunctionGraph when a feature
is added via the constructor. How constructed is the FunctionGraph?
Note: the intermediate nodes between 'inputs' and 'outputs' are not explicitely
passed.
:param inputs: inputs nodes of the graph, usually declared by the user
:param outputs: outputs nodes of the graph.
:param clone: If true, we will clone the graph. This is
useful to remove the constant cache problem.
......@@ -724,17 +729,42 @@ class FunctionGraph(utils.object2):
return self.__str__()
### clone ###
def clone(self):
def clone(self, check_integrity=True):
"""WRITEME"""
return self.clone_get_equiv()[0]
return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv(self):
def clone_get_equiv(self, check_integrity=True):
"""WRITEME"""
equiv = graph.clone_get_equiv(self.inputs, self.outputs)
self.check_integrity()
if check_integrity:
self.check_integrity()
e = FunctionGraph([equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs])
e.check_integrity()
if check_integrity:
e.check_integrity()
for feature in self._features:
e.attach_feature(feature)
return e, equiv
def __getstate__(self):
"""This is needed as some feature introduce instancemethod and
this is not pickable.
"""
d = self.__dict__.copy()
for feature in self._features:
for attr in getattr(feature, "pickle_rm_attr", []):
del d[attr]
# The class Updater take fct as parameter and they are lambda function, so unpicklable.
# execute_callbacks_times have reference to optimizer, and they can't
# be pickled as the decorators with parameters aren't pickable.
if "execute_callbacks_times" in d:
del d["execute_callbacks_times"]
return d
def __setstate__(self, dct):
self.__dict__.update(dct)
for feature in self._features:
if hasattr(feature, "unpickle"):
feature.unpickle(self)
......@@ -878,6 +878,7 @@ def is_same_graph(var1, var2, givens=None, debug=False):
# Get result from the merge-based function.
rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
# Get result from the function `equal_computations` from scan_utils.
use_equal_computations = True
if givens:
# We need to build the `in_xs` and `in_ys` lists. To do this, we need
......
......@@ -22,8 +22,6 @@ import theano
from theano import config
from theano.gof.python25 import any, all, deque
#if sys.version_info[:2] >= (2,5):
# from collections import defaultdict
_logger = logging.getLogger('theano.gof.opt')
......@@ -1241,6 +1239,30 @@ class PatternSub(LocalOptimizer):
# Use the following classes to apply LocalOptimizers
class Updater:
def __init__(self, importer, pruner, chin):
self.importer = importer
self.pruner = pruner
self.chin = chin
def on_import(self, fgraph, node, reason):
if self.importer:
self.importer(node)
def on_prune(self, fgraph, node, reason):
if self.pruner:
self.pruner(node)
def on_change_input(self, fgraph, node, i, r, new_r, reason):
if self.chin:
self.chin(node, i, r, new_r, reason)
def on_detach(self, fgraph):
# To allow pickling this object
self.importer = None
self.pruner = None
self.chin = None
class NavigatorOptimizer(Optimizer):
"""Abstract class
......@@ -1329,18 +1351,7 @@ class NavigatorOptimizer(Optimizer):
if importer is None and pruner is None:
return None
class Updater:
if importer is not None:
def on_import(self, fgraph, node, reason):
importer(node)
if pruner is not None:
def on_prune(self, fgraph, node, reason):
pruner(node)
if chin is not None:
def on_change_input(self, fgraph, node, i, r, new_r, reason):
chin(node, i, r, new_r, reason)
u = Updater()
u = Updater(importer, pruner, chin)
fgraph.attach_feature(u)
return u
......
import pickle
import unittest
import theano
from theano.gof import CachedConstantError, FunctionGraph
from theano import tensor as tt
class TFunctionGraph(unittest.TestCase):
......@@ -15,3 +17,10 @@ class TFunctionGraph(unittest.TestCase):
v = theano.tensor.constant(1)
assert v.cached
FunctionGraph([], [v + 1])
def test_pickle(self):
v = tt.vector()
func = theano.gof.FunctionGraph([v], [v + 1])
s = pickle.dumps(func)
func2 = pickle.loads(s)
......@@ -104,7 +104,32 @@ class Bookkeeper(Feature):
self.on_prune(fgraph, node, 'Bookkeeper.detach')
class GetCheckpoint:
def __init__(self, history, fgraph):
self.h = history
self.fgraph = fgraph
def __call__(self):
return len(self.h.history[self.fgraph])
class LambdExtract:
def __init__(self, fgraph, node, i, r, reason=None):
self.fgraph = fgraph
self.node = node
self.i = i
self.r = r
self.reason = reason
def __call__(self):
return self.fgraph.change_input(self.node, self.i, self.r,
reason=("Revert", self.reason))
class History(Feature):
pickle_rm_attr = ["checkpoint", "revert"]
def __init__(self):
self.history = {}
......@@ -114,7 +139,14 @@ class History(Feature):
raise AlreadyThere("History feature is already present or in"
" conflict with another plugin.")
self.history[fgraph] = []
fgraph.checkpoint = lambda: len(self.history[fgraph])
# Don't call unpickle here, as ReplaceValidate.on_attach()
# call to History.on_attach() will call the
# ReplaceValidate.unpickle and not History.unpickle
fgraph.checkpoint = GetCheckpoint(self, fgraph)
fgraph.revert = partial(self.revert, fgraph)
def unpickle(self, fgraph):
fgraph.checkpoint = GetCheckpoint(self, fgraph)
fgraph.revert = partial(self.revert, fgraph)
def on_detach(self, fgraph):
......@@ -126,8 +158,7 @@ class History(Feature):
if self.history[fgraph] is None:
return
h = self.history[fgraph]
h.append(lambda: fgraph.change_input(node, i, r,
reason=("Revert", reason)))
h.append(LambdExtract(fgraph, node, i, r, reason))
def revert(self, fgraph, checkpoint):
"""
......@@ -144,47 +175,66 @@ class History(Feature):
class Validator(Feature):
pickle_rm_attr = ["validate", "consistent"]
def on_attach(self, fgraph):
for attr in ('validate', 'validate_time'):
if hasattr(fgraph, attr):
raise AlreadyThere("Validator feature is already present or in"
" conflict with another plugin.")
# Don't call unpickle here, as ReplaceValidate.on_attach()
# call to History.on_attach() will call the
# ReplaceValidate.unpickle and not History.unpickle
fgraph.validate = partial(self.validate_, fgraph)
fgraph.consistent = partial(self.consistent_, fgraph)
def validate():
t0 = time.time()
ret = fgraph.execute_callbacks('validate')
t1 = time.time()
if fgraph.profile:
fgraph.profile.validate_time += t1 - t0
return ret
fgraph.validate = validate
def consistent():
try:
fgraph.validate()
return True
except Exception:
return False
fgraph.consistent = consistent
def unpickle(self, fgraph):
fgraph.validate = partial(self.validate_, fgraph)
fgraph.consistent = partial(self.consistent_, fgraph)
def on_detach(self, fgraph):
del fgraph.validate
del fgraph.consistent
def validate_(self, fgraph):
t0 = time.time()
ret = fgraph.execute_callbacks('validate')
t1 = time.time()
if fgraph.profile:
fgraph.profile.validate_time += t1 - t0
return ret
def consistent_(self, fgraph):
try:
fgraph.validate()
return True
except Exception:
return False
class ReplaceValidate(History, Validator):
pickle_rm_attr = ["replace_validate", "replace_all_validate",
"replace_all_validate_remove"] + \
History.pickle_rm_attr + Validator.pickle_rm_attr
def on_attach(self, fgraph):
History.on_attach(self, fgraph)
Validator.on_attach(self, fgraph)
for attr in ('replace_validate', 'replace_all_validate'):
for attr in ('replace_validate', 'replace_all_validate',
'replace_all_validate_remove'):
if hasattr(fgraph, attr):
raise AlreadyThere("ReplaceValidate feature is already present"
" or in conflict with another plugin.")
History.on_attach(self, fgraph)
Validator.on_attach(self, fgraph)
self.unpickle(fgraph)
def unpickle(self, fgraph):
History.unpickle(self, fgraph)
Validator.unpickle(self, fgraph)
fgraph.replace_validate = partial(self.replace_validate, fgraph)
fgraph.replace_all_validate = partial(self.replace_all_validate, fgraph)
fgraph.replace_all_validate = partial(self.replace_all_validate,
fgraph)
fgraph.replace_all_validate_remove = partial(
self.replace_all_validate_remove, fgraph)
......@@ -247,6 +297,12 @@ class ReplaceValidate(History, Validator):
print >> out, reason, replacements
raise ReplacementDidntRemovedError()
def __getstate__(self):
d = self.__dict__.copy()
if "history" in d:
del d["history"]
return d
class NodeFinder(Bookkeeper):
......
......@@ -694,7 +694,7 @@ class VM_Linker(link.LocalLinker):
if k.owner and k.clients:
ls = []
for cl in k.clients:
if cl[0] is not 'output':
if cl[0] != 'output':
ls += cl[0].outputs
dependencies[k] += ls
return dependencies
......
......@@ -44,6 +44,8 @@ if MutableSet is not None:
import weakref
class Link(object):
# This make that we need to use a different pickle protocol
# then the default. Othewise, there is pickling errors
__slots__ = 'prev', 'next', 'key', '__weakref__'
def __getstate__(self):
......
......@@ -1494,11 +1494,11 @@ class GemmOptimizer(Optimizer):
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
class Updater:
def on_import(self, fgraph, new_node, reason):
if new_node is not node:
nodelist.append(new_node)
u = Updater()
def on_import(new_node):
if new_node is not node:
nodelist.append(new_node)
u = theano.gof.opt.Updater(on_import, None, None)
fgraph.attach_feature(u)
while did_something:
nb_iter += 1
......
......@@ -2664,6 +2664,7 @@ def local_useless_tile(node):
except NotScalarConstantError:
return
################
# Flatten Opts #
################
......
......@@ -53,10 +53,10 @@ def test_gc_never_pickles_temporaries():
len_pre_f = len(cPickle.dumps(f))
len_pre_g = len(cPickle.dumps(g))
# should be no difference at first
# In future, FunctionMaker might pickle linker-dependent stuff and make
# this assertion fail.
assert len_pre_f == len_pre_g
# We can't compare the content or the length of the string
# between f and g. 2 reason, we store some timming information
# in float. They won't be the same each time. Different float
# can have different lenght when printed.
def a(fn):
return len(cPickle.dumps(fn.maker))
......
"""
This script tests the pickle and unpickle of theano functions.
When a compiled theano has shared vars, their values are also being pickled.
Side notes useful for debugging:
The pickling tools theano uses is here:
theano.compile.function_module._pickle_Function()
theano.compile.function_module._pickle_FunctionMaker()
Whether reoptimize the pickled function graph is handled by
FunctionMaker.__init__()
The config option is in configdefaults.py
This note is written by Li Yao.
"""
import unittest
import numpy
import cPickle
from theano.compat.python2x import DictMixin, OrderedDict
floatX = 'float32'
import theano
import theano.tensor as T
def test_pickle_unpickle_with_reoptimization():
mode = theano.config.mode
if mode in ["DEBUG_MODE", "DebugMode"]:
mode = "FAST_RUN"
x1 = T.fmatrix('x1')
x2 = T.fmatrix('x2')
x3 = theano.shared(numpy.ones((10, 10), dtype=floatX))
x4 = theano.shared(numpy.ones((10, 10), dtype=floatX))
y = T.sum(T.sum(T.sum(x1 ** 2 + x2) + x3) + x4)
updates = OrderedDict()
updates[x3] = x3 + 1
updates[x4] = x4 + 1
f = theano.function([x1, x2], y, updates=updates, mode=mode)
# now pickle the compiled theano fn
string_pkl = cPickle.dumps(f, -1)
in1 = numpy.ones((10, 10), dtype=floatX)
in2 = numpy.ones((10, 10), dtype=floatX)
# test unpickle with optimization
default = theano.config.reoptimize_unpickled_function
try:
# the default is True
theano.config.reoptimize_unpickled_function = True
f_ = cPickle.loads(string_pkl)
assert f(in1, in2) == f_(in1, in2)
finally:
theano.config.reoptimize_unpickled_function = default
def test_pickle_unpickle_without_reoptimization():
mode = theano.config.mode
if mode in ["DEBUG_MODE", "DebugMode"]:
mode = "FAST_RUN"
x1 = T.fmatrix('x1')
x2 = T.fmatrix('x2')
x3 = theano.shared(numpy.ones((10, 10), dtype=floatX))
x4 = theano.shared(numpy.ones((10, 10), dtype=floatX))
y = T.sum(T.sum(T.sum(x1**2 + x2) + x3) + x4)
updates = OrderedDict()
updates[x3] = x3 + 1
updates[x4] = x4 + 1
f = theano.function([x1, x2], y, updates=updates, mode=mode)
# now pickle the compiled theano fn
string_pkl = cPickle.dumps(f, -1)
# compute f value
in1 = numpy.ones((10, 10), dtype=floatX)
in2 = numpy.ones((10, 10), dtype=floatX)
# test unpickle without optimization
default = theano.config.reoptimize_unpickled_function
try:
# the default is True
theano.config.reoptimize_unpickled_function = False
f_ = cPickle.loads(string_pkl)
assert f(in1, in2) == f_(in1, in2)
finally:
theano.config.reoptimize_unpickled_function = default
if __name__ == '__main__':
test_pickle_unpickle_with_reoptimization()
test_pickle_unpickle_without_reoptimization()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论