提交 40a42060 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix AdvIncSub_dev20 when the data to set is broadcasted.

上级 6264d52b
...@@ -489,7 +489,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -489,7 +489,7 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
return gof.Apply(self, [x_, y_, ilist_], [x_.type()]) return gof.Apply(self, [x_, y_, ilist_], [x_.type()])
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
def c_headers(self): def c_headers(self):
return ['cuda.h', '<gpuarray/extension.h>', '<numpy_compat.h>', return ['cuda.h', '<gpuarray/extension.h>', '<numpy_compat.h>',
...@@ -583,17 +583,17 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -583,17 +583,17 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
dim3 n_threads(num_threads_per_block); dim3 n_threads(num_threads_per_block);
k_vector_add_fast<<<n_blocks, n_threads>>>( k_vector_add_fast<<<n_blocks, n_threads>>>(
PyGpuArray_DIMS(py_self)[0], PyGpuArray_DIM(py_self, 0),
PyGpuArray_DIMS(py_self)[1], PyGpuArray_DIM(py_self, 1),
PyGpuArray_STRIDES(py_self)[0] / %(itemsize_x)s, PyGpuArray_STRIDE(py_self, 0) / %(itemsize_x)s,
PyGpuArray_STRIDES(py_self)[1] / %(itemsize_x)s, PyGpuArray_STRIDE(py_self, 1) / %(itemsize_x)s,
(npy_%(dtype_x)s*)( (npy_%(dtype_x)s*)(
((char *)cuda_get_ptr(py_self->ga.data)) + ((char *)cuda_get_ptr(py_self->ga.data)) +
py_self->ga.offset), py_self->ga.offset),
PyGpuArray_DIMS(py_other)[0], PyGpuArray_DIM(py_other, 0),
PyGpuArray_DIMS(py_other)[1], PyGpuArray_DIM(py_other, 1),
PyGpuArray_STRIDES(py_other)[0] / %(itemsize_y)s, PyGpuArray_DIM(py_other, 0) == 1 ? 0 : PyGpuArray_STRIDE(py_other, 0) / %(itemsize_y)s,
PyGpuArray_STRIDES(py_other)[1] / %(itemsize_y)s, PyGpuArray_DIM(py_other, 1) == 1 ? 0 : PyGpuArray_STRIDE(py_other, 1) / %(itemsize_y)s,
(npy_%(dtype_x)s*)( (npy_%(dtype_x)s*)(
((char *)cuda_get_ptr(py_other->ga.data)) + ((char *)cuda_get_ptr(py_other->ga.data)) +
py_other->ga.offset), py_other->ga.offset),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论