提交 ee382960 authored 作者: Shawn Tan's avatar Shawn Tan

Far more understandable version using strides.

上级 5f558708
...@@ -1394,11 +1394,8 @@ class GpuAllocDiag(AllocDiag): ...@@ -1394,11 +1394,8 @@ class GpuAllocDiag(AllocDiag):
context=x.context) context=x.context)
# Slice out a view of the diagonals # Slice out a view of the diagonals
row_size = abs(offset) + x.shape[-1] if offset < 0: # diag in the lower triangle
if offset <= 0: # diag in the lower triangle diag_view = result_buffer[:, abs(offset):, 0]
diag_view = result_buffer[:, abs(offset), :x.shape[-1]]
diag_view.strides = (diag_view.strides[0],
(row_size + 1) * x.dtype.itemsize)
else: # diag in the upper triangle else: # diag in the upper triangle
diag_view = result_buffer[:, :x.shape[-1], abs(offset)] diag_view = result_buffer[:, :x.shape[-1], abs(offset)]
diag_view.strides = (diag_view.strides[0], diag_view.strides = (diag_view.strides[0],
...@@ -1409,7 +1406,6 @@ class GpuAllocDiag(AllocDiag): ...@@ -1409,7 +1406,6 @@ class GpuAllocDiag(AllocDiag):
# Unflatten buffer into output size # Unflatten buffer into output size
result = result_buffer.reshape(result_shape) result = result_buffer.reshape(result_shape)
print(result)
if len(x.shape) > 1: if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2 # Re-order axes so they correspond to diagonals at axis1, axis2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论