提交 22d04f87 authored 作者: notoraptor's avatar notoraptor

Extend CLinkerType with new abstract method `c_element_type()`.

Implement this new method into: - Scalar - TensorType - GpuArrayType
上级 6548e5cc
...@@ -35,6 +35,19 @@ class CLinkerType(CLinkerObject): ...@@ -35,6 +35,19 @@ class CLinkerType(CLinkerObject):
""" """
def c_element_type(self):
"""
Optional: Return the name of the primitive C type of items into variables
handled by this type.
e.g:
- For ``TensorType(dtype='int64', ...)``: should return ``"npy_int64"``.
- For ``GpuArrayType(dtype='int32', ...)``: should return ``"ga_int"``.
"""
return ''
def c_is_simple(self): def c_is_simple(self):
""" """
Optional: Return True for small or builtin C types. Optional: Return True for small or builtin C types.
......
...@@ -430,6 +430,9 @@ class GpuArrayType(Type): ...@@ -430,6 +430,9 @@ class GpuArrayType(Type):
else: else:
return np.dtype(self.dtype).itemsize return np.dtype(self.dtype).itemsize
def c_element_type(self):
return pygpu.gpuarray.dtype_to_ctype(self.dtype)
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
return """ return """
PyGpuArrayObject *%(name)s; PyGpuArrayObject *%(name)s;
......
...@@ -349,6 +349,9 @@ class Scalar(Type): ...@@ -349,6 +349,9 @@ class Scalar(Type):
return True return True
return abs(diff) <= (abs(a) * tolerance) + (abs(b) * tolerance) return abs(diff) <= (abs(a) * tolerance) + (abs(b) * tolerance)
def c_element_type(self):
return self.dtype_specs()[1]
def c_headers(self, c_compiler): def c_headers(self, c_compiler):
l = ['<math.h>'] l = ['<math.h>']
# These includes are needed by Scalar and TensorType, # These includes are needed by Scalar and TensorType,
......
...@@ -375,6 +375,9 @@ class TensorType(Type): ...@@ -375,6 +375,9 @@ class TensorType(Type):
return str(self) return str(self)
# "TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable)) # "TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable))
def c_element_type(self):
return self.dtype_specs()[1]
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
""" """
Override `CLinkerType.c_declare`. Override `CLinkerType.c_declare`.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论