提交 98bae725 authored 作者: Frederic Bastien's avatar Frederic Bastien

Pickle/Unpickle trust_input. We use False (the default) for old pickle that…

Pickle/Unpickle trust_input. We use False (the default) for old pickle that don't have the information.
上级 95b8f998
...@@ -1057,19 +1057,21 @@ def _pickle_Function(f): ...@@ -1057,19 +1057,21 @@ def _pickle_Function(f):
(str(d_i), str(d_j))) (str(d_i), str(d_j)))
else: else:
raise AliasedMemoryError(d_i, d_j) raise AliasedMemoryError(d_i, d_j)
rval = (_constructor_Function, (f.maker, input_storage, inputs_data)) rval = (_constructor_Function, (f.maker, input_storage, inputs_data, f.trust_input))
return rval return rval
def _constructor_Function(maker, input_storage, inputs_data): def _constructor_Function(maker, input_storage, inputs_data, trust_input=False):
if not theano.config.unpickle_function: if not theano.config.unpickle_function:
return None return None
f = maker.create(input_storage, trustme=True) f = maker.create(input_storage, trustme=True)
assert len(f.input_storage) == len(inputs_data) assert len(f.input_storage) == len(inputs_data)
for container, x in zip(f.input_storage, inputs_data): for container, x in zip(f.input_storage, inputs_data):
assert (container.data is x) or \ assert (container.data is x) or \
(isinstance(x, np.ndarray) and (container.data == x).all()) or \ (isinstance(x, np.ndarray) and (container.data == x).all()) or \
(container.data == x) (container.data == x)
f.trust_input = trust_input
return f return f
copyreg.pickle(Function, _pickle_Function) copyreg.pickle(Function, _pickle_Function)
......
...@@ -590,8 +590,8 @@ class T_picklefunction(unittest.TestCase): ...@@ -590,8 +590,8 @@ class T_picklefunction(unittest.TestCase):
x, s = T.scalars('xs') x, s = T.scalars('xs')
f = function([x, In(a, value=1.0, name='a'), f = function([x, In(a, value=1.0, name='a'),
In(s, value=0.0, update=s + a * x, mutable=True)], s + a * x) In(s, value=0.0, update=s + a * x, mutable=True)],
s + a * x)
try: try:
g = copy.deepcopy(f) g = copy.deepcopy(f)
except NotImplementedError as e: except NotImplementedError as e:
...@@ -628,6 +628,27 @@ class T_picklefunction(unittest.TestCase): ...@@ -628,6 +628,27 @@ class T_picklefunction(unittest.TestCase):
g(1, 2) # put them back in sync g(1, 2) # put them back in sync
self.assertTrue(f(3) == g(3)) # They should be in sync again. self.assertTrue(f(3) == g(3)) # They should be in sync again.
def test_deepcopy_trust_input(self):
a = T.dscalar() # the a is for 'anonymous' (un-named).
x, s = T.dscalars('xs')
f = function([x, In(a, value=1.0, name='a'),
In(s, value=0.0, update=s + a * x, mutable=True)],
s + a * x)
f.trust_input = True
try:
g = copy.deepcopy(f)
except NotImplementedError as e:
if e[0].startswith('DebugMode is not picklable'):
return
else:
raise
self.assertTrue(f.trust_input is g.trust_input)
f(np.asarray(2.))
self.assertRaises((ValueError, AttributeError), f, 2.)
g(np.asarray(2.))
self.assertRaises((ValueError, AttributeError), g, 2.)
def test_deepcopy_shared_container(self): def test_deepcopy_shared_container(self):
# Ensure that shared containers remain shared after a deep copy. # Ensure that shared containers remain shared after a deep copy.
a, x = T.scalars('ax') a, x = T.scalars('ax')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论