提交 6c76f283 authored 作者: James Bergstra's avatar James Bergstra

patched the Function pickler to test for aliased storage, and by default issue a warning

上级 07d66494
...@@ -115,6 +115,19 @@ class Function(object): ...@@ -115,6 +115,19 @@ class Function(object):
""" """
pickle_aliased_memory_strategy = 'warn'
"""How to deal with pickling finding aliased storage.
Meaningful settings are: 'ignore', 'warn', 'raise'
If the value is 'warn', then a message will be printed to stderr if aliased storage is
dectected during pickle.dump.
If the value is 'raise', then an AliasedMemoryError will be raised if aliased storage is
detected during pickle.dump.
"""
def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, maker): def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, maker):
""" """
fn -> a function returned by some linker's make_thunk method fn -> a function returned by some linker's make_thunk method
...@@ -334,9 +347,29 @@ def _pickle_Function(f): ...@@ -334,9 +347,29 @@ def _pickle_Function(f):
else: else:
defaults.append(ins[0]) defaults.append(ins[0])
del ins[0] del ins[0]
rval = (_constructor_Function, (f.maker, defaults, [x.data for x in f.input_storage]))
inputs_data = [x.data for x in f.input_storage]
#HACK to detect aliased storage.
# aliased relationships will not be preserved across the pickle operation
if not (f.pickle_aliased_memory_strategy == 'ignore'):
all_data = defaults + inputs_data
for i, d_i in enumerate(all_data):
for j, d_j in enumerate(all_data):
if (i < j) and isinstance(d_i, numpy.ndarray) and isinstance(d_j, numpy.ndarray):
if f.pickle_aliased_memory_strategy == 'warn':
print >> sys.stderr, ('WARNING: '
'aliased relationship between Function arguments '
'will not be preserved by un-pickling operation')
#print >> sys.stderr, d_i, d_j, id(d_i), id(d_j)
else:
raise AliasedMemoryError(d_i, d_j)
rval = (_constructor_Function, (f.maker, defaults, inputs_data))
return rval return rval
class AliasedMemoryError(Exception):pass
def _constructor_Function(maker, defaults, data): def _constructor_Function(maker, defaults, data):
f = maker.create(defaults, trustme = True) f = maker.create(defaults, trustme = True)
assert len(f.input_storage) == len(data) assert len(f.input_storage) == len(data)
......
...@@ -5,6 +5,7 @@ __docformat__ = "restructuredtext en" ...@@ -5,6 +5,7 @@ __docformat__ = "restructuredtext en"
import cPickle, numpy, unittest import cPickle, numpy, unittest
from theano.compile.module import * from theano.compile.module import *
from theano.compile.function_module import AliasedMemoryError
import theano.tensor as T import theano.tensor as T
import sys import sys
import theano import theano
...@@ -587,7 +588,6 @@ def test_pickle(): ...@@ -587,7 +588,6 @@ def test_pickle():
assert m_dup.y is m_dup.g.input_storage[2].data assert m_dup.y is m_dup.g.input_storage[2].data
def test_pickle_aliased_memory(): def test_pickle_aliased_memory():
try:
M = Module() M = Module()
M.x = (T.dmatrix()) M.x = (T.dmatrix())
M.y = (T.dmatrix()) M.y = (T.dmatrix())
...@@ -597,12 +597,32 @@ def test_pickle_aliased_memory(): ...@@ -597,12 +597,32 @@ def test_pickle_aliased_memory():
m = M.make(x=numpy.zeros((4,5)), y=numpy.ones((2,3))) m = M.make(x=numpy.zeros((4,5)), y=numpy.ones((2,3)))
m.y = m.x[:] m.y = m.x[:]
m_dup = cPickle.loads(cPickle.dumps(m))
#m's memory is aliased.... #m's x and y memory is aliased....
m.x[0,0] = 3.14 m.x[0,0] = 3.14
assert m.y[0,0] == 3.14 assert m.y[0,0] == 3.14
import StringIO
sio = StringIO.StringIO()
old_stderr = sys.stderr
sys.stderr = sio
m.f.pickle_aliased_memory_strategy = 'warn'
m.g.pickle_aliased_memory_strategy = 'warn'
m_dup = cPickle.loads(cPickle.dumps(m))
sys.stderr = old_stderr
assert sio.getvalue().startswith('WARNING: aliased relat')
try:
m.f.pickle_aliased_memory_strategy = 'raise'
m.g.pickle_aliased_memory_strategy = 'raise'
m_dup = cPickle.loads(cPickle.dumps(m))
except AliasedMemoryError, e:
return
assert 0 #should have failed to pickle
#is m_dup's memory aliased? #is m_dup's memory aliased?
m_dup.x[0,0] = 3.14 m_dup.x[0,0] = 3.14
assert m_dup.y[0,0] == 3.14 assert m_dup.y[0,0] == 3.14
...@@ -611,14 +631,12 @@ def test_pickle_aliased_memory(): ...@@ -611,14 +631,12 @@ def test_pickle_aliased_memory():
m.y = m.x[1:2] m.y = m.x[1:2]
m_dup = cPickle.loads(cPickle.dumps(m)) m_dup = cPickle.loads(cPickle.dumps(m))
if 0:
#is m_dup's memory aliased the same way? #is m_dup's memory aliased the same way?
m.x[1,0] = 3.142 m.x[1,0] = 3.142
assert m.y[0,0] == 3.142 assert m.y[0,0] == 3.142
m_dup.x[1,0] = 3.142 m_dup.x[1,0] = 3.142
assert m_dup.y[0,0] == 3.142 assert m_dup.y[0,0] == 3.142
except Exception, e:
raise Exception('Known Failure: These branch cuts are known to fail', str(e))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论