提交 a6b2b0cb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use tuple_setitem from numba_basic to simplify mocking

上级 cb4da6ed
...@@ -11,6 +11,7 @@ import scipy.special ...@@ -11,6 +11,7 @@ import scipy.special
from llvmlite.llvmpy.core import Type as llvm_Type from llvmlite.llvmpy.core import Type as llvm_Type
from numba import types from numba import types
from numba.core.errors import TypingError from numba.core.errors import TypingError
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box from numba.extending import box
from aesara import config from aesara import config
......
...@@ -6,7 +6,6 @@ from typing import Any, Callable, Optional, Union ...@@ -6,7 +6,6 @@ from typing import Any, Callable, Optional, Union
import numba import numba
import numpy as np import numpy as np
from numba.cpython.unsafe.tuple import tuple_setitem
from aesara import config from aesara import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
...@@ -519,10 +518,10 @@ def numba_funcify_DimShuffle(op, **kwargs): ...@@ -519,10 +518,10 @@ def numba_funcify_DimShuffle(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def populate_new_shape(i, j, new_shape, shuffle_shape): def populate_new_shape(i, j, new_shape, shuffle_shape):
if i in augment: if i in augment:
new_shape = tuple_setitem(new_shape, i, 1) new_shape = numba_basic.tuple_setitem(new_shape, i, 1)
return j, new_shape return j, new_shape
else: else:
new_shape = tuple_setitem(new_shape, i, shuffle_shape[j]) new_shape = numba_basic.tuple_setitem(new_shape, i, shuffle_shape[j])
return j + 1, new_shape return j + 1, new_shape
else: else:
...@@ -532,7 +531,7 @@ def numba_funcify_DimShuffle(op, **kwargs): ...@@ -532,7 +531,7 @@ def numba_funcify_DimShuffle(op, **kwargs):
# To avoid this compile-time error, we omit the expression altogether. # To avoid this compile-time error, we omit the expression altogether.
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def populate_new_shape(i, j, new_shape, shuffle_shape): def populate_new_shape(i, j, new_shape, shuffle_shape):
return j, tuple_setitem(new_shape, i, 1) return j, numba_basic.tuple_setitem(new_shape, i, 1)
if ndim_new_shape > 0: if ndim_new_shape > 0:
create_zeros_tuple = numba_basic.create_tuple_creator( create_zeros_tuple = numba_basic.create_tuple_creator(
......
...@@ -156,17 +156,15 @@ def eval_python_only(fn_inputs, fgraph, inputs): ...@@ -156,17 +156,15 @@ def eval_python_only(fn_inputs, fgraph, inputs):
mocks = [ mocks = [
mock.patch("numba.njit", njit_noop), mock.patch("numba.njit", njit_noop),
mock.patch("numba.vectorize", vectorize_noop), mock.patch("numba.vectorize", vectorize_noop),
mock.patch( mock.patch("aesara.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem),
"aesara.link.numba.dispatch.elemwise.tuple_setitem", py_tuple_setitem
),
mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop), mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop),
mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop), mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop),
mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x), mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x),
mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar),
mock.patch( mock.patch(
"aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype", "aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
lambda dtype: dtype, lambda dtype: dtype,
), ),
mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar),
mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)), mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)),
] ]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论