提交 22f44f74 authored 作者: Frederic Bastien's avatar Frederic Bastien

Test that we reuse the same c module

上级 bdadee94
...@@ -1146,12 +1146,13 @@ class CLinker(link.Linker): ...@@ -1146,12 +1146,13 @@ class CLinker(link.Linker):
output_storage.append(map[variable]) output_storage.append(map[variable])
input_storage = tuple(input_storage) input_storage = tuple(input_storage)
output_storage = tuple(output_storage) output_storage = tuple(output_storage)
thunk = self.cthunk_factory(error_storage, thunk, module = self.cthunk_factory(error_storage,
input_storage, input_storage,
output_storage, output_storage,
storage_map, storage_map,
keep_lock=keep_lock) keep_lock=keep_lock)
return (thunk, return (thunk,
module,
[link.Container(input, storage) for input, storage in [link.Container(input, storage) for input, storage in
izip(self.fgraph.inputs, input_storage)], izip(self.fgraph.inputs, input_storage)],
[link.Container(output, storage, True) for output, storage in [link.Container(output, storage, True) for output, storage in
...@@ -1207,11 +1208,11 @@ class CLinker(link.Linker): ...@@ -1207,11 +1208,11 @@ class CLinker(link.Linker):
first_output = ostor[0].data first_output = ostor[0].data
""" """
init_tasks, tasks = self.get_init_tasks() init_tasks, tasks = self.get_init_tasks()
cthunk, in_storage, out_storage, error_storage = self.__compile__( cthunk, module, in_storage, out_storage, error_storage = self.__compile__(
input_storage, output_storage, storage_map, input_storage, output_storage, storage_map,
keep_lock=keep_lock) keep_lock=keep_lock)
res = _CThunk(cthunk, init_tasks, tasks, error_storage) res = _CThunk(cthunk, init_tasks, tasks, error_storage, module)
res.nodes = self.node_order res.nodes = self.node_order
return res, in_storage, out_storage return res, in_storage, out_storage
...@@ -1623,8 +1624,7 @@ class CLinker(link.Linker): ...@@ -1623,8 +1624,7 @@ class CLinker(link.Linker):
ret = module.instantiate(error_storage, ret = module.instantiate(error_storage,
*(in_storage + out_storage + orphd)) *(in_storage + out_storage + orphd))
return ret, module
return ret
def instantiate_code(self, n_args): def instantiate_code(self, n_args):
code = StringIO() code = StringIO()
...@@ -1669,10 +1669,13 @@ class _CThunk(object): ...@@ -1669,10 +1669,13 @@ class _CThunk(object):
WRITEME WRITEME
error_storage error_storage
WRITEME WRITEME
module
The module that was used to compile this cthunk.
Mostly only useful for tests.
""" """
def __init__(self, cthunk, init_tasks, tasks, error_storage): def __init__(self, cthunk, init_tasks, tasks, error_storage, module):
global run_cthunk global run_cthunk
if run_cthunk is None: if run_cthunk is None:
# Lazy import to avoid compilation when importing theano. # Lazy import to avoid compilation when importing theano.
...@@ -1681,6 +1684,7 @@ class _CThunk(object): ...@@ -1681,6 +1684,7 @@ class _CThunk(object):
self.init_tasks = init_tasks self.init_tasks = init_tasks
self.tasks = tasks self.tasks = tasks
self.error_storage = error_storage self.error_storage = error_storage
self.module = module
def find_task(self, failure_code): def find_task(self, failure_code):
""" """
......
...@@ -856,14 +856,15 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -856,14 +856,15 @@ class Op(utils.object2, PureOp, CLinkerOp):
_logger.debug('Trying CLinker.make_thunk') _logger.debug('Trying CLinker.make_thunk')
outputs = cl.make_thunk(input_storage=node_input_storage, outputs = cl.make_thunk(input_storage=node_input_storage,
output_storage=node_output_storage) output_storage=node_output_storage)
fill_storage, node_input_filters, node_output_filters = outputs thunk, node_input_filters, node_output_filters = outputs
def rval(): def rval():
fill_storage() thunk()
for o in node.outputs: for o in node.outputs:
compute_map[o][0] = True compute_map[o][0] = True
rval.cthunk = fill_storage.cthunk rval.thunk = thunk
rval.cthunk = thunk.cthunk
rval.inputs = node_input_storage rval.inputs = node_input_storage
rval.outputs = node_output_storage rval.outputs = node_output_storage
rval.lazy = False rval.lazy = False
......
...@@ -442,7 +442,14 @@ def test_reallocation(): ...@@ -442,7 +442,14 @@ def test_reallocation():
def test_no_recycling(): def test_no_recycling():
x = theano.tensor.vector() x = theano.tensor.vector()
mode = theano.Mode(optimizer='fast_compile') for lnk in [vm.VM_Linker(use_cloop=True),
f = theano.function([x], x + 1, mode=mode) vm.VM_Linker(use_cloop=False, lazy=True),
f2 = theano.function([x], (x + 1) * 2, mode=mode) vm.VM_Linker(use_cloop=False, lazy=False, allow_gc=True),
theano.printing.debugprint([f, f2]) vm.VM_Linker(use_cloop=False, lazy=False, allow_gc=False)]:
mode = theano.Mode(optimizer='fast_compile', linker=lnk)
f = theano.function([x], x + 1, mode=mode)
f2 = theano.function([x], (x + 1) * 2, mode=mode)
m1 = f.fn.thunks[0].thunk.module
m2 = f2.fn.thunks[0].thunk.module
assert m1 is m2
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论