提交 15edfb76 authored 作者: James Bergstra's avatar James Bergstra

minor edits to function_module

上级 56830cb5
...@@ -11,7 +11,7 @@ from functools import partial ...@@ -11,7 +11,7 @@ from functools import partial
import numpy import numpy
from .. import gof from .. import gof
import sys import sys
from copy import copy import copy
from mode import * from mode import *
from io import * from io import *
...@@ -168,10 +168,10 @@ class Function(object): ...@@ -168,10 +168,10 @@ class Function(object):
input.distribute(value, indices, cs) input.distribute(value, indices, cs)
for c in cs: for c in cs:
c.provided += 1 c.provided += 1
def assign(c, v): #def assign(c, v):
c.data = v #c.data = v
setters = [] #setters = []
# Initialize the storage # Initialize the storage
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(zip(self.indices, defaults)): for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(zip(self.indices, defaults)):
if indices is None: # this is true iff input is not a SymbolicInputKit if indices is None: # this is true iff input is not a SymbolicInputKit
...@@ -193,7 +193,7 @@ class Function(object): ...@@ -193,7 +193,7 @@ class Function(object):
finder[input.name] = c if input.name not in finder else DUPLICATE finder[input.name] = c if input.name not in finder else DUPLICATE
# inv_finder maps the container to the input (useful for one error message) # inv_finder maps the container to the input (useful for one error message)
inv_finder[c] = input inv_finder[c] = input
setters.append(partial(assign, c)) #setters.append(partial(assign, c))
containers[:1] = [] containers[:1] = []
else: else:
# The input is a SymbolicInputKit, so we take as many containers as the Kit provides inputs # The input is a SymbolicInputKit, so we take as many containers as the Kit provides inputs
...@@ -207,7 +207,7 @@ class Function(object): ...@@ -207,7 +207,7 @@ class Function(object):
finder[i] = f finder[i] = f
finder[input] = f finder[input] = f
finder[input.name] = f if input.name not in finder else DUPLICATE finder[input.name] = f if input.name not in finder else DUPLICATE
setters.append(f) #setters.append(f)
# For each input in the kit and its corresponding container, we put an entry in finder. # For each input in the kit and its corresponding container, we put an entry in finder.
# This allows the user to micro-manage elements of the kit if need be. # This allows the user to micro-manage elements of the kit if need be.
# All containers inherit the required field and have their own "provided" counter # All containers inherit the required field and have their own "provided" counter
...@@ -278,7 +278,7 @@ class Function(object): ...@@ -278,7 +278,7 @@ class Function(object):
cpy = self.maker.create(defaults, trustme = True) cpy = self.maker.create(defaults, trustme = True)
for (input,_1,_2), here, there in zip(self.indices, self.input_storage, cpy.input_storage): for (input,_1,_2), here, there in zip(self.indices, self.input_storage, cpy.input_storage):
if input.mutable and here is not None: if input.mutable and here is not None:
there.data = copy(here.data) there.data = copy.copy(here.data)
else: else:
there.data = here.data there.data = here.data
return cpy return cpy
...@@ -369,13 +369,14 @@ def _pickle_Function(f): ...@@ -369,13 +369,14 @@ def _pickle_Function(f):
for i, d_i in enumerate(all_data): for i, d_i in enumerate(all_data):
for j, d_j in enumerate(all_data): for j, d_j in enumerate(all_data):
if (i < j) and isinstance(d_i, numpy.ndarray) and isinstance(d_j, numpy.ndarray): if (i < j) and isinstance(d_i, numpy.ndarray) and isinstance(d_j, numpy.ndarray):
if f.pickle_aliased_memory_strategy == 'warn': if numpy.may_share_memory(d_i, d_j):
print >> sys.stderr, ('WARNING: ' if f.pickle_aliased_memory_strategy == 'warn':
'aliased relationship between Function arguments ' print >> sys.stderr, ('WARNING: '
'will not be preserved by un-pickling operation') 'aliased relationship between Function arguments '
#print >> sys.stderr, d_i, d_j, id(d_i), id(d_j) 'will not be preserved by un-pickling operation')
else: #print >> sys.stderr, d_i, d_j, id(d_i), id(d_j)
raise AliasedMemoryError(d_i, d_j) else:
raise AliasedMemoryError(d_i, d_j)
rval = (_constructor_Function, (f.maker, defaults, inputs_data)) rval = (_constructor_Function, (f.maker, defaults, inputs_data))
return rval return rval
...@@ -413,11 +414,11 @@ class SanityCheckFunction(Function): ...@@ -413,11 +414,11 @@ class SanityCheckFunction(Function):
for fn in self.others: for fn in self.others:
for stor1, stor2 in zip(self.input_storage, fn.input_storage): for stor1, stor2 in zip(self.input_storage, fn.input_storage):
stor2.value = copy(stor1.value) stor2.value = copy.copy(stor1.value)
variables = super(SanityCheckFunction, self).__call__(*args, **kwargs) variables = super(SanityCheckFunction, self).__call__(*args, **kwargs)
all_outputs = [copy(c.value) for c in self.output_storage] # we keep a copy to make sure it's not overwritten all_outputs = [copy.copy(c.value) for c in self.output_storage] # we keep a copy to make sure it's not overwritten
for fn in self.others: for fn in self.others:
fn(*args, **kwargs) fn(*args, **kwargs)
...@@ -536,7 +537,7 @@ class FunctionMaker(object): ...@@ -536,7 +537,7 @@ class FunctionMaker(object):
# Fetch the mode and then the optimizer and linker # Fetch the mode and then the optimizer and linker
mode = predefined_modes.get(mode, mode) mode = predefined_modes.get(mode, mode)
optimizer, linker = mode.optimizer, copy(mode.linker) optimizer, linker = mode.optimizer, copy.copy(mode.linker)
# optimize the env # optimize the env
optimizer(env) optimizer(env)
...@@ -755,7 +756,8 @@ def function(inputs, outputs, mode=default_mode, accept_inplace = False): ...@@ -755,7 +756,8 @@ def function(inputs, outputs, mode=default_mode, accept_inplace = False):
else: else:
#return a different kind of function #return a different kind of function
def dup_defaults(): def dup_defaults():
return [copy(default.value) if isinstance(default, gof.Container) else copy(default) return [copy.copy(default.value) if isinstance(default, gof.Container) else
copy.copy(default)
for default in defaults] for default in defaults]
makers = [FunctionMaker(inputs, outputs, m, accept_inplace = accept_inplace) for m in mode[1:]] makers = [FunctionMaker(inputs, outputs, m, accept_inplace = accept_inplace) for m in mode[1:]]
fns = [maker.create(dup_defaults(), trustme = True) for maker in makers] fns = [maker.create(dup_defaults(), trustme = True) for maker in makers]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论