提交 071d3cae authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Allow FunctionGraph arguments in create_numba_signature

上级 4652fa7b
...@@ -3,6 +3,7 @@ import warnings ...@@ -3,6 +3,7 @@ import warnings
from contextlib import contextmanager from contextlib import contextmanager
from functools import singledispatch from functools import singledispatch
from textwrap import dedent from textwrap import dedent
from typing import Union
import numba import numba
import numba.np.unsafe.ndarray as numba_ndarray import numba.np.unsafe.ndarray as numba_ndarray
...@@ -96,11 +97,13 @@ def get_numba_type( ...@@ -96,11 +97,13 @@ def get_numba_type(
def create_numba_signature( def create_numba_signature(
node: Apply, force_scalar: bool = False, reduce_to_scalar: bool = False node_or_fgraph: Union[FunctionGraph, Apply],
force_scalar: bool = False,
reduce_to_scalar: bool = False,
) -> numba.types.Type: ) -> numba.types.Type:
"""Create a Numba type for the signature of an ``Apply`` node.""" """Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
input_types = [] input_types = []
for inp in node.inputs: for inp in node_or_fgraph.inputs:
input_types.append( input_types.append(
get_numba_type( get_numba_type(
inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
...@@ -108,7 +111,7 @@ def create_numba_signature( ...@@ -108,7 +111,7 @@ def create_numba_signature(
) )
output_types = [] output_types = []
for out in node.outputs: for out in node_or_fgraph.outputs:
output_types.append( output_types.append(
get_numba_type( get_numba_type(
out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论