提交 43829581 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Extract source compilation into aesara.link.utils.compile_function_src helper

上级 3b552872
import ast
from functools import reduce, singledispatch from functools import reduce, singledispatch
from tempfile import NamedTemporaryFile
import numba import numba
import numpy as np import numpy as np
...@@ -10,7 +8,7 @@ import scipy.special ...@@ -10,7 +8,7 @@ import scipy.special
from aesara.compile.ops import DeepCopyOp from aesara.compile.ops import DeepCopyOp
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.link.utils import fgraph_to_python from aesara.link.utils import compile_function_src, fgraph_to_python
from aesara.scalar.basic import Composite, ScalarOp from aesara.scalar.basic import Composite, ScalarOp
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
...@@ -154,19 +152,12 @@ def numba_funcify_Subtensor(op, node, **kwargs): ...@@ -154,19 +152,12 @@ def numba_funcify_Subtensor(op, node, **kwargs):
node, idx_list, objmode=isinstance(op, AdvancedSubtensor) node, idx_list, objmode=isinstance(op, AdvancedSubtensor)
) )
subtensor_def_ast = ast.parse(subtensor_def_src) global_env = {}
global_env["objmode"] = numba.objmode
with NamedTemporaryFile(delete=False) as f: subtensor_fn = compile_function_src(subtensor_def_src, "subtensor", global_env)
filename = f.name
f.write(subtensor_def_src.encode())
local_env = {} return numba.njit(subtensor_fn)
mod_code = compile(subtensor_def_ast, filename, mode="exec")
exec(mod_code, {"objmode": numba.objmode}, local_env)
subtensor_def = local_env["subtensor"]
return numba.njit(subtensor_def)
@numba_funcify.register(DeepCopyOp) @numba_funcify.register(DeepCopyOp)
......
...@@ -573,6 +573,25 @@ def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout): ...@@ -573,6 +573,25 @@ def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout):
register_thunk_trace_excepthook() register_thunk_trace_excepthook()
def compile_function_src(src, function_name, global_env=None, local_env=None):
src_ast = ast.parse(src)
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(src.encode())
if global_env is None:
global_env = {}
if local_env is None:
local_env = {}
mod_code = compile(src_ast, filename, mode="exec")
exec(mod_code, global_env, local_env)
return local_env[function_name]
def fgraph_to_python( def fgraph_to_python(
fgraph: FunctionGraph, fgraph: FunctionGraph,
op_conversion_fn: Callable, op_conversion_fn: Callable,
...@@ -624,9 +643,6 @@ def fgraph_to_python( ...@@ -624,9 +643,6 @@ def fgraph_to_python(
fgraph, order, input_storage, output_storage, storage_map fgraph, order, input_storage, output_storage, storage_map
) )
if global_env is None:
global_env = {}
def unique_name(x, names_counter=Counter([fgraph_name]), obj_to_names={}): def unique_name(x, names_counter=Counter([fgraph_name]), obj_to_names={}):
if x in obj_to_names: if x in obj_to_names:
return obj_to_names[x] return obj_to_names[x]
...@@ -649,6 +665,9 @@ def fgraph_to_python( ...@@ -649,6 +665,9 @@ def fgraph_to_python(
return local_name return local_name
if global_env is None:
global_env = {}
body_assigns = [] body_assigns = []
for node in order: for node in order:
jax_func = op_conversion_fn(node.op, node=node, **kwargs) jax_func = op_conversion_fn(node.op, node=node, **kwargs)
...@@ -690,20 +709,11 @@ def {fgraph_name}({", ".join(fgraph_input_names)}): ...@@ -690,20 +709,11 @@ def {fgraph_name}({", ".join(fgraph_input_names)}):
return {fgraph_return_src} return {fgraph_return_src}
""" """
fgraph_def_ast = ast.parse(fgraph_def_src)
# Create source code to be (at least temporarily) associated with the
# compiled function (e.g. for easier debugging)
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(fgraph_def_src.encode())
if local_env is None: if local_env is None:
local_env = locals() local_env = locals()
mod_code = compile(fgraph_def_ast, filename, mode="exec") fgraph_def = compile_function_src(
exec(mod_code, global_env, local_env) fgraph_def_src, fgraph_name, global_env, local_env
)
fgraph_def = local_env[fgraph_name]
return fgraph_def return fgraph_def
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论