提交 f6ebb5ca authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba ExtractDiag: respect output dtype

上级 0f0e67ad
......@@ -160,7 +160,7 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
diag_len = min(x.shape[axis2], max(0, x.shape[axis1] + offset))
base_shape = x.shape[:axis1] + x.shape[axis1p1:axis2] + x.shape[axis2p1:]
out_shape = (*base_shape, diag_len)
out = np.empty(out_shape)
out = np.empty(out_shape, dtype=x.dtype)
for i in range(diag_len):
if offset >= 0:
......@@ -170,7 +170,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
out[..., i] = new_entry
return out
return extract_diag
cache_key = 1
return extract_diag, cache_key
@register_funcify_default_op_cache_key(Eye)
......
......@@ -260,11 +260,21 @@ def test_Split_view():
(pt.vector(), np.arange(10, dtype=config.floatX)),
0,
),
(
(
pt.tensor3(dtype="int8"),
np.arange(3 * 5 * 5, dtype="int8").reshape((3, 5, 5)),
),
1,
),
],
)
def test_ExtractDiag(val, offset):
val, val_test = val
if val.ndim <= 2:
g = pt.diag(val, offset)
else:
g = pt.diagonal(val, offset)
compare_numba_and_py(
[val],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论