提交 43829581 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Extract source compilation into aesara.link.utils.compile_function_src helper

上级 3b552872
import ast
from functools import reduce, singledispatch
from tempfile import NamedTemporaryFile
import numba
import numpy as np
......@@ -10,7 +8,7 @@ import scipy.special
from aesara.compile.ops import DeepCopyOp
from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
from aesara.link.utils import fgraph_to_python
from aesara.link.utils import compile_function_src, fgraph_to_python
from aesara.scalar.basic import Composite, ScalarOp
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
......@@ -154,19 +152,12 @@ def numba_funcify_Subtensor(op, node, **kwargs):
node, idx_list, objmode=isinstance(op, AdvancedSubtensor)
)
subtensor_def_ast = ast.parse(subtensor_def_src)
global_env = {}
global_env["objmode"] = numba.objmode
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(subtensor_def_src.encode())
subtensor_fn = compile_function_src(subtensor_def_src, "subtensor", global_env)
local_env = {}
mod_code = compile(subtensor_def_ast, filename, mode="exec")
exec(mod_code, {"objmode": numba.objmode}, local_env)
subtensor_def = local_env["subtensor"]
return numba.njit(subtensor_def)
return numba.njit(subtensor_fn)
@numba_funcify.register(DeepCopyOp)
......
......@@ -573,6 +573,25 @@ def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout):
register_thunk_trace_excepthook()
def compile_function_src(src, function_name, global_env=None, local_env=None):
src_ast = ast.parse(src)
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(src.encode())
if global_env is None:
global_env = {}
if local_env is None:
local_env = {}
mod_code = compile(src_ast, filename, mode="exec")
exec(mod_code, global_env, local_env)
return local_env[function_name]
def fgraph_to_python(
fgraph: FunctionGraph,
op_conversion_fn: Callable,
......@@ -624,9 +643,6 @@ def fgraph_to_python(
fgraph, order, input_storage, output_storage, storage_map
)
if global_env is None:
global_env = {}
def unique_name(x, names_counter=Counter([fgraph_name]), obj_to_names={}):
if x in obj_to_names:
return obj_to_names[x]
......@@ -649,6 +665,9 @@ def fgraph_to_python(
return local_name
if global_env is None:
global_env = {}
body_assigns = []
for node in order:
jax_func = op_conversion_fn(node.op, node=node, **kwargs)
......@@ -690,20 +709,11 @@ def {fgraph_name}({", ".join(fgraph_input_names)}):
return {fgraph_return_src}
"""
fgraph_def_ast = ast.parse(fgraph_def_src)
# Create source code to be (at least temporarily) associated with the
# compiled function (e.g. for easier debugging)
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(fgraph_def_src.encode())
if local_env is None:
local_env = locals()
mod_code = compile(fgraph_def_ast, filename, mode="exec")
exec(mod_code, global_env, local_env)
fgraph_def = local_env[fgraph_name]
fgraph_def = compile_function_src(
fgraph_def_src, fgraph_name, global_env, local_env
)
return fgraph_def
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论