提交 5e70f3ac authored 作者: Li's avatar Li 提交者: Frederic

finished ticket on pickle and unpickle theano fn with/without reoptimization

上级 d4e7e4f1
...@@ -703,7 +703,6 @@ class Function(object): ...@@ -703,7 +703,6 @@ class Function(object):
# pickling/deepcopy support for Function # pickling/deepcopy support for Function
def _pickle_Function(f): def _pickle_Function(f):
print 'pickling Function..'
#copy of the input storage list #copy of the input storage list
ins = list(f.input_storage) ins = list(f.input_storage)
input_storage = [] input_storage = []
...@@ -741,7 +740,6 @@ def _pickle_Function(f): ...@@ -741,7 +740,6 @@ def _pickle_Function(f):
return rval return rval
def _constructor_Function(maker, input_storage, inputs_data): def _constructor_Function(maker, input_storage, inputs_data):
print 'unpickling Function...'
if not theano.config.unpickle_function: if not theano.config.unpickle_function:
return None return None
f = maker.create(input_storage, trustme = True) f = maker.create(input_storage, trustme = True)
...@@ -1245,7 +1243,6 @@ class FunctionMaker(object): ...@@ -1245,7 +1243,6 @@ class FunctionMaker(object):
if need_opt: if need_opt:
# optimize the fgraph # optimize the fgraph
print 'fgraph is optimized'
try: try:
theano.config.compute_test_value = theano.config.compute_test_value_opt theano.config.compute_test_value = theano.config.compute_test_value_opt
gof.Op.add_stack_trace_on_call = False gof.Op.add_stack_trace_on_call = False
...@@ -1267,7 +1264,6 @@ class FunctionMaker(object): ...@@ -1267,7 +1264,6 @@ class FunctionMaker(object):
else: else:
# fgraph is already optimized # fgraph is already optimized
print 'fgraph is not optimized'
theano.config.compute_test_value = compute_test_value_orig theano.config.compute_test_value = compute_test_value_orig
gof.Op.add_stack_trace_on_call = add_stack_trace_on_call gof.Op.add_stack_trace_on_call = add_stack_trace_on_call
...@@ -1430,7 +1426,6 @@ class FunctionMaker(object): ...@@ -1430,7 +1426,6 @@ class FunctionMaker(object):
def _pickle_FunctionMaker(self): def _pickle_FunctionMaker(self):
'picking FunctionMaker'
kwargs = dict( kwargs = dict(
inputs=self.inputs, inputs=self.inputs,
outputs=self.orig_outputs, outputs=self.orig_outputs,
...@@ -1445,7 +1440,6 @@ def _pickle_FunctionMaker(self): ...@@ -1445,7 +1440,6 @@ def _pickle_FunctionMaker(self):
def _constructor_FunctionMaker(kwargs): def _constructor_FunctionMaker(kwargs):
print 'unpickling FunctionMaker...'
if theano.config.unpickle_function: if theano.config.unpickle_function:
if theano.config.reoptimize_unpickled_function: if theano.config.reoptimize_unpickled_function:
del kwargs['fgraph'] del kwargs['fgraph']
......
...@@ -468,7 +468,7 @@ AddConfigVar('unpickle_function', ...@@ -468,7 +468,7 @@ AddConfigVar('unpickle_function',
AddConfigVar('reoptimize_unpickled_function', AddConfigVar('reoptimize_unpickled_function',
"Re-optimize the graph when a theano function is unpickled from the disk.", "Re-optimize the graph when a theano function is unpickled from the disk.",
BoolParam(False, allow_override=False), BoolParam(True, allow_override=True),
in_c_key=False) in_c_key=False)
......
"""
This script tests the pickle and unpickle of theano functions.
When a compiled theano has shared vars, their values are also being pickled.
Side notes useful for debugging:
The pickling tools theano uses is here:
theano.compile.function_module._pickle_Function()
theano.compile.function_module._pickle_FunctionMaker()
Whether reoptimize the pickled function graph is handled by
FunctionMaker.__init__()
The config option is in configdefaults.py
This note is written by Li Yao.
"""
import unittest import unittest
import theano
import theano.tensor as T
import numpy import numpy
import cPickle import cPickle
from collections import OrderedDict from collections import OrderedDict
floatX = 'float32' floatX = 'float32'
import theano
def test_pickle_unpickle(): import theano.tensor as T
# Test if pick and unpickling a theano function with
# shared variables should be pickled properly def test_pickle_unpickle_with_reoptimization():
x1 = T.fmatrix('x1') x1 = T.fmatrix('x1')
x2 = T.fmatrix('x2') x2 = T.fmatrix('x2')
x3 = theano.shared(numpy.ones((10,10),dtype=floatX)) x3 = theano.shared(numpy.ones((10,10),dtype=floatX))
...@@ -19,18 +31,50 @@ def test_pickle_unpickle(): ...@@ -19,18 +31,50 @@ def test_pickle_unpickle():
updates[x3] = x3 + 1 updates[x3] = x3 + 1
updates[x4] = x4 + 1 updates[x4] = x4 + 1
f = theano.function([x1,x2],y, updates=updates) f = theano.function([x1,x2],y, updates=updates)
# now pickle the compiled theano fn
pkl_path = open('thean_fn.pkl','wb') pkl_path = open('thean_fn.pkl','wb')
cPickle.dump(f, pkl_path, -1) cPickle.dump(f, pkl_path, -1)
in1 = numpy.ones((10, 10), dtype=floatX)
in2 = numpy.ones((10, 10), dtype=floatX)
print 'the desired value is ',f(in1, in2)
# test unpickle with optimization
theano.config.reoptimize_unpickled_function=True # the default is True
pkl_path = open('thean_fn.pkl','r') pkl_path = open('thean_fn.pkl','r')
f_ = cPickle.load(pkl_path) f_ = cPickle.load(pkl_path)
import ipdb; ipdb.set_trace() print 'got value ', f_(in1, in2)
assert f(in1, in2) == f_(in1, in2)
def test_pickle_unpickle_without_reoptimization():
x1 = T.fmatrix('x1')
x2 = T.fmatrix('x2')
x3 = theano.shared(numpy.ones((10,10),dtype=floatX))
x4 = theano.shared(numpy.ones((10,10),dtype=floatX))
y = T.sum(T.sum(T.sum(x1**2+x2) + x3) + x4)
updates = OrderedDict()
updates[x3] = x3 + 1
updates[x4] = x4 + 1
f = theano.function([x1,x2],y, updates=updates)
# now pickle the compiled theano fn
pkl_path = open('thean_fn.pkl','wb')
cPickle.dump(f, pkl_path, -1)
# compute f value
in1 = numpy.ones((10, 10), dtype=floatX) in1 = numpy.ones((10, 10), dtype=floatX)
in2 = numpy.ones((10, 10), dtype=floatX) in2 = numpy.ones((10, 10), dtype=floatX)
print 'the desired value is ',f(in1, in2)
assert f(in1, in2) == f_(in1, in2)
print f(in1, in2) # test unpickle without optimization
theano.config.reoptimize_unpickled_function=False # the default is True
pkl_path = open('thean_fn.pkl','r')
f_ = cPickle.load(pkl_path)
print 'got value ', f_(in1, in2)
assert f(in1, in2) == f_(in1, in2)
if __name__ == '__main__': if __name__ == '__main__':
test_pickle_unpickle() test_pickle_unpickle_with_reoptimization()
test_pickle_unpickle_without_reoptimization()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论