提交 be918f5d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add JITed functions as an attribute to the thunk-computing functions

This means that an `aesara.function`-compiled result, e.g. `func`, will expose the underlying JITed function as `func.fn.jit_fn`.
上级 c575da2c
...@@ -676,6 +676,8 @@ class JITLinker(PerformLinker): ...@@ -676,6 +676,8 @@ class JITLinker(PerformLinker):
A tuple containing the thunks. A tuple containing the thunks.
output_nodes: list and their output_nodes: list and their
A tuple containing the output nodes. A tuple containing the output nodes.
jit_fn: callable
The JITed function that performs the computations.
""" """
output_nodes = [o.owner for o in self.fgraph.outputs] output_nodes = [o.owner for o in self.fgraph.outputs]
...@@ -721,7 +723,7 @@ class JITLinker(PerformLinker): ...@@ -721,7 +723,7 @@ class JITLinker(PerformLinker):
thunks.append(thunk) thunks.append(thunk)
# This is a bit hackish, but we only return one of the output nodes # This is a bit hackish, but we only return one of the output nodes
return thunks, output_nodes[:1] return thunks, output_nodes[:1], fgraph_jit
def make_all(self, input_storage=None, output_storage=None, storage_map=None): def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph fgraph = self.fgraph
...@@ -736,7 +738,7 @@ class JITLinker(PerformLinker): ...@@ -736,7 +738,7 @@ class JITLinker(PerformLinker):
for k in storage_map: for k in storage_map:
compute_map[k] = [k.owner is None] compute_map[k] = [k.owner is None]
thunks, nodes = self.create_jitable_thunk( thunks, nodes, jit_fn = self.create_jitable_thunk(
compute_map, nodes, input_storage, output_storage, storage_map compute_map, nodes, input_storage, output_storage, storage_map
) )
...@@ -770,6 +772,7 @@ class JITLinker(PerformLinker): ...@@ -770,6 +772,7 @@ class JITLinker(PerformLinker):
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
) )
fn.jit_fn = jit_fn
fn.allow_gc = self.allow_gc fn.allow_gc = self.allow_gc
fn.storage_map = storage_map fn.storage_map = storage_map
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论