提交 f6ebb5ca authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba ExtractDiag: respect output dtype

上级 0f0e67ad
...@@ -160,7 +160,7 @@ def numba_funcify_ExtractDiag(op, node, **kwargs): ...@@ -160,7 +160,7 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
diag_len = min(x.shape[axis2], max(0, x.shape[axis1] + offset)) diag_len = min(x.shape[axis2], max(0, x.shape[axis1] + offset))
base_shape = x.shape[:axis1] + x.shape[axis1p1:axis2] + x.shape[axis2p1:] base_shape = x.shape[:axis1] + x.shape[axis1p1:axis2] + x.shape[axis2p1:]
out_shape = (*base_shape, diag_len) out_shape = (*base_shape, diag_len)
out = np.empty(out_shape) out = np.empty(out_shape, dtype=x.dtype)
for i in range(diag_len): for i in range(diag_len):
if offset >= 0: if offset >= 0:
...@@ -170,7 +170,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs): ...@@ -170,7 +170,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
out[..., i] = new_entry out[..., i] = new_entry
return out return out
return extract_diag cache_key = 1
return extract_diag, cache_key
@register_funcify_default_op_cache_key(Eye) @register_funcify_default_op_cache_key(Eye)
......
...@@ -260,11 +260,21 @@ def test_Split_view(): ...@@ -260,11 +260,21 @@ def test_Split_view():
(pt.vector(), np.arange(10, dtype=config.floatX)), (pt.vector(), np.arange(10, dtype=config.floatX)),
0, 0,
), ),
(
(
pt.tensor3(dtype="int8"),
np.arange(3 * 5 * 5, dtype="int8").reshape((3, 5, 5)),
),
1,
),
], ],
) )
def test_ExtractDiag(val, offset): def test_ExtractDiag(val, offset):
val, val_test = val val, val_test = val
g = pt.diag(val, offset) if val.ndim <= 2:
g = pt.diag(val, offset)
else:
g = pt.diagonal(val, offset)
compare_numba_and_py( compare_numba_and_py(
[val], [val],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论