提交 1feb53a7 authored 作者: Frederic's avatar Frederic

Add some support for GpuElemwise scalar_op c_support_code.

This is needed for the new gpu op tests to pass.
上级 d336a8ef
...@@ -154,6 +154,12 @@ class GpuElemwise(HideC, Elemwise): ...@@ -154,6 +154,12 @@ class GpuElemwise(HideC, Elemwise):
#define ga_half uint16_t #define ga_half uint16_t
""" """
try:
#We accept only some c_support_code().
#This filter is done in the make_node()
support_code += self.scalar_op.c_support_code()
except MethodNotDefined:
pass
for npy, ga in [("npy_uint8", "ga_ubyte"), for npy, ga in [("npy_uint8", "ga_ubyte"),
("npy_uint16", "ga_ushort"), ("npy_uint16", "ga_ushort"),
("npy_uin32", "ga_uint"), ("npy_uin32", "ga_uint"),
...@@ -179,6 +185,9 @@ class GpuElemwise(HideC, Elemwise): ...@@ -179,6 +185,9 @@ class GpuElemwise(HideC, Elemwise):
raise MethodNotDefined('cuda only') raise MethodNotDefined('cuda only')
return NVCC_compiler return NVCC_compiler
def c_support_code(self):
return self.scalar_op.c_support_code()
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
if pygpu.get_default_context().kind == 'opencl': if pygpu.get_default_context().kind == 'opencl':
raise MethodNotDefined('cuda only') raise MethodNotDefined('cuda only')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论