提交 c96641a4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle functions without outputs in JITLinker

上级 1453ba09
......@@ -656,21 +656,37 @@ class JITLinker(PerformLinker):
thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
fgraph_jit = self.jit_compile(converted_fgraph)
def thunk(
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
try:
outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
except Exception:
# TODO: Should we add a fake node that combines all outputs,
# since the error may come from any of them?
raise_with_op(self.fgraph, output_nodes[0], thunk)
if thunk_outputs:
# zip strict not specified because we are in a hot loop
for o_storage, o_val in zip(thunk_outputs, outputs):
o_storage[0] = o_val
def thunk(
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
try:
outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
except Exception:
# TODO: Should we add a fake node that combines all outputs,
# since the error may come from any of them?
raise_with_op(self.fgraph, output_nodes[0], thunk)
# zip strict not specified because we are in a hot loop
for o_storage, o_val in zip(thunk_outputs, outputs):
o_storage[0] = o_val
else:
# Edge case - functions without outputs
def thunk(
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
try:
res = fgraph_jit(*(x[0] for x in thunk_inputs))
except Exception:
raise_with_op(self.fgraph, output_nodes[0], thunk)
assert res is None
return thunk_outputs
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论