提交 9926e07c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix dtype conversion for scalar Numba Dot results

上级 54e19a68
...@@ -1424,7 +1424,7 @@ def numba_funcify_Dot(op, node, **kwargs): ...@@ -1424,7 +1424,7 @@ def numba_funcify_Dot(op, node, **kwargs):
@numba.njit @numba.njit
def dot(x, y): def dot(x, y):
return np.dot(inputs_cast(x), inputs_cast(y)).astype(out_dtype) return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)
return dot return dot
......
...@@ -1694,6 +1694,11 @@ def test_BroadcastTo(x, shape, exc): ...@@ -1694,6 +1694,11 @@ def test_BroadcastTo(x, shape, exc):
), ),
None, None,
), ),
(
set_test_value(aet.lvector(), np.random.random(size=(2,)).astype(np.int64)),
set_test_value(aet.lvector(), np.random.random(size=(2,)).astype(np.int64)),
None,
),
], ],
) )
def test_Dot(x, y, exc): def test_Dot(x, y, exc):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论