提交 1a5749d4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Generalize the create_vectorize_func interface

上级 db9b4bbe
...@@ -2,7 +2,7 @@ import inspect ...@@ -2,7 +2,7 @@ import inspect
from functools import singledispatch from functools import singledispatch
from numbers import Number from numbers import Number
from textwrap import indent from textwrap import indent
from typing import Union from typing import Any, Callable, Optional, Union
import numba import numba
import numpy as np import numpy as np
...@@ -127,8 +127,14 @@ if {res}[{idx}] > {arr}: ...@@ -127,8 +127,14 @@ if {res}[{idx}] > {arr}:
""" """
def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs): def create_vectorize_func(
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs) scalar_op_fn: Callable,
node: Apply,
use_signature: bool = False,
identity: Optional[Any] = None,
**kwargs,
) -> Callable:
r"""Create a vectorized Numba function from a `Apply`\s Python function."""
if len(node.outputs) > 1: if len(node.outputs) > 1:
raise NotImplementedError( raise NotImplementedError(
...@@ -420,7 +426,8 @@ def create_axis_apply_fn(fn, axis, ndim, dtype): ...@@ -420,7 +426,8 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
@numba_funcify.register(Elemwise) @numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs): def numba_funcify_Elemwise(op, node, **kwargs):
elemwise_fn = create_vectorize_func(op, node, use_signature=False) scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__ elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern: if op.inplace_pattern:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论