提交 ea889f4a authored 作者: Frederic's avatar Frederic

Fix GpuAdvancedIncSubtensor1_dev20 with negative index in new back-end

上级 11ebaeb5
...@@ -489,7 +489,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -489,7 +489,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
return gof.Apply(self, [x_, y_, ilist_], [x_.type()]) return gof.Apply(self, [x_, y_, ilist_], [x_.type()])
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
def c_headers(self): def c_headers(self):
return ['cuda.h', '<gpuarray/extension.h>', '<numpy_compat.h>', return ['cuda.h', '<gpuarray/extension.h>', '<numpy_compat.h>',
...@@ -587,6 +587,8 @@ __device__ npy_float16 atomicAdd(npy_float16 *addr, npy_float16 val) { ...@@ -587,6 +587,8 @@ __device__ npy_float16 atomicAdd(npy_float16 *addr, npy_float16 val) {
for(int j = (threadIdx.x); j < numColsX;j += blockDim.x) for(int j = (threadIdx.x); j < numColsX;j += blockDim.x)
{ {
int x_row = indices_arr[i * stridesIndices]; int x_row = indices_arr[i * stridesIndices];
if(x_row < 0)
x_row += numRowsX;
int y_row = i; int y_row = i;
atomicAdd(&X[(x_row * stridesX0) + (j * stridesX1)], Y[(y_row * stridesY0) + (j * stridesY1)]); atomicAdd(&X[(x_row * stridesX0) + (j * stridesX1)], Y[(y_row * stridesY0) + (j * stridesY1)]);
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论