提交 af1f4045 authored 作者: ChienliMa's avatar ChienliMa

Add ```storage_map```argument for Profile_Maker;CLinker new support…

Add ```storage_map```argument for Profile_Maker;CLinker new support ```storage_map```(without testing)
上级 af1e743d
...@@ -43,7 +43,9 @@ AddConfigVar('ProfileMode.profile_memory', ...@@ -43,7 +43,9 @@ AddConfigVar('ProfileMode.profile_memory',
class Profile_Maker(FunctionMaker): class Profile_Maker(FunctionMaker):
def create(self, input_storage=None, trustme=False): # storage_map does not work in Profile_Maker
# I just add this argument to fit the interface. -- ChienliMa 03.June.2015
def create(self, input_storage=None, trustme=False, storage_map=None):
ret = super(Profile_Maker, self).create(input_storage, trustme) ret = super(Profile_Maker, self).create(input_storage, trustme)
if (hasattr(theano, 'sandbox') and if (hasattr(theano, 'sandbox') and
......
...@@ -1069,11 +1069,9 @@ class CLinker(link.Linker): ...@@ -1069,11 +1069,9 @@ class CLinker(link.Linker):
pass pass
return utils.uniq(ret) return utils.uniq(ret)
def __compile__(self, input_storage=None, def __compile__(self, input_storage=None, output_storage=None,
output_storage=None, keep_lock=False): storage_map=None, keep_lock=False):
""" """WRITEME
WRITEME
Compiles this linker's fgraph. Compiles this linker's fgraph.
Parameters Parameters
...@@ -1111,6 +1109,7 @@ class CLinker(link.Linker): ...@@ -1111,6 +1109,7 @@ class CLinker(link.Linker):
thunk = self.cthunk_factory(error_storage, thunk = self.cthunk_factory(error_storage,
input_storage, input_storage,
output_storage, output_storage,
storage_map,
keep_lock=keep_lock) keep_lock=keep_lock)
return (thunk, return (thunk,
[link.Container(input, storage) for input, storage in [link.Container(input, storage) for input, storage in
...@@ -1143,10 +1142,8 @@ class CLinker(link.Linker): ...@@ -1143,10 +1142,8 @@ class CLinker(link.Linker):
return init_tasks, tasks return init_tasks, tasks
def make_thunk(self, input_storage=None, output_storage=None, def make_thunk(self, input_storage=None, output_storage=None,
keep_lock=False): storage_map=None, keep_lock=False):
""" """WRITEME
WRITEME
Compiles this linker's fgraph and returns a function to perform the Compiles this linker's fgraph and returns a function to perform the
computations, as well as lists of storage cells for both the inputs computations, as well as lists of storage cells for both the inputs
and outputs. and outputs.
...@@ -1157,25 +1154,24 @@ class CLinker(link.Linker): ...@@ -1157,25 +1154,24 @@ class CLinker(link.Linker):
List of lists of length 1. In order to use List of lists of length 1. In order to use
the thunk returned by __compile__, the inputs must be put in the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated. that storage. If None, storage will be allocated.
output_storage: list of lists of length 1 @param output_storage: list of lists of length 1. The thunk returned
The thunk returned by __compile__ will put the variables of the by __compile__ will put the variables of the computation in these
computation in these lists. If None, storage will be allocated. lists. If None, storage will be allocated.
@param storage_map: dict that map variables to storages. This is used
Returns when you need to customize the storage of this thunk.
-------
object Returns: thunk, input_storage, output_storage
Thunk, input_storage, output_storage.
The return values can be used as follows: The return values can be used as follows:
f, istor, ostor = clinker.make_thunk() f, istor, ostor = clinker.make_thunk()
istor[0].data = first_input istor[0].data = first_input
istor[1].data = second_input istor[1].data = second_input
f() f()
first_output = ostor[0].data first_output = ostor[0].data
""" """
init_tasks, tasks = self.get_init_tasks() init_tasks, tasks = self.get_init_tasks()
cthunk, in_storage, out_storage, error_storage = self.__compile__( cthunk, in_storage, out_storage, error_storage = self.__compile__(
input_storage, output_storage, input_storage, output_storage, storage_map,
keep_lock=keep_lock) keep_lock=keep_lock)
res = _CThunk(cthunk, init_tasks, tasks, error_storage) res = _CThunk(cthunk, init_tasks, tasks, error_storage)
...@@ -1529,25 +1525,17 @@ class CLinker(link.Linker): ...@@ -1529,25 +1525,17 @@ class CLinker(link.Linker):
return self._mod return self._mod
def cthunk_factory(self, error_storage, in_storage, out_storage, def cthunk_factory(self, error_storage, in_storage, out_storage,
keep_lock=False): storage_map=None, keep_lock=False):
""" """WRITEME
WRITEME error_storage -> list of length 3
in_storage -> list of lists of length 1, one per input
Parameters out_storage -> list of lists of length 1, one per output
----------
error_storage : list of length 3 Returns a thunk that points to an instance of a C struct that
in_storage : list of lists of length 1, one per input can carry on the computation of this linker's fgraph. That thunk,
out_storage : list of lists of length 1, one per output when executed, will fetch its inputs from in_storage, put its
outputs in out_storage and if an error occurs will put the
Returns type, value and traceback of the exception in error_storage.
-------
object
A thunk that points to an instance of a C struct that
can carry on the computation of this linker's fgraph. That thunk,
when executed, will fetch its inputs from in_storage, put its
outputs in out_storage and if an error occurs will put the
type, value and traceback of the exception in error_storage.
""" """
try: try:
key = self.cmodule_key() key = self.cmodule_key()
...@@ -1569,7 +1557,10 @@ class CLinker(link.Linker): ...@@ -1569,7 +1557,10 @@ class CLinker(link.Linker):
out_storage = [x for i, x in enumerate(out_storage) out_storage = [x for i, x in enumerate(out_storage)
if (i + len(in_storage)) not in dupidx] if (i + len(in_storage)) not in dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx] in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx]
orphd = [[orphan.data] for orphan in self.orphans] if storage_map is None:
orphd = [storage_map[orphan] for orphan in self.orphans]
else:
orphd = [[orphan.data] for orphan in self.orphans]
ret = module.instantiate(error_storage, ret = module.instantiate(error_storage,
*(in_storage + out_storage + orphd)) *(in_storage + out_storage + orphd))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论