提交 394844a7 authored 作者: Frederic's avatar Frederic

Don't include Python/NumPy include, convert the string.

Otherwise, in some platform the include fail. This is more robust.
上级 f1f3e8fb
...@@ -149,9 +149,19 @@ class GpuElemwise(HideC, Elemwise): ...@@ -149,9 +149,19 @@ class GpuElemwise(HideC, Elemwise):
#define ga_double double #define ga_double double
#define ga_half uint16_t #define ga_half uint16_t
#include <Python.h>
#include <numpy/npy_common.h>
""" """
for npy, ga in [("npy_uint8", "ga_ubyte"),
("npy_uint16", "ga_ushort"),
("npy_uin32", "ga_uint"),
("npy_uin64", "ga_ulong"),
("npy_int8", "ga_byte"),
("npy_int16", "ga_short"),
("npy_int32", "ga_int"),
("npy_int64", "ga_long"),
("npy_float32", "ga_float"),
("npy_float64", "ga_double"),
]:
kop = kop.replace(npy, ga)
return ElemwiseKernel(None, inps+outs, kop, preamble=support_code) return ElemwiseKernel(None, inps+outs, kop, preamble=support_code)
def c_headers(self): def c_headers(self):
...@@ -338,8 +348,8 @@ class GpuElemwise(HideC, Elemwise): ...@@ -338,8 +348,8 @@ class GpuElemwise(HideC, Elemwise):
node.inputs + node.outputs)): node.inputs + node.outputs)):
if (n - len(inputs)) in self.inplace_pattern: if (n - len(inputs)) in self.inplace_pattern:
continue continue
dtype = var.dtype dtype = dtype_to_ctype(var.dtype)
param.append("(npy_%(dtype)s*)(cuda_get_ptr(%(name)s->ga.data))" % locals()) param.append("(%(dtype)s*)(cuda_get_ptr(%(name)s->ga.data))" % locals())
param.append("%(name)s->ga.offset" % locals()) param.append("%(name)s->ga.offset" % locals())
for i in range(nd): for i in range(nd):
param.append("PyGpuArray_DIMS(%(name)s)[%(i)d] == 1 ? 0 : PyGpuArray_STRIDES(%(name)s)[%(i)d]" % locals()) param.append("PyGpuArray_DIMS(%(name)s)[%(i)d] == 1 ? 0 : PyGpuArray_STRIDES(%(name)s)[%(i)d]" % locals())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论