提交 95b8f998 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix the original problem of gh-5736

上级 54c49ef0
......@@ -289,13 +289,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
uses_updates = bool(updates)
uses_givens = bool(givens)
# See if we have any mutable / borrow inputs
check_for_aliased_inputs = False
for i in inputs:
if (isinstance(i, In) and ((hasattr(i, 'borrow') and i.borrow) or
(hasattr(i, 'mutable') and i.mutable))):
check_for_aliased_inputs = True
if uses_tuple:
# we must use old semantics in this case.
if profile:
......@@ -323,7 +316,4 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
on_unused_input=on_unused_input,
profile=profile,
output_keys=output_keys)
# We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs
return fn
......@@ -375,6 +375,15 @@ class Function(object):
self.nodes_with_inner_function = []
self.output_keys = output_keys
# See if we have any mutable / borrow inputs
# TODO: this only need to be set if there is more then 1 input
self._check_for_aliased_inputs = False
for i in maker.inputs:
if (isinstance(i, In) and ((hasattr(i, 'borrow') and i.borrow) or
(hasattr(i, 'mutable') and i.mutable))):
self._check_for_aliased_inputs = True
break
# We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage)
finder = {}
......@@ -821,6 +830,7 @@ class Function(object):
self[k] = arg
if (not self.trust_input and
# The getattr is only needed for old pickle
getattr(self, '_check_for_aliased_inputs', True)):
# Collect aliased inputs among the storage space
args_share_memory = []
......
......@@ -609,6 +609,7 @@ class T_picklefunction(unittest.TestCase):
self.assertFalse(x in g.container)
self.assertFalse(x in g.value)
self.assertTrue(len(f.defaults) == len(g.defaults))
self.assertTrue(f._check_for_aliased_inputs is g._check_for_aliased_inputs)
# print 'f.defaults = %s' % (f.defaults, )
# print 'g.defaults = %s' % (g.defaults, )
self.assertTrue(all([f_req == g_req and f_feed == g_feed and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论