提交 071d3cae authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Allow FunctionGraph arguments in create_numba_signature

上级 4652fa7b
......@@ -3,6 +3,7 @@ import warnings
from contextlib import contextmanager
from functools import singledispatch
from textwrap import dedent
from typing import Union
import numba
import numba.np.unsafe.ndarray as numba_ndarray
......@@ -96,11 +97,13 @@ def get_numba_type(
def create_numba_signature(
node: Apply, force_scalar: bool = False, reduce_to_scalar: bool = False
node_or_fgraph: Union[FunctionGraph, Apply],
force_scalar: bool = False,
reduce_to_scalar: bool = False,
) -> numba.types.Type:
"""Create a Numba type for the signature of an ``Apply`` node."""
"""Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
input_types = []
for inp in node.inputs:
for inp in node_or_fgraph.inputs:
input_types.append(
get_numba_type(
inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
......@@ -108,7 +111,7 @@ def create_numba_signature(
)
output_types = []
for out in node.outputs:
for out in node_or_fgraph.outputs:
output_types.append(
get_numba_type(
out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论