提交 d58444e1 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Replace content of c_init_code by reference to function setup_ext_cuda() in…

Replace content of c_init_code by reference to function setup_ext_cuda() in GpuSoftmax and GpuSoftmaxWithBias
上级 a123a224
......@@ -466,16 +466,17 @@ class GpuSoftmax (Op):
return shape
def c_code_cache_version(self):
return (10,) + inline_softmax.code_version
return (11,) + inline_softmax.code_version
def c_headers(self):
return ['cuda.h', '<compyte/extension.h>', '<numpy_compat.h>']
return ['cuda.h', '<compyte/extension.h>', '<numpy_compat.h>',
'<compyte/ext_cuda.h>']
def c_compiler(self):
return NVCC_compiler
def c_init_code(self):
return ['cuda_get_ptr = (CUdeviceptr (*)(gpudata *g))compyte_get_extension("cuda_get_ptr");']
return ['setup_ext_cuda();']
def c_code(self, node, nodename, inp, out, sub):
dtype_x = node.inputs[0].dtype
......@@ -625,8 +626,7 @@ class GpuSoftmax (Op):
"__syncthreads()",
"}",
])
ret3 = "CUdeviceptr (*cuda_get_ptr)(gpudata *g);"
return (ret1 + "\n" + ret2 + "\n" + ret3) % locals()
return (ret1 + "\n" + ret2) % locals()
gpu_softmax = GpuSoftmax()
......@@ -654,19 +654,20 @@ class GpuSoftmaxWithBias (Op):
def infer_shape(self, node, shape):
return [shape[0]]
def c_code_cache_version(self):
return (9,) + inline_softmax.code_version
return (10,) + inline_softmax.code_version
def c_headers(self):
return ['cuda.h', '<compyte/extension.h>', '<numpy_compat.h>']
return ['cuda.h', '<compyte/extension.h>', '<numpy_compat.h>',
'<compyte/ext_cuda.h>']
def c_compiler(self):
return NVCC_compiler
def c_init_code(self):
return ['cuda_get_ptr = (CUdeviceptr (*)(gpudata *g))compyte_get_extension("cuda_get_ptr");']
return ['setup_ext_cuda();']
def c_code(self, node, nodename, inp, out, sub):
dtype_x = node.inputs[0].dtype
dtype_b = node.inputs[1].dtype
......@@ -839,7 +840,6 @@ class GpuSoftmaxWithBias (Op):
"__syncthreads()",
"}",
])
ret3 = "CUdeviceptr (*cuda_get_ptr)(gpudata *g);"
return (ret1 + "\n" + ret2 + "\n" + ret3) % locals()
return (ret1 + "\n" + ret2) % locals()
gpu_softmax_with_bias = GpuSoftmaxWithBias()
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论