提交 95b8f998 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix the original problem of gh-5736

上级 54c49ef0
...@@ -289,13 +289,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -289,13 +289,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
uses_updates = bool(updates) uses_updates = bool(updates)
uses_givens = bool(givens) uses_givens = bool(givens)
# See if we have any mutable / borrow inputs
check_for_aliased_inputs = False
for i in inputs:
if (isinstance(i, In) and ((hasattr(i, 'borrow') and i.borrow) or
(hasattr(i, 'mutable') and i.mutable))):
check_for_aliased_inputs = True
if uses_tuple: if uses_tuple:
# we must use old semantics in this case. # we must use old semantics in this case.
if profile: if profile:
...@@ -323,7 +316,4 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -323,7 +316,4 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
profile=profile, profile=profile,
output_keys=output_keys) output_keys=output_keys)
# We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs
return fn return fn
...@@ -375,6 +375,15 @@ class Function(object): ...@@ -375,6 +375,15 @@ class Function(object):
self.nodes_with_inner_function = [] self.nodes_with_inner_function = []
self.output_keys = output_keys self.output_keys = output_keys
# See if we have any mutable / borrow inputs
# TODO: this only need to be set if there is more then 1 input
self._check_for_aliased_inputs = False
for i in maker.inputs:
if (isinstance(i, In) and ((hasattr(i, 'borrow') and i.borrow) or
(hasattr(i, 'mutable') and i.mutable))):
self._check_for_aliased_inputs = True
break
# We will be popping stuff off this `containers` object. It is a copy. # We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage) containers = list(self.input_storage)
finder = {} finder = {}
...@@ -821,6 +830,7 @@ class Function(object): ...@@ -821,6 +830,7 @@ class Function(object):
self[k] = arg self[k] = arg
if (not self.trust_input and if (not self.trust_input and
# The getattr is only needed for old pickle
getattr(self, '_check_for_aliased_inputs', True)): getattr(self, '_check_for_aliased_inputs', True)):
# Collect aliased inputs among the storage space # Collect aliased inputs among the storage space
args_share_memory = [] args_share_memory = []
......
...@@ -609,6 +609,7 @@ class T_picklefunction(unittest.TestCase): ...@@ -609,6 +609,7 @@ class T_picklefunction(unittest.TestCase):
self.assertFalse(x in g.container) self.assertFalse(x in g.container)
self.assertFalse(x in g.value) self.assertFalse(x in g.value)
self.assertTrue(len(f.defaults) == len(g.defaults)) self.assertTrue(len(f.defaults) == len(g.defaults))
self.assertTrue(f._check_for_aliased_inputs is g._check_for_aliased_inputs)
# print 'f.defaults = %s' % (f.defaults, ) # print 'f.defaults = %s' % (f.defaults, )
# print 'g.defaults = %s' % (g.defaults, ) # print 'g.defaults = %s' % (g.defaults, )
self.assertTrue(all([f_req == g_req and f_feed == g_feed and self.assertTrue(all([f_req == g_req and f_feed == g_feed and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论