提交 ab3704b3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove unused numba dispatch function

上级 8cc489b1
from collections.abc import Callable
from functools import singledispatch
from textwrap import dedent, indent
from typing import Any
import numba
import numpy as np
......@@ -9,7 +7,6 @@ from numba.core.extending import overload
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
from pytensor import config
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
......@@ -124,42 +121,6 @@ if {res}[{idx}] > {arr}:
"""
def create_vectorize_func(
scalar_op_fn: Callable,
node: Apply,
use_signature: bool = False,
identity: Any | None = None,
**kwargs,
) -> Callable:
r"""Create a vectorized Numba function from a `Apply`\s Python function."""
if len(node.outputs) > 1:
raise NotImplementedError(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)
if use_signature:
signature = [create_numba_signature(node, force_scalar=True)]
else:
signature = []
target = (
getattr(node.tag, "numba__vectorize_target", None)
or config.numba__vectorize_target
)
numba_vectorized_fn = numba_basic.numba_vectorize(
signature, identity=identity, target=target, fastmath=config.numba__fastmath
)
py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn)
elemwise_fn = numba_vectorized_fn(scalar_op_fn)
elemwise_fn.py_scalar_func = py_scalar_func
return elemwise_fn
def create_multiaxis_reducer(
scalar_op,
identity,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论