提交 5f558708 authored 作者: Shawn Tan's avatar Shawn Tan

Re-implemented using strides.

上级 f89ec97d
...@@ -1388,30 +1388,28 @@ class GpuAllocDiag(AllocDiag): ...@@ -1388,30 +1388,28 @@ class GpuAllocDiag(AllocDiag):
# Initialise a buffer the same size as the output # Initialise a buffer the same size as the output
result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2 result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2
result_buffer_shape = ((np.prod(x.shape[:-1]).astype(np.int64),) + result_buffer_shape = ((np.prod(x.shape[:-1]).astype(np.int64),) +
((x.shape[-1] + abs(offset)) ** 2,)) (x.shape[-1] + abs(offset),) * 2)
result_buffer = gpuarray.zeros(result_buffer_shape, result_buffer = gpuarray.zeros(result_buffer_shape,
dtype=x.dtype, dtype=x.dtype,
context=x.context) context=x.context)
# Slice out a view of the diagonals # Slice out a view of the diagonals
if offset != 0: row_size = abs(offset) + x.shape[-1]
row_size = x.shape[-1] + abs(offset) if offset <= 0: # diag in the lower triangle
if offset >= 0: diag_view = result_buffer[:, abs(offset), :x.shape[-1]]
start_flattened_offset = abs(offset) diag_view.strides = (diag_view.strides[0],
end_flattened_offset = row_size * abs(offset) (row_size + 1) * x.dtype.itemsize)
else: else: # diag in the upper triangle
start_flattened_offset = row_size * abs(offset) diag_view = result_buffer[:, :x.shape[-1], abs(offset)]
end_flattened_offset = abs(offset) diag_view.strides = (diag_view.strides[0],
diag_view.strides[1] + x.dtype.itemsize)
diag_view = result_buffer[:, start_flattened_offset:-end_flattened_offset:row_size + 1]
else:
diag_view = result_buffer[:, ::x.shape[-1] + 1]
# Fill view with flattened array of diagonals # Fill view with flattened array of diagonals
diag_view[:] = x.reshape(diag_view.shape)[:] diag_view[:] = x.reshape(diag_view.shape)[:]
# Unflatten buffer into output size # Unflatten buffer into output size
result = result_buffer.reshape(result_shape) result = result_buffer.reshape(result_shape)
print(result)
if len(x.shape) > 1: if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2 # Re-order axes so they correspond to diagonals at axis1, axis2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论