提交 5dc35990 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Simplify Reshape's tuple creation during Numba conversion

上级 abf6026c
......@@ -14,6 +14,7 @@ from numba import types
from numba.core.errors import TypingError
from numba.cpython.unsafe.tuple import tuple_setitem
from numba.extending import box
from numba.np.unsafe.ndarray import to_fixed_tuple
from numpy.core.multiarray import normalize_axis_index
from aesara.compile.ops import DeepCopyOp, ViewOp
......@@ -953,20 +954,11 @@ def numba_funcify_Cast(op, node, **kwargs):
@numba_funcify.register(Reshape)
def numba_funcify_Reshape(op, **kwargs):
ndim = op.ndim
# TODO: It might be possible/better to use
# `numba.np.unsafe.ndarray.to_fixed_tuple` here instead
create_zeros_tuple = create_tuple_creator(lambda _: 0, ndim)
@numba.njit
def reshape(x, shape):
new_shape = create_zeros_tuple()
for i in numba.literal_unroll(range(ndim)):
new_shape = tuple_setitem(new_shape, i, shape[i])
new_shape = to_fixed_tuple(shape, ndim)
return np.reshape(x, new_shape)
return reshape
......
......@@ -138,6 +138,9 @@ def eval_python_only(fn_inputs, fgraph, inputs):
lambda dtype: dtype,
), mock.patch(
"aesara.link.numba.dispatch.to_scalar", py_to_scalar
), mock.patch(
"aesara.link.numba.dispatch.to_fixed_tuple",
lambda x, n: tuple(x),
):
aesara_numba_fn = function(
fn_inputs,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论