提交 263a7c71 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Fixed Dimshuffle for scalar result cases

上级 0f8c81c3
......@@ -324,7 +324,6 @@ def numba_funcify_DimShuffle(op, **kwargs):
inplace = op.inplace
ndim_new_shape = len(shuffle) + len(augment)
create_zeros_tuple = numba_basic.create_tuple_creator(lambda _: 0, ndim_new_shape)
if len(shuffle) > 0:
......@@ -346,6 +345,11 @@ def numba_funcify_DimShuffle(op, **kwargs):
def populate_new_shape(i, j, new_shape, shuffle_shape):
return j, tuple_setitem(new_shape, i, 1)
if ndim_new_shape > 0:
create_zeros_tuple = numba_basic.create_tuple_creator(
lambda _: 0, ndim_new_shape
)
@numba.njit
def dimshuffle_inner(x, shuffle):
res = np.transpose(x, shuffle + drop)
......@@ -365,6 +369,12 @@ def numba_funcify_DimShuffle(op, **kwargs):
else:
return res_reshape
else:
@numba.njit
def dimshuffle_inner(x, shuffle):
return x.item()
# Without the following wrapper function we would see this error:
# E No implementation of function Function(<built-in function getitem>) found for signature:
# E
......
......@@ -691,6 +691,14 @@ def test_AllocDiag(v, offset):
(0,),
True,
),
(
set_test_value(
aet.tensor(config.floatX, [True, True, True], name="a"),
np.array([[[1.0]]], dtype=config.floatX),
),
(),
True,
),
],
)
def test_Dimshuffle(v, new_order, inplace):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论