提交 263a7c71 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Fixed Dimshuffle for scalar result cases

上级 0f8c81c3
...@@ -324,7 +324,6 @@ def numba_funcify_DimShuffle(op, **kwargs): ...@@ -324,7 +324,6 @@ def numba_funcify_DimShuffle(op, **kwargs):
inplace = op.inplace inplace = op.inplace
ndim_new_shape = len(shuffle) + len(augment) ndim_new_shape = len(shuffle) + len(augment)
create_zeros_tuple = numba_basic.create_tuple_creator(lambda _: 0, ndim_new_shape)
if len(shuffle) > 0: if len(shuffle) > 0:
...@@ -346,24 +345,35 @@ def numba_funcify_DimShuffle(op, **kwargs): ...@@ -346,24 +345,35 @@ def numba_funcify_DimShuffle(op, **kwargs):
def populate_new_shape(i, j, new_shape, shuffle_shape): def populate_new_shape(i, j, new_shape, shuffle_shape):
return j, tuple_setitem(new_shape, i, 1) return j, tuple_setitem(new_shape, i, 1)
@numba.njit if ndim_new_shape > 0:
def dimshuffle_inner(x, shuffle): create_zeros_tuple = numba_basic.create_tuple_creator(
res = np.transpose(x, shuffle + drop) lambda _: 0, ndim_new_shape
shuffle_shape = res.shape[: len(shuffle)] )
new_shape = create_zeros_tuple() @numba.njit
def dimshuffle_inner(x, shuffle):
res = np.transpose(x, shuffle + drop)
shuffle_shape = res.shape[: len(shuffle)]
j = 0 new_shape = create_zeros_tuple()
for i in range(len(new_shape)):
j, new_shape = populate_new_shape(i, j, new_shape, shuffle_shape)
# FIXME: Numba's `array.reshape` only accepts C arrays. j = 0
res_reshape = np.reshape(np.ascontiguousarray(res), new_shape) for i in range(len(new_shape)):
j, new_shape = populate_new_shape(i, j, new_shape, shuffle_shape)
if not inplace: # FIXME: Numba's `array.reshape` only accepts C arrays.
return res_reshape.copy() res_reshape = np.reshape(np.ascontiguousarray(res), new_shape)
else:
return res_reshape if not inplace:
return res_reshape.copy()
else:
return res_reshape
else:
@numba.njit
def dimshuffle_inner(x, shuffle):
return x.item()
# Without the following wrapper function we would see this error: # Without the following wrapper function we would see this error:
# E No implementation of function Function(<built-in function getitem>) found for signature: # E No implementation of function Function(<built-in function getitem>) found for signature:
......
...@@ -691,6 +691,14 @@ def test_AllocDiag(v, offset): ...@@ -691,6 +691,14 @@ def test_AllocDiag(v, offset):
(0,), (0,),
True, True,
), ),
(
set_test_value(
aet.tensor(config.floatX, [True, True, True], name="a"),
np.array([[[1.0]]], dtype=config.floatX),
),
(),
True,
),
], ],
) )
def test_Dimshuffle(v, new_order, inplace): def test_Dimshuffle(v, new_order, inplace):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论