提交 67a5eba9 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Improve numba DimShuffle compile time

上级 1827703c
......@@ -539,44 +539,54 @@ def numba_funcify_DimShuffle(op, **kwargs):
ndim_new_shape = len(shuffle) + len(augment)
no_transpose = all(i == j for i, j in enumerate(transposition))
if no_transpose:
@numba_basic.numba_njit
def transpose(x):
return x
else:
@numba_basic.numba_njit
def transpose(x):
return np.transpose(x, transposition)
shape_template = (1,) * ndim_new_shape
# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
# is typed as `getitem(Tuple(), int)`, which has no implementation
# (since getting an item from an empty sequence doesn't make sense).
# To avoid this compile-time error, we omit the expression altogether.
if len(shuffle) > 0:
@numba_basic.numba_njit
def populate_new_shape(i, j, new_shape, shuffle_shape):
if i in augment:
new_shape = numba_basic.tuple_setitem(new_shape, i, 1)
return j, new_shape
else:
new_shape = numba_basic.tuple_setitem(new_shape, i, shuffle_shape[j])
return j + 1, new_shape
def find_shape(array_shape):
shape = shape_template
j = 0
for i in range(ndim_new_shape):
if i not in augment:
length = array_shape[j]
shape = numba_basic.tuple_setitem(shape, i, length)
j = j + 1
return shape
else:
# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
# is typed as `getitem(Tuple(), int)`, which has no implementation
# (since getting an item from an empty sequence doesn't make sense).
# To avoid this compile-time error, we omit the expression altogether.
@numba_basic.numba_njit(inline="always")
def populate_new_shape(i, j, new_shape, shuffle_shape):
return j, numba_basic.tuple_setitem(new_shape, i, 1)
@numba_basic.numba_njit
def find_shape(array_shape):
return shape_template
if ndim_new_shape > 0:
create_zeros_tuple = numba_basic.create_tuple_creator(
lambda _: 0, ndim_new_shape
)
@numba_basic.numba_njit
def dimshuffle_inner(x, shuffle):
res = np.transpose(x, transposition)
shuffle_shape = res.shape[: len(shuffle)]
new_shape = create_zeros_tuple()
j = 0
for i in range(len(new_shape)):
j, new_shape = populate_new_shape(i, j, new_shape, shuffle_shape)
x = transpose(x)
shuffle_shape = x.shape[: len(shuffle)]
new_shape = find_shape(shuffle_shape)
# FIXME: Numba's `array.reshape` only accepts C arrays.
res_reshape = np.reshape(np.ascontiguousarray(res), new_shape)
res_reshape = np.reshape(np.ascontiguousarray(x), new_shape)
if not inplace:
return res_reshape.copy()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论