提交 c96641a4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle functions without outputs in JITLinker

上级 1453ba09
...@@ -656,6 +656,8 @@ class JITLinker(PerformLinker): ...@@ -656,6 +656,8 @@ class JITLinker(PerformLinker):
thunk_outputs = [storage_map[n] for n in self.fgraph.outputs] thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
fgraph_jit = self.jit_compile(converted_fgraph) fgraph_jit = self.jit_compile(converted_fgraph)
if thunk_outputs:
def thunk( def thunk(
fgraph_jit=fgraph_jit, fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs, thunk_inputs=thunk_inputs,
...@@ -672,6 +674,20 @@ class JITLinker(PerformLinker): ...@@ -672,6 +674,20 @@ class JITLinker(PerformLinker):
for o_storage, o_val in zip(thunk_outputs, outputs): for o_storage, o_val in zip(thunk_outputs, outputs):
o_storage[0] = o_val o_storage[0] = o_val
else:
# Edge case - functions without outputs
def thunk(
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
try:
res = fgraph_jit(*(x[0] for x in thunk_inputs))
except Exception:
raise_with_op(self.fgraph, output_nodes[0], thunk)
assert res is None
return thunk_outputs
thunk.inputs = thunk_inputs thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs thunk.outputs = thunk_outputs
thunk.lazy = False thunk.lazy = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论