提交 3a3a8f08 authored 作者: ChienliMa's avatar ChienliMa

Add argument ```storage_map``` to the rest of theano linker

上级 e2f3cebf
...@@ -1827,7 +1827,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1827,7 +1827,7 @@ class _Linker(gof.link.LocalLinker):
return self return self
def make_all(self, profiler=None, input_storage=None, def make_all(self, profiler=None, input_storage=None,
output_storage=None): output_storage=None, storage_map=None):
# can't import at toplevel because of circular import TODO: # can't import at toplevel because of circular import TODO:
# don't do this ugly hacky way of setting the # don't do this ugly hacky way of setting the
# filter_checks_isfinite # filter_checks_isfinite
...@@ -1857,7 +1857,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1857,7 +1857,7 @@ class _Linker(gof.link.LocalLinker):
no_recycling = [] no_recycling = []
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage_, output_storage_) fgraph, order, input_storage_, output_storage_, storage_map)
thunks_py = [] # python thunks thunks_py = [] # python thunks
thunks_c = [] # c thunks thunks_c = [] # c thunks
......
...@@ -1727,7 +1727,8 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1727,7 +1727,8 @@ class OpWiseCLinker(link.LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_all(self, profiler=None, input_storage=None, output_storage=None): def make_all(self, profiler=None, input_storage=None, output_storage=None,
storage_map=None):
# The lock will be acquired when we compile the first # The lock will be acquired when we compile the first
# C code. We will keep the lock untill all the function # C code. We will keep the lock untill all the function
...@@ -1741,7 +1742,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1741,7 +1742,7 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage, output_storage) fgraph, order, input_storage, output_storage, storage_map)
if self.allow_gc: if self.allow_gc:
computed, last_user = link.gc_helper(order) computed, last_user = link.gc_helper(order)
post_thunk_old_storage = [] post_thunk_old_storage = []
......
...@@ -1000,14 +1000,14 @@ class VM_Linker(link.LocalLinker): ...@@ -1000,14 +1000,14 @@ class VM_Linker(link.LocalLinker):
return vm return vm
def make_all(self, profiler=None, input_storage=None, def make_all(self, profiler=None, input_storage=None,
output_storage=None, output_storage=None, storage_map=None,
): ):
fgraph = self.fgraph fgraph = self.fgraph
order = self.schedule(fgraph) order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage, output_storage) fgraph, order, input_storage, output_storage, storage_map)
compute_map = {} compute_map = {}
for k in storage_map: for k in storage_map:
compute_map[k] = [k.owner is None] compute_map[k] = [k.owner is None]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论