提交 b3f294a1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use numba.from_dtype and simplify force_scalar in get_numba_type

上级 8335bbfe
...@@ -55,13 +55,15 @@ def get_numba_type( ...@@ -55,13 +55,15 @@ def get_numba_type(
) -> numba.types.Type: ) -> numba.types.Type:
"""Create a Numba type object for a ``Type``.""" """Create a Numba type object for a ``Type``."""
if isinstance(aesara_type, TensorType) and not force_scalar: if isinstance(aesara_type, TensorType):
dtype = aesara_type.numpy_dtype dtype = aesara_type.numpy_dtype
numba_dtype = numba.np.numpy_support.from_dtype(dtype) numba_dtype = numba.from_dtype(dtype)
if force_scalar:
return numba_dtype
return numba.types.Array(numba_dtype, aesara_type.ndim, layout) return numba.types.Array(numba_dtype, aesara_type.ndim, layout)
elif isinstance(aesara_type, Scalar) or force_scalar: elif isinstance(aesara_type, Scalar):
dtype = np.dtype(aesara_type.dtype) dtype = np.dtype(aesara_type.dtype)
numba_dtype = numba.np.numpy_support.from_dtype(dtype) numba_dtype = numba.from_dtype(dtype)
return numba_dtype return numba_dtype
else: else:
raise NotImplementedError(f"Numba type not implemented for {aesara_type}") raise NotImplementedError(f"Numba type not implemented for {aesara_type}")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论