提交 0a4a35c7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add output_filter method to JITLinker and use it to convert Numba scalars

上级 2c9ee770
......@@ -12,8 +12,6 @@ from typing import (
Union,
)
from numpy import ndarray
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
......@@ -23,6 +21,8 @@ from aesara.utils import difference
if TYPE_CHECKING:
from numpy.typing import NDArray
from aesara.compile.profiling import ProfileStats
from aesara.graph.op import (
BasicThunkType,
......@@ -505,7 +505,7 @@ class WrapLinker(Linker):
def pre(
self,
f: "WrapLinker",
inputs: Union[List[ndarray], List[Optional[float]]],
inputs: Union[List["NDArray"], List[Optional[float]]],
order: List[Apply],
thunk_groups: List[Tuple[Callable]],
) -> None:
......@@ -609,6 +609,10 @@ class JITLinker(PerformLinker):
def jit_compile(self, fn: Callable) -> Callable:
"""JIT compile a converted ``FunctionGraph``."""
def output_filter(self, var: Variable, out: Any) -> Any:
"""Apply a filter to the data output by a JITed function call."""
return out
def create_jitable_thunk(
self, compute_map, order, input_storage, output_storage, storage_map
):
......@@ -663,14 +667,9 @@ class JITLinker(PerformLinker):
):
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
for o_node, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_node][0] = True
if len(o_storage) > 1:
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val
for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_var][0] = True
o_storage[0] = self.output_filter(o_var, o_val)
return outputs
thunk.inputs = thunk_inputs
......
from typing import TYPE_CHECKING, Any
import numpy as np
import aesara
from aesara.link.basic import JITLinker
if TYPE_CHECKING:
from aesara.graph.basic import Variable
class NumbaLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
def output_filter(self, var: "Variable", out: Any) -> Any:
if not isinstance(var, np.ndarray) and isinstance(
var.type, aesara.tensor.TensorType
):
return np.asarray(out, dtype=var.type.dtype)
return out
def fgraph_convert(self, fgraph, **kwargs):
from aesara.link.numba.dispatch import numba_funcify
......
......@@ -3575,3 +3575,14 @@ def test_config_options_cached():
aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache)
def test_scalar_return_value_conversion():
r"""Make sure that we convert \"native\" scalars to `ndarray`\s in the graph outputs."""
x = at.scalar(name="x")
x_fn = function(
[x],
2 * x,
mode=numba_mode,
)
assert isinstance(x_fn(1.0), np.ndarray)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论