提交 0c1f0f33 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make sure Numba DimShuffle input is an ndarray

上级 1bee8434
......@@ -926,7 +926,7 @@ def numba_funcify_DimShuffle(op, **kwargs):
# E shuffle_shape = res.shape[: len(shuffle)]
@numba.njit
def dimshuffle(x):
return dimshuffle_inner(x, shuffle)
return dimshuffle_inner(np.asarray(x), shuffle)
return dimshuffle
......
......@@ -903,16 +903,27 @@ def test_ARange(start, stop, step, dtype):
@pytest.mark.parametrize(
"careduce_fn, axis, v",
"careduce_fn, axis, v, keepdims",
[
(aet.sum, 0, set_test_value(aet.vector(), np.arange(3, dtype=config.floatX))),
(aet.all, 0, set_test_value(aet.vector(), np.arange(3, dtype=config.floatX))),
(
aet.sum,
0,
set_test_value(aet.vector(), np.arange(3, dtype=config.floatX)),
False,
),
(
aet.all,
0,
set_test_value(aet.vector(), np.arange(3, dtype=config.floatX)),
False,
),
(
aet.sum,
0,
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
aet.sum,
......@@ -920,6 +931,7 @@ def test_ARange(start, stop, step, dtype):
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
aet.sum,
......@@ -927,6 +939,7 @@ def test_ARange(start, stop, step, dtype):
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
aet.sum,
......@@ -934,6 +947,7 @@ def test_ARange(start, stop, step, dtype):
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
aet.sum,
......@@ -941,14 +955,21 @@ def test_ARange(start, stop, step, dtype):
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
aet.prod,
0,
set_test_value(aet.vector(), np.arange(3, dtype=config.floatX)),
False,
),
(aet.prod, 0, set_test_value(aet.vector(), np.arange(3, dtype=config.floatX))),
(
aet.prod,
0,
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
aet.prod,
......@@ -956,11 +977,20 @@ def test_ARange(start, stop, step, dtype):
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
aet.max,
None,
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
True,
),
],
)
def test_CAReduce(careduce_fn, axis, v):
g = careduce_fn(v, axis=axis)
def test_CAReduce(careduce_fn, axis, v, keepdims):
g = careduce_fn(v, axis=axis, keepdims=keepdims)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论