提交 22f44f74 authored 作者: Frederic Bastien's avatar Frederic Bastien

Test that we reuse the same c module

上级 bdadee94
......@@ -1146,12 +1146,13 @@ class CLinker(link.Linker):
output_storage.append(map[variable])
input_storage = tuple(input_storage)
output_storage = tuple(output_storage)
thunk = self.cthunk_factory(error_storage,
input_storage,
output_storage,
storage_map,
keep_lock=keep_lock)
thunk, module = self.cthunk_factory(error_storage,
input_storage,
output_storage,
storage_map,
keep_lock=keep_lock)
return (thunk,
module,
[link.Container(input, storage) for input, storage in
izip(self.fgraph.inputs, input_storage)],
[link.Container(output, storage, True) for output, storage in
......@@ -1207,11 +1208,11 @@ class CLinker(link.Linker):
first_output = ostor[0].data
"""
init_tasks, tasks = self.get_init_tasks()
cthunk, in_storage, out_storage, error_storage = self.__compile__(
cthunk, module, in_storage, out_storage, error_storage = self.__compile__(
input_storage, output_storage, storage_map,
keep_lock=keep_lock)
res = _CThunk(cthunk, init_tasks, tasks, error_storage)
res = _CThunk(cthunk, init_tasks, tasks, error_storage, module)
res.nodes = self.node_order
return res, in_storage, out_storage
......@@ -1623,8 +1624,7 @@ class CLinker(link.Linker):
ret = module.instantiate(error_storage,
*(in_storage + out_storage + orphd))
return ret
return ret, module
def instantiate_code(self, n_args):
code = StringIO()
......@@ -1669,10 +1669,13 @@ class _CThunk(object):
WRITEME
error_storage
WRITEME
module
The module that was used to compile this cthunk.
Mostly only useful for tests.
"""
def __init__(self, cthunk, init_tasks, tasks, error_storage):
def __init__(self, cthunk, init_tasks, tasks, error_storage, module):
global run_cthunk
if run_cthunk is None:
# Lazy import to avoid compilation when importing theano.
......@@ -1681,6 +1684,7 @@ class _CThunk(object):
self.init_tasks = init_tasks
self.tasks = tasks
self.error_storage = error_storage
self.module = module
def find_task(self, failure_code):
"""
......
......@@ -856,14 +856,15 @@ class Op(utils.object2, PureOp, CLinkerOp):
_logger.debug('Trying CLinker.make_thunk')
outputs = cl.make_thunk(input_storage=node_input_storage,
output_storage=node_output_storage)
fill_storage, node_input_filters, node_output_filters = outputs
thunk, node_input_filters, node_output_filters = outputs
def rval():
fill_storage()
thunk()
for o in node.outputs:
compute_map[o][0] = True
rval.cthunk = fill_storage.cthunk
rval.thunk = thunk
rval.cthunk = thunk.cthunk
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.lazy = False
......
......@@ -442,7 +442,14 @@ def test_reallocation():
def test_no_recycling():
x = theano.tensor.vector()
mode = theano.Mode(optimizer='fast_compile')
f = theano.function([x], x + 1, mode=mode)
f2 = theano.function([x], (x + 1) * 2, mode=mode)
theano.printing.debugprint([f, f2])
for lnk in [vm.VM_Linker(use_cloop=True),
vm.VM_Linker(use_cloop=False, lazy=True),
vm.VM_Linker(use_cloop=False, lazy=False, allow_gc=True),
vm.VM_Linker(use_cloop=False, lazy=False, allow_gc=False)]:
mode = theano.Mode(optimizer='fast_compile', linker=lnk)
f = theano.function([x], x + 1, mode=mode)
f2 = theano.function([x], (x + 1) * 2, mode=mode)
m1 = f.fn.thunks[0].thunk.module
m2 = f2.fn.thunks[0].thunk.module
assert m1 is m2
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论