提交 6c76f283 authored 作者: James Bergstra's avatar James Bergstra

patched the Function pickler to test for aliased storage, and by default issue a warning

上级 07d66494
......@@ -115,6 +115,19 @@ class Function(object):
"""
pickle_aliased_memory_strategy = 'warn'
"""How to deal with pickling finding aliased storage.
Meaningful settings are: 'ignore', 'warn', 'raise'
If the value is 'warn', then a message will be printed to stderr if aliased storage is
dectected during pickle.dump.
If the value is 'raise', then an AliasedMemoryError will be raised if aliased storage is
detected during pickle.dump.
"""
def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, maker):
"""
fn -> a function returned by some linker's make_thunk method
......@@ -334,9 +347,29 @@ def _pickle_Function(f):
else:
defaults.append(ins[0])
del ins[0]
rval = (_constructor_Function, (f.maker, defaults, [x.data for x in f.input_storage]))
inputs_data = [x.data for x in f.input_storage]
#HACK to detect aliased storage.
# aliased relationships will not be preserved across the pickle operation
if not (f.pickle_aliased_memory_strategy == 'ignore'):
all_data = defaults + inputs_data
for i, d_i in enumerate(all_data):
for j, d_j in enumerate(all_data):
if (i < j) and isinstance(d_i, numpy.ndarray) and isinstance(d_j, numpy.ndarray):
if f.pickle_aliased_memory_strategy == 'warn':
print >> sys.stderr, ('WARNING: '
'aliased relationship between Function arguments '
'will not be preserved by un-pickling operation')
#print >> sys.stderr, d_i, d_j, id(d_i), id(d_j)
else:
raise AliasedMemoryError(d_i, d_j)
rval = (_constructor_Function, (f.maker, defaults, inputs_data))
return rval
class AliasedMemoryError(Exception):pass
def _constructor_Function(maker, defaults, data):
f = maker.create(defaults, trustme = True)
assert len(f.input_storage) == len(data)
......
......@@ -5,6 +5,7 @@ __docformat__ = "restructuredtext en"
import cPickle, numpy, unittest
from theano.compile.module import *
from theano.compile.function_module import AliasedMemoryError
import theano.tensor as T
import sys
import theano
......@@ -587,38 +588,55 @@ def test_pickle():
assert m_dup.y is m_dup.g.input_storage[2].data
def test_pickle_aliased_memory():
M = Module()
M.x = (T.dmatrix())
M.y = (T.dmatrix())
a = T.dmatrix()
M.f = Method([a], a + M.x + M.y)
M.g = Method([a], a * M.x * M.y)
m = M.make(x=numpy.zeros((4,5)), y=numpy.ones((2,3)))
m.y = m.x[:]
#m's x and y memory is aliased....
m.x[0,0] = 3.14
assert m.y[0,0] == 3.14
import StringIO
sio = StringIO.StringIO()
old_stderr = sys.stderr
sys.stderr = sio
m.f.pickle_aliased_memory_strategy = 'warn'
m.g.pickle_aliased_memory_strategy = 'warn'
m_dup = cPickle.loads(cPickle.dumps(m))
sys.stderr = old_stderr
assert sio.getvalue().startswith('WARNING: aliased relat')
try:
M = Module()
M.x = (T.dmatrix())
M.y = (T.dmatrix())
a = T.dmatrix()
M.f = Method([a], a + M.x + M.y)
M.g = Method([a], a * M.x * M.y)
m = M.make(x=numpy.zeros((4,5)), y=numpy.ones((2,3)))
m.y = m.x[:]
m.f.pickle_aliased_memory_strategy = 'raise'
m.g.pickle_aliased_memory_strategy = 'raise'
m_dup = cPickle.loads(cPickle.dumps(m))
except AliasedMemoryError, e:
return
#m's memory is aliased....
m.x[0,0] = 3.14
assert m.y[0,0] == 3.14
assert 0 #should have failed to pickle
#is m_dup's memory aliased?
m_dup.x[0,0] = 3.14
assert m_dup.y[0,0] == 3.14
#is m_dup's memory aliased?
m_dup.x[0,0] = 3.14
assert m_dup.y[0,0] == 3.14
#m's memory is aliased differently....
m.y = m.x[1:2]
m_dup = cPickle.loads(cPickle.dumps(m))
#m's memory is aliased differently....
m.y = m.x[1:2]
m_dup = cPickle.loads(cPickle.dumps(m))
if 0:
#is m_dup's memory aliased the same way?
m.x[1,0] = 3.142
assert m.y[0,0] == 3.142
m_dup.x[1,0] = 3.142
assert m_dup.y[0,0] == 3.142
except Exception, e:
raise Exception('Known Failure: These branch cuts are known to fail', str(e))
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论