提交 83139b8c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix bug in GpuCrossentropy...1Hot...Dx

上级 3fce8613
...@@ -369,7 +369,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op): ...@@ -369,7 +369,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op):
return node.inputs[0].type.context return node.inputs[0].type.context
def c_code_cache_version(self): def c_code_cache_version(self):
return (11,) return (12,)
def c_headers(self): def c_headers(self):
return ['<numpy_compat.h>', '<gpuarray/types.h>'] return ['<numpy_compat.h>', '<gpuarray/types.h>']
...@@ -499,7 +499,8 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op): ...@@ -499,7 +499,8 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op):
load_sm = load_w(dtype_sm) load_sm = load_w(dtype_sm)
write_dx = write_w(dtype_dx) write_dx = write_w(dtype_dx)
flags = Kernel.get_flags(dtype_dnll, dtype_sm, dtype_y_idx, dtype_dx) flags = Kernel.get_flags(dtype_dnll, dtype_sm, dtype_y_idx, dtype_dx)
type_dnll = gpuarray.dtype_to_ctype(work_dnll) wtype_dnll = gpuarray.dtype_to_ctype(work_dnll)
type_dnll = gpuarray.dtype_to_ctype(dtype_dnll)
type_sm = gpuarray.dtype_to_ctype(dtype_sm) type_sm = gpuarray.dtype_to_ctype(dtype_sm)
type_y_idx = gpuarray.dtype_to_ctype(dtype_y_idx) type_y_idx = gpuarray.dtype_to_ctype(dtype_y_idx)
type_dx = gpuarray.dtype_to_ctype(dtype_dx) type_dx = gpuarray.dtype_to_ctype(dtype_dx)
...@@ -525,7 +526,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op): ...@@ -525,7 +526,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op):
for (int i = blockIdx.x; i < N; i += gridDim.x) for (int i = blockIdx.x; i < N; i += gridDim.x)
{ {
%(type_dnll)s dnll_i = %(load_dnll)s(dnll[i * dnll_s0]); %(wtype_dnll)s dnll_i = %(load_dnll)s(dnll[i * dnll_s0]);
%(type_y_idx)s y_i = y_idx[i * y_idx_s0]; %(type_y_idx)s y_i = y_idx[i * y_idx_s0];
for (int j = threadIdx.x; j < K; j += blockDim.x) for (int j = threadIdx.x; j < K; j += blockDim.x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论