提交 8e4cf99c authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1732 from abergeron/fallback_perform

Don't try to use the c_code for GpuElemwise on opencl since it's cuda-only.
...@@ -169,13 +169,19 @@ class GpuElemwise(HideC, Elemwise): ...@@ -169,13 +169,19 @@ class GpuElemwise(HideC, Elemwise):
return ElemwiseKernel(None, inps+outs, kop, preamble=support_code) return ElemwiseKernel(None, inps+outs, kop, preamble=support_code)
def c_headers(self): def c_headers(self):
if pygpu.get_default_context().kind == 'opencl':
raise MethodNotDefined('cuda only')
return ['cuda.h', '<compyte/extension.h>', '<numpy_compat.h>', return ['cuda.h', '<compyte/extension.h>', '<numpy_compat.h>',
'<compyte/ext_cuda.h>'] '<compyte/ext_cuda.h>']
def c_compiler(self): def c_compiler(self):
if pygpu.get_default_context().kind == 'opencl':
raise MethodNotDefined('cuda only')
return NVCC_compiler return NVCC_compiler
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
if pygpu.get_default_context().kind == 'opencl':
raise MethodNotDefined('cuda only')
# This is useless by itself, but will serve an eventual c_code # This is useless by itself, but will serve an eventual c_code
# implementation # implementation
k = self.generate_kernel(node, nodename) k = self.generate_kernel(node, nodename)
...@@ -214,9 +220,13 @@ class GpuElemwise(HideC, Elemwise): ...@@ -214,9 +220,13 @@ class GpuElemwise(HideC, Elemwise):
return '\n'.join(res) return '\n'.join(res)
def c_init_code(self): def c_init_code(self):
if pygpu.get_default_context().kind == 'opencl':
raise MethodNotDefined('cuda only')
return ['setup_ext_cuda();'] return ['setup_ext_cuda();']
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
if pygpu.get_default_context().kind == 'opencl':
raise MethodNotDefined('cuda only')
nd = node.outputs[0].ndim nd = node.outputs[0].ndim
fail = sub["fail"] fail = sub["fail"]
initial_dims = ','.join('1' for i in xrange(nd)) initial_dims = ','.join('1' for i in xrange(nd))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论