提交 d58444e1 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Replace content of c_init_code by reference to function setup_ext_cuda() in…

Replace content of c_init_code by reference to function setup_ext_cuda() in GpuSoftmax and GpuSoftmaxWithBias
上级 a123a224
...@@ -466,16 +466,17 @@ class GpuSoftmax (Op): ...@@ -466,16 +466,17 @@ class GpuSoftmax (Op):
return shape return shape
def c_code_cache_version(self): def c_code_cache_version(self):
return (10,) + inline_softmax.code_version return (11,) + inline_softmax.code_version
def c_headers(self): def c_headers(self):
return ['cuda.h', '<compyte/extension.h>', '<numpy_compat.h>'] return ['cuda.h', '<compyte/extension.h>', '<numpy_compat.h>',
'<compyte/ext_cuda.h>']
def c_compiler(self): def c_compiler(self):
return NVCC_compiler return NVCC_compiler
def c_init_code(self): def c_init_code(self):
return ['cuda_get_ptr = (CUdeviceptr (*)(gpudata *g))compyte_get_extension("cuda_get_ptr");'] return ['setup_ext_cuda();']
def c_code(self, node, nodename, inp, out, sub): def c_code(self, node, nodename, inp, out, sub):
dtype_x = node.inputs[0].dtype dtype_x = node.inputs[0].dtype
...@@ -625,8 +626,7 @@ class GpuSoftmax (Op): ...@@ -625,8 +626,7 @@ class GpuSoftmax (Op):
"__syncthreads()", "__syncthreads()",
"}", "}",
]) ])
ret3 = "CUdeviceptr (*cuda_get_ptr)(gpudata *g);" return (ret1 + "\n" + ret2) % locals()
return (ret1 + "\n" + ret2 + "\n" + ret3) % locals()
gpu_softmax = GpuSoftmax() gpu_softmax = GpuSoftmax()
...@@ -654,19 +654,20 @@ class GpuSoftmaxWithBias (Op): ...@@ -654,19 +654,20 @@ class GpuSoftmaxWithBias (Op):
def infer_shape(self, node, shape): def infer_shape(self, node, shape):
return [shape[0]] return [shape[0]]
def c_code_cache_version(self): def c_code_cache_version(self):
return (9,) + inline_softmax.code_version return (10,) + inline_softmax.code_version
def c_headers(self): def c_headers(self):
return ['cuda.h', '<compyte/extension.h>', '<numpy_compat.h>'] return ['cuda.h', '<compyte/extension.h>', '<numpy_compat.h>',
'<compyte/ext_cuda.h>']
def c_compiler(self): def c_compiler(self):
return NVCC_compiler return NVCC_compiler
def c_init_code(self): def c_init_code(self):
return ['cuda_get_ptr = (CUdeviceptr (*)(gpudata *g))compyte_get_extension("cuda_get_ptr");'] return ['setup_ext_cuda();']
def c_code(self, node, nodename, inp, out, sub): def c_code(self, node, nodename, inp, out, sub):
dtype_x = node.inputs[0].dtype dtype_x = node.inputs[0].dtype
dtype_b = node.inputs[1].dtype dtype_b = node.inputs[1].dtype
...@@ -839,7 +840,6 @@ class GpuSoftmaxWithBias (Op): ...@@ -839,7 +840,6 @@ class GpuSoftmaxWithBias (Op):
"__syncthreads()", "__syncthreads()",
"}", "}",
]) ])
ret3 = "CUdeviceptr (*cuda_get_ptr)(gpudata *g);" return (ret1 + "\n" + ret2) % locals()
return (ret1 + "\n" + ret2 + "\n" + ret3) % locals()
gpu_softmax_with_bias = GpuSoftmaxWithBias() gpu_softmax_with_bias = GpuSoftmaxWithBias()
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论