提交 aa1b7c83 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove unused numba_vectorize

上级 022f1890
...@@ -87,13 +87,6 @@ def numba_njit(*args, fastmath=None, **kwargs): ...@@ -87,13 +87,6 @@ def numba_njit(*args, fastmath=None, **kwargs):
return numba.njit(*args, fastmath=fastmath, **kwargs) return numba.njit(*args, fastmath=fastmath, **kwargs)
def numba_vectorize(*args, **kwargs):
if len(args) > 0 and callable(args[0]):
return numba.vectorize(*args[1:], cache=config.numba__cache, **kwargs)(args[0])
return numba.vectorize(*args, cache=config.numba__cache, **kwargs)
def get_numba_type( def get_numba_type(
pytensor_type: Type, pytensor_type: Type,
layout: str = "A", layout: str = "A",
......
import contextlib import contextlib
import inspect
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from unittest import mock from unittest import mock
...@@ -151,30 +150,6 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode): ...@@ -151,30 +150,6 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
else: else:
return lambda x: x return lambda x: x
def vectorize_noop(*args, **kwargs):
def wrap(fn):
# `numba.vectorize` allows an `out` positional argument. We need
# to account for that
sig = inspect.signature(fn)
nparams = len(sig.parameters)
def inner_vec(*args):
if len(args) > nparams:
# An `out` argument has been specified for an in-place
# operation
out = args[-1]
out[...] = np.vectorize(fn)(*args[:nparams])
return out
else:
return np.vectorize(fn)(*args)
return inner_vec
if len(args) == 1 and callable(args[0]):
return wrap(args[0], **kwargs)
else:
return wrap
def py_global_numba_func(func): def py_global_numba_func(func):
if hasattr(func, "py_func"): if hasattr(func, "py_func"):
return func.py_func return func.py_func
...@@ -182,7 +157,6 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode): ...@@ -182,7 +157,6 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
mocks = [ mocks = [
mock.patch("numba.njit", njit_noop), mock.patch("numba.njit", njit_noop),
mock.patch("numba.vectorize", vectorize_noop),
mock.patch( mock.patch(
"pytensor.link.numba.dispatch.basic.global_numba_func", "pytensor.link.numba.dispatch.basic.global_numba_func",
py_global_numba_func, py_global_numba_func,
...@@ -191,9 +165,6 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode): ...@@ -191,9 +165,6 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
"pytensor.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem "pytensor.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem
), ),
mock.patch("pytensor.link.numba.dispatch.basic.numba_njit", njit_noop), mock.patch("pytensor.link.numba.dispatch.basic.numba_njit", njit_noop),
mock.patch(
"pytensor.link.numba.dispatch.basic.numba_vectorize", vectorize_noop
),
mock.patch( mock.patch(
"pytensor.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x "pytensor.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x
), ),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论