提交 d5237cca authored 作者: Razvan Pascanu's avatar Razvan Pascanu

merge

......@@ -11,7 +11,7 @@ from pfunc import pfunc
from numpy import any #for to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None,
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict = True):
"""
Return a callable object that will calculate `outputs` from `inputs`.
......@@ -71,19 +71,30 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
uses_updates = (updates != [])
uses_givens = (givens != [])
# See if we have any mutable / borrow inputs
check_for_aliased_inputs = False
for i in inputs:
if (isinstance(i, In) and ( (hasattr(i,'borrow') and i.borrow) or
(hasattr(i,'mutable') and i.mutable)) ):
check_for_aliased_inputs = True
if uses_In or uses_tuple:
# we must use old semantics in this case.
if uses_updates or uses_givens:
raise NotImplementedError("In() instances and tuple inputs triggers the old semantics, which disallow using updates and givens")
return orig_function(inputs, outputs,
fn = orig_function(inputs, outputs,
mode=mode,
accept_inplace=accept_inplace, name=name)
else:
return pfunc(params=inputs,
fn = pfunc(params=inputs,
outputs=outputs,
mode=mode,
updates=updates,
mode=mode,
updates=updates,
givens=givens,
no_default_updates=no_default_updates,
accept_inplace=accept_inplace,name=name,
rebuild_strict=rebuild_strict)
# We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs
return fn
......@@ -531,32 +531,40 @@ class Function(object):
for k, arg in kwargs.iteritems():
self[k] = arg
## Collect aliased inputs among the storage space
args_share_memory = []
for i in xrange(len(self.input_storage)):
if self.input_storage[i].storage[0] is not None:
is_aliased = False
for j in xrange(len(args_share_memory)):
if numpy.may_share_memory(
self.input_storage[i].storage[0] ,
self.input_storage[args_share_memory[j][0]].storage[0]):
is_aliased = True
args_share_memory[j].append(i)
break
if not is_aliased:
args_share_memory.append([i])
# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# see if any of these arguments are mutable
mutable = numpy.any([self.maker.inputs[idx].mutable
for idx in group])
# copy all but the first
for idx in group[1:]:
self.input_storage[i].storage[0] = copy.copy(
self.input_storage[i].storage[0])
if ( not hasattr(self, '_check_for_aliased_inputs') or
self._check_for_aliased_inputs):
## Collect aliased inputs among the storage space
args_share_memory = []
for i in xrange(len(self.input_storage)):
if isinstance(self.input_storage[i].storage[0],
numpy.ndarray):
is_aliased = False
for j in xrange(len(args_share_memory)):
for k in args_share_memory[j]:
if numpy.may_share_memory(
self.input_storage[i].storage[0] ,
self.input_storage[k].storage[0]):
is_aliased = True
args_share_memory[j].append(i)
break
if is_aliased:
break
if not is_aliased:
args_share_memory.append([i])
# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# see if any of these arguments are mutable
mutable = numpy.any([(self.maker.inputs[idx].mutable or
self.maker.inputs[idx].borrow )
for idx in group ])
# copy all but the first
for idx in group[1:]:
self.input_storage[i].storage[0] = copy.copy(
self.input_storage[i].storage[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论