提交 263a7c71 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Fixed Dimshuffle for scalar result cases

上级 0f8c81c3
...@@ -324,7 +324,6 @@ def numba_funcify_DimShuffle(op, **kwargs): ...@@ -324,7 +324,6 @@ def numba_funcify_DimShuffle(op, **kwargs):
inplace = op.inplace inplace = op.inplace
ndim_new_shape = len(shuffle) + len(augment) ndim_new_shape = len(shuffle) + len(augment)
create_zeros_tuple = numba_basic.create_tuple_creator(lambda _: 0, ndim_new_shape)
if len(shuffle) > 0: if len(shuffle) > 0:
...@@ -346,6 +345,11 @@ def numba_funcify_DimShuffle(op, **kwargs): ...@@ -346,6 +345,11 @@ def numba_funcify_DimShuffle(op, **kwargs):
def populate_new_shape(i, j, new_shape, shuffle_shape): def populate_new_shape(i, j, new_shape, shuffle_shape):
return j, tuple_setitem(new_shape, i, 1) return j, tuple_setitem(new_shape, i, 1)
if ndim_new_shape > 0:
create_zeros_tuple = numba_basic.create_tuple_creator(
lambda _: 0, ndim_new_shape
)
@numba.njit @numba.njit
def dimshuffle_inner(x, shuffle): def dimshuffle_inner(x, shuffle):
res = np.transpose(x, shuffle + drop) res = np.transpose(x, shuffle + drop)
...@@ -365,6 +369,12 @@ def numba_funcify_DimShuffle(op, **kwargs): ...@@ -365,6 +369,12 @@ def numba_funcify_DimShuffle(op, **kwargs):
else: else:
return res_reshape return res_reshape
else:
@numba.njit
def dimshuffle_inner(x, shuffle):
return x.item()
# Without the following wrapper function we would see this error: # Without the following wrapper function we would see this error:
# E No implementation of function Function(<built-in function getitem>) found for signature: # E No implementation of function Function(<built-in function getitem>) found for signature:
# E # E
......
...@@ -691,6 +691,14 @@ def test_AllocDiag(v, offset): ...@@ -691,6 +691,14 @@ def test_AllocDiag(v, offset):
(0,), (0,),
True, True,
), ),
(
set_test_value(
aet.tensor(config.floatX, [True, True, True], name="a"),
np.array([[[1.0]]], dtype=config.floatX),
),
(),
True,
),
], ],
) )
def test_Dimshuffle(v, new_order, inplace): def test_Dimshuffle(v, new_order, inplace):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论