提交 be918f5d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add JITed functions as an attribute to the thunk-computing functions

This means that an `aesara.function`-compiled result, e.g. `func`, will expose the underlying JITed function as `func.fn.jit_fn`.
上级 c575da2c
......@@ -676,6 +676,8 @@ class JITLinker(PerformLinker):
A tuple containing the thunks.
output_nodes: list and their
A tuple containing the output nodes.
jit_fn: callable
The JITed function that performs the computations.
"""
output_nodes = [o.owner for o in self.fgraph.outputs]
......@@ -721,7 +723,7 @@ class JITLinker(PerformLinker):
thunks.append(thunk)
# This is a bit hackish, but we only return one of the output nodes
return thunks, output_nodes[:1]
return thunks, output_nodes[:1], fgraph_jit
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
......@@ -736,7 +738,7 @@ class JITLinker(PerformLinker):
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks, nodes = self.create_jitable_thunk(
thunks, nodes, jit_fn = self.create_jitable_thunk(
compute_map, nodes, input_storage, output_storage, storage_map
)
......@@ -770,6 +772,7 @@ class JITLinker(PerformLinker):
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)
fn.jit_fn = jit_fn
fn.allow_gc = self.allow_gc
fn.storage_map = storage_map
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论