提交 776518b7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move dtype conversion dictionary to global scope in aesara.tensor.type

上级 6ca2eecf
...@@ -24,6 +24,25 @@ all_dtypes = list(map(str, aes.all_types)) ...@@ -24,6 +24,25 @@ all_dtypes = list(map(str, aes.all_types))
int_dtypes = list(map(str, aes.int_types)) int_dtypes = list(map(str, aes.int_types))
uint_dtypes = list(map(str, aes.uint_types)) uint_dtypes = list(map(str, aes.uint_types))
# TODO: add more type correspondances for e.g. int32, int64, float32,
# complex64, etc.
dtype_specs_map = {
"float16": (float, "npy_float16", "NPY_FLOAT16"),
"float32": (float, "npy_float32", "NPY_FLOAT32"),
"float64": (float, "npy_float64", "NPY_FLOAT64"),
"bool": (bool, "npy_bool", "NPY_BOOL"),
"uint8": (int, "npy_uint8", "NPY_UINT8"),
"int8": (int, "npy_int8", "NPY_INT8"),
"uint16": (int, "npy_uint16", "NPY_UINT16"),
"int16": (int, "npy_int16", "NPY_INT16"),
"uint32": (int, "npy_uint32", "NPY_UINT32"),
"int32": (int, "npy_int32", "NPY_INT32"),
"uint64": (int, "npy_uint64", "NPY_UINT64"),
"int64": (int, "npy_int64", "NPY_INT64"),
"complex128": (complex, "aesara_complex128", "NPY_COMPLEX128"),
"complex64": (complex, "aesara_complex64", "NPY_COMPLEX64"),
}
class TensorType(CType): class TensorType(CType):
""" """
...@@ -260,25 +279,8 @@ class TensorType(CType): ...@@ -260,25 +279,8 @@ class TensorType(CType):
This function is used internally as part of C code generation. This function is used internally as part of C code generation.
""" """
# TODO: add more type correspondances for e.g. int32, int64, float32,
# complex64, etc.
try: try:
return { return dtype_specs_map[self.dtype]
"float16": (float, "npy_float16", "NPY_FLOAT16"),
"float32": (float, "npy_float32", "NPY_FLOAT32"),
"float64": (float, "npy_float64", "NPY_FLOAT64"),
"bool": (bool, "npy_bool", "NPY_BOOL"),
"uint8": (int, "npy_uint8", "NPY_UINT8"),
"int8": (int, "npy_int8", "NPY_INT8"),
"uint16": (int, "npy_uint16", "NPY_UINT16"),
"int16": (int, "npy_int16", "NPY_INT16"),
"uint32": (int, "npy_uint32", "NPY_UINT32"),
"int32": (int, "npy_int32", "NPY_INT32"),
"uint64": (int, "npy_uint64", "NPY_UINT64"),
"int64": (int, "npy_int64", "NPY_INT64"),
"complex128": (complex, "aesara_complex128", "NPY_COMPLEX128"),
"complex64": (complex, "aesara_complex64", "NPY_COMPLEX64"),
}[self.dtype]
except KeyError: except KeyError:
raise TypeError( raise TypeError(
f"Unsupported dtype for {self.__class__.__name__}: {self.dtype}" f"Unsupported dtype for {self.__class__.__name__}: {self.dtype}"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论