提交 d5237cca authored 作者: Razvan Pascanu's avatar Razvan Pascanu

merge

...@@ -11,7 +11,7 @@ from pfunc import pfunc ...@@ -11,7 +11,7 @@ from pfunc import pfunc
from numpy import any #for to work in python 2.4 from numpy import any #for to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=[], givens=[], def function(inputs, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict = True): rebuild_strict = True):
""" """
Return a callable object that will calculate `outputs` from `inputs`. Return a callable object that will calculate `outputs` from `inputs`.
...@@ -71,19 +71,30 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], ...@@ -71,19 +71,30 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
uses_updates = (updates != []) uses_updates = (updates != [])
uses_givens = (givens != []) uses_givens = (givens != [])
# See if we have any mutable / borrow inputs
check_for_aliased_inputs = False
for i in inputs:
if (isinstance(i, In) and ( (hasattr(i,'borrow') and i.borrow) or
(hasattr(i,'mutable') and i.mutable)) ):
check_for_aliased_inputs = True
if uses_In or uses_tuple: if uses_In or uses_tuple:
# we must use old semantics in this case. # we must use old semantics in this case.
if uses_updates or uses_givens: if uses_updates or uses_givens:
raise NotImplementedError("In() instances and tuple inputs triggers the old semantics, which disallow using updates and givens") raise NotImplementedError("In() instances and tuple inputs triggers the old semantics, which disallow using updates and givens")
return orig_function(inputs, outputs, fn = orig_function(inputs, outputs,
mode=mode, mode=mode,
accept_inplace=accept_inplace, name=name) accept_inplace=accept_inplace, name=name)
else: else:
return pfunc(params=inputs, fn = pfunc(params=inputs,
outputs=outputs, outputs=outputs,
mode=mode, mode=mode,
updates=updates, updates=updates,
givens=givens, givens=givens,
no_default_updates=no_default_updates, no_default_updates=no_default_updates,
accept_inplace=accept_inplace,name=name, accept_inplace=accept_inplace,name=name,
rebuild_strict=rebuild_strict) rebuild_strict=rebuild_strict)
# We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs
return fn
...@@ -531,32 +531,40 @@ class Function(object): ...@@ -531,32 +531,40 @@ class Function(object):
for k, arg in kwargs.iteritems(): for k, arg in kwargs.iteritems():
self[k] = arg self[k] = arg
## Collect aliased inputs among the storage space
args_share_memory = []
for i in xrange(len(self.input_storage)):
if self.input_storage[i].storage[0] is not None:
is_aliased = False
for j in xrange(len(args_share_memory)):
if numpy.may_share_memory(
self.input_storage[i].storage[0] ,
self.input_storage[args_share_memory[j][0]].storage[0]):
is_aliased = True
args_share_memory[j].append(i)
break
if not is_aliased: if ( not hasattr(self, '_check_for_aliased_inputs') or
args_share_memory.append([i]) self._check_for_aliased_inputs):
## Collect aliased inputs among the storage space
# Check for groups of more than one argument that share memory args_share_memory = []
for group in args_share_memory: for i in xrange(len(self.input_storage)):
if len(group) > 1: if isinstance(self.input_storage[i].storage[0],
# see if any of these arguments are mutable numpy.ndarray):
mutable = numpy.any([self.maker.inputs[idx].mutable is_aliased = False
for idx in group]) for j in xrange(len(args_share_memory)):
# copy all but the first for k in args_share_memory[j]:
for idx in group[1:]: if numpy.may_share_memory(
self.input_storage[i].storage[0] = copy.copy( self.input_storage[i].storage[0] ,
self.input_storage[i].storage[0]) self.input_storage[k].storage[0]):
is_aliased = True
args_share_memory[j].append(i)
break
if is_aliased:
break
if not is_aliased:
args_share_memory.append([i])
# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# see if any of these arguments are mutable
mutable = numpy.any([(self.maker.inputs[idx].mutable or
self.maker.inputs[idx].borrow )
for idx in group ])
# copy all but the first
for idx in group[1:]:
self.input_storage[i].storage[0] = copy.copy(
self.input_storage[i].storage[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论