提交 a6b2b0cb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use tuple_setitem from numba_basic to simplify mocking

上级 cb4da6ed
......@@ -11,6 +11,7 @@ import scipy.special
from llvmlite.llvmpy.core import Type as llvm_Type
from numba import types
from numba.core.errors import TypingError
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box
from aesara import config
......
......@@ -6,7 +6,6 @@ from typing import Any, Callable, Optional, Union
import numba
import numpy as np
from numba.cpython.unsafe.tuple import tuple_setitem
from aesara import config
from aesara.graph.basic import Apply
......@@ -519,10 +518,10 @@ def numba_funcify_DimShuffle(op, **kwargs):
@numba_basic.numba_njit
def populate_new_shape(i, j, new_shape, shuffle_shape):
if i in augment:
new_shape = tuple_setitem(new_shape, i, 1)
new_shape = numba_basic.tuple_setitem(new_shape, i, 1)
return j, new_shape
else:
new_shape = tuple_setitem(new_shape, i, shuffle_shape[j])
new_shape = numba_basic.tuple_setitem(new_shape, i, shuffle_shape[j])
return j + 1, new_shape
else:
......@@ -532,7 +531,7 @@ def numba_funcify_DimShuffle(op, **kwargs):
# To avoid this compile-time error, we omit the expression altogether.
@numba_basic.numba_njit(inline="always")
def populate_new_shape(i, j, new_shape, shuffle_shape):
return j, tuple_setitem(new_shape, i, 1)
return j, numba_basic.tuple_setitem(new_shape, i, 1)
if ndim_new_shape > 0:
create_zeros_tuple = numba_basic.create_tuple_creator(
......
......@@ -156,17 +156,15 @@ def eval_python_only(fn_inputs, fgraph, inputs):
mocks = [
mock.patch("numba.njit", njit_noop),
mock.patch("numba.vectorize", vectorize_noop),
mock.patch(
"aesara.link.numba.dispatch.elemwise.tuple_setitem", py_tuple_setitem
),
mock.patch("aesara.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem),
mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop),
mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop),
mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x),
mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar),
mock.patch(
"aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
lambda dtype: dtype,
),
mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar),
mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)),
]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论