提交 0a4a35c7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add output_filter method to JITLinker and use it to convert Numba scalars

上级 2c9ee770
...@@ -12,8 +12,6 @@ from typing import ( ...@@ -12,8 +12,6 @@ from typing import (
Union, Union,
) )
from numpy import ndarray
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
...@@ -23,6 +21,8 @@ from aesara.utils import difference ...@@ -23,6 +21,8 @@ from aesara.utils import difference
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from aesara.compile.profiling import ProfileStats from aesara.compile.profiling import ProfileStats
from aesara.graph.op import ( from aesara.graph.op import (
BasicThunkType, BasicThunkType,
...@@ -505,7 +505,7 @@ class WrapLinker(Linker): ...@@ -505,7 +505,7 @@ class WrapLinker(Linker):
def pre( def pre(
self, self,
f: "WrapLinker", f: "WrapLinker",
inputs: Union[List[ndarray], List[Optional[float]]], inputs: Union[List["NDArray"], List[Optional[float]]],
order: List[Apply], order: List[Apply],
thunk_groups: List[Tuple[Callable]], thunk_groups: List[Tuple[Callable]],
) -> None: ) -> None:
...@@ -609,6 +609,10 @@ class JITLinker(PerformLinker): ...@@ -609,6 +609,10 @@ class JITLinker(PerformLinker):
def jit_compile(self, fn: Callable) -> Callable: def jit_compile(self, fn: Callable) -> Callable:
"""JIT compile a converted ``FunctionGraph``.""" """JIT compile a converted ``FunctionGraph``."""
def output_filter(self, var: Variable, out: Any) -> Any:
"""Apply a filter to the data output by a JITed function call."""
return out
def create_jitable_thunk( def create_jitable_thunk(
self, compute_map, order, input_storage, output_storage, storage_map self, compute_map, order, input_storage, output_storage, storage_map
): ):
...@@ -663,14 +667,9 @@ class JITLinker(PerformLinker): ...@@ -663,14 +667,9 @@ class JITLinker(PerformLinker):
): ):
outputs = fgraph_jit(*[x[0] for x in thunk_inputs]) outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
for o_node, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs): for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_node][0] = True compute_map[o_var][0] = True
if len(o_storage) > 1: o_storage[0] = self.output_filter(o_var, o_val)
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val
return outputs return outputs
thunk.inputs = thunk_inputs thunk.inputs = thunk_inputs
......
from typing import TYPE_CHECKING, Any
import numpy as np
import aesara
from aesara.link.basic import JITLinker from aesara.link.basic import JITLinker
if TYPE_CHECKING:
from aesara.graph.basic import Variable
class NumbaLinker(JITLinker): class NumbaLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Numba.""" """A `Linker` that JIT-compiles NumPy-based operations using Numba."""
def output_filter(self, var: "Variable", out: Any) -> Any:
if not isinstance(var, np.ndarray) and isinstance(
var.type, aesara.tensor.TensorType
):
return np.asarray(out, dtype=var.type.dtype)
return out
def fgraph_convert(self, fgraph, **kwargs): def fgraph_convert(self, fgraph, **kwargs):
from aesara.link.numba.dispatch import numba_funcify from aesara.link.numba.dispatch import numba_funcify
......
...@@ -3575,3 +3575,14 @@ def test_config_options_cached(): ...@@ -3575,3 +3575,14 @@ def test_config_options_cached():
aesara_numba_fn = function([x], x * 2, mode=numba_mode) aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache) assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache)
def test_scalar_return_value_conversion():
r"""Make sure that we convert \"native\" scalars to `ndarray`\s in the graph outputs."""
x = at.scalar(name="x")
x_fn = function(
[x],
2 * x,
mode=numba_mode,
)
assert isinstance(x_fn(1.0), np.ndarray)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论