提交 d5237cca authored 作者: Razvan Pascanu's avatar Razvan Pascanu

merge

...@@ -71,15 +71,22 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], ...@@ -71,15 +71,22 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
uses_updates = (updates != []) uses_updates = (updates != [])
uses_givens = (givens != []) uses_givens = (givens != [])
# See if we have any mutable / borrow inputs
check_for_aliased_inputs = False
for i in inputs:
if (isinstance(i, In) and ( (hasattr(i,'borrow') and i.borrow) or
(hasattr(i,'mutable') and i.mutable)) ):
check_for_aliased_inputs = True
if uses_In or uses_tuple: if uses_In or uses_tuple:
# we must use old semantics in this case. # we must use old semantics in this case.
if uses_updates or uses_givens: if uses_updates or uses_givens:
raise NotImplementedError("In() instances and tuple inputs triggers the old semantics, which disallow using updates and givens") raise NotImplementedError("In() instances and tuple inputs triggers the old semantics, which disallow using updates and givens")
return orig_function(inputs, outputs, fn = orig_function(inputs, outputs,
mode=mode, mode=mode,
accept_inplace=accept_inplace, name=name) accept_inplace=accept_inplace, name=name)
else: else:
return pfunc(params=inputs, fn = pfunc(params=inputs,
outputs=outputs, outputs=outputs,
mode=mode, mode=mode,
updates=updates, updates=updates,
...@@ -87,3 +94,7 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], ...@@ -87,3 +94,7 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=no_default_updates, no_default_updates=no_default_updates,
accept_inplace=accept_inplace,name=name, accept_inplace=accept_inplace,name=name,
rebuild_strict=rebuild_strict) rebuild_strict=rebuild_strict)
# We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs
return fn
...@@ -531,18 +531,25 @@ class Function(object): ...@@ -531,18 +531,25 @@ class Function(object):
for k, arg in kwargs.iteritems(): for k, arg in kwargs.iteritems():
self[k] = arg self[k] = arg
if ( not hasattr(self, '_check_for_aliased_inputs') or
self._check_for_aliased_inputs):
## Collect aliased inputs among the storage space ## Collect aliased inputs among the storage space
args_share_memory = [] args_share_memory = []
for i in xrange(len(self.input_storage)): for i in xrange(len(self.input_storage)):
if self.input_storage[i].storage[0] is not None: if isinstance(self.input_storage[i].storage[0],
numpy.ndarray):
is_aliased = False is_aliased = False
for j in xrange(len(args_share_memory)): for j in xrange(len(args_share_memory)):
for k in args_share_memory[j]:
if numpy.may_share_memory( if numpy.may_share_memory(
self.input_storage[i].storage[0] , self.input_storage[i].storage[0] ,
self.input_storage[args_share_memory[j][0]].storage[0]): self.input_storage[k].storage[0]):
is_aliased = True is_aliased = True
args_share_memory[j].append(i) args_share_memory[j].append(i)
break break
if is_aliased:
break
if not is_aliased: if not is_aliased:
args_share_memory.append([i]) args_share_memory.append([i])
...@@ -551,8 +558,9 @@ class Function(object): ...@@ -551,8 +558,9 @@ class Function(object):
for group in args_share_memory: for group in args_share_memory:
if len(group) > 1: if len(group) > 1:
# see if any of these arguments are mutable # see if any of these arguments are mutable
mutable = numpy.any([self.maker.inputs[idx].mutable mutable = numpy.any([(self.maker.inputs[idx].mutable or
for idx in group]) self.maker.inputs[idx].borrow )
for idx in group ])
# copy all but the first # copy all but the first
for idx in group[1:]: for idx in group[1:]:
self.input_storage[i].storage[0] = copy.copy( self.input_storage[i].storage[0] = copy.copy(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论