提交 5e70f3ac authored 作者: Li's avatar Li 提交者: Frederic

finished ticket on pickle and unpickle theano fn with/without reoptimization

上级 d4e7e4f1
......@@ -703,7 +703,6 @@ class Function(object):
# pickling/deepcopy support for Function
def _pickle_Function(f):
print 'pickling Function..'
#copy of the input storage list
ins = list(f.input_storage)
input_storage = []
......@@ -741,7 +740,6 @@ def _pickle_Function(f):
return rval
def _constructor_Function(maker, input_storage, inputs_data):
print 'unpickling Function...'
if not theano.config.unpickle_function:
return None
f = maker.create(input_storage, trustme = True)
......@@ -1245,7 +1243,6 @@ class FunctionMaker(object):
if need_opt:
# optimize the fgraph
print 'fgraph is optimized'
try:
theano.config.compute_test_value = theano.config.compute_test_value_opt
gof.Op.add_stack_trace_on_call = False
......@@ -1267,7 +1264,6 @@ class FunctionMaker(object):
else:
# fgraph is already optimized
print 'fgraph is not optimized'
theano.config.compute_test_value = compute_test_value_orig
gof.Op.add_stack_trace_on_call = add_stack_trace_on_call
......@@ -1430,7 +1426,6 @@ class FunctionMaker(object):
def _pickle_FunctionMaker(self):
'picking FunctionMaker'
kwargs = dict(
inputs=self.inputs,
outputs=self.orig_outputs,
......@@ -1445,7 +1440,6 @@ def _pickle_FunctionMaker(self):
def _constructor_FunctionMaker(kwargs):
print 'unpickling FunctionMaker...'
if theano.config.unpickle_function:
if theano.config.reoptimize_unpickled_function:
del kwargs['fgraph']
......
......@@ -468,7 +468,7 @@ AddConfigVar('unpickle_function',
AddConfigVar('reoptimize_unpickled_function',
"Re-optimize the graph when a theano function is unpickled from the disk.",
BoolParam(False, allow_override=False),
BoolParam(True, allow_override=True),
in_c_key=False)
......
"""
This script tests the pickle and unpickle of theano functions.
When a compiled theano has shared vars, their values are also being pickled.
Side notes useful for debugging:
The pickling tools theano uses is here:
theano.compile.function_module._pickle_Function()
theano.compile.function_module._pickle_FunctionMaker()
Whether reoptimize the pickled function graph is handled by
FunctionMaker.__init__()
The config option is in configdefaults.py
This note is written by Li Yao.
"""
import unittest
import theano
import theano.tensor as T
import numpy
import cPickle
from collections import OrderedDict
floatX = 'float32'
def test_pickle_unpickle():
# Test if pick and unpickling a theano function with
# shared variables should be pickled properly
import theano
import theano.tensor as T
def test_pickle_unpickle_with_reoptimization():
x1 = T.fmatrix('x1')
x2 = T.fmatrix('x2')
x3 = theano.shared(numpy.ones((10,10),dtype=floatX))
......@@ -19,18 +31,50 @@ def test_pickle_unpickle():
updates[x3] = x3 + 1
updates[x4] = x4 + 1
f = theano.function([x1,x2],y, updates=updates)
# now pickle the compiled theano fn
pkl_path = open('thean_fn.pkl','wb')
cPickle.dump(f, pkl_path, -1)
in1 = numpy.ones((10, 10), dtype=floatX)
in2 = numpy.ones((10, 10), dtype=floatX)
print 'the desired value is ',f(in1, in2)
# test unpickle with optimization
theano.config.reoptimize_unpickled_function=True # the default is True
pkl_path = open('thean_fn.pkl','r')
f_ = cPickle.load(pkl_path)
import ipdb; ipdb.set_trace()
print 'got value ', f_(in1, in2)
assert f(in1, in2) == f_(in1, in2)
def test_pickle_unpickle_without_reoptimization():
x1 = T.fmatrix('x1')
x2 = T.fmatrix('x2')
x3 = theano.shared(numpy.ones((10,10),dtype=floatX))
x4 = theano.shared(numpy.ones((10,10),dtype=floatX))
y = T.sum(T.sum(T.sum(x1**2+x2) + x3) + x4)
updates = OrderedDict()
updates[x3] = x3 + 1
updates[x4] = x4 + 1
f = theano.function([x1,x2],y, updates=updates)
# now pickle the compiled theano fn
pkl_path = open('thean_fn.pkl','wb')
cPickle.dump(f, pkl_path, -1)
# compute f value
in1 = numpy.ones((10, 10), dtype=floatX)
in2 = numpy.ones((10, 10), dtype=floatX)
assert f(in1, in2) == f_(in1, in2)
print 'the desired value is ',f(in1, in2)
print f(in1, in2)
# test unpickle without optimization
theano.config.reoptimize_unpickled_function=False # the default is True
pkl_path = open('thean_fn.pkl','r')
f_ = cPickle.load(pkl_path)
print 'got value ', f_(in1, in2)
assert f(in1, in2) == f_(in1, in2)
if __name__ == '__main__':
test_pickle_unpickle()
test_pickle_unpickle_with_reoptimization()
test_pickle_unpickle_without_reoptimization()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论