提交 0427130d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix Numba dtype inconsistencies

上级 ee758618
...@@ -921,7 +921,7 @@ def numba_funcify_ScalarFromTensor(op, **kwargs): ...@@ -921,7 +921,7 @@ def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba_funcify.register(AllocEmpty) @numba_funcify.register(AllocEmpty)
def numba_funcify_AllocEmpty(op, node, **kwargs): def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = {"np": np, "to_scalar": to_scalar, "dtype": op.dtype} global_env = {"np": np, "to_scalar": to_scalar, "dtype": np.dtype(op.dtype)}
unique_names = unique_name_generator( unique_names = unique_name_generator(
["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_" ["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
...@@ -1114,7 +1114,6 @@ def direct_cast(typingctx, val, typ): ...@@ -1114,7 +1114,6 @@ def direct_cast(typingctx, val, typ):
def numba_funcify_Cast(op, node, **kwargs): def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype) dtype = np.dtype(op.o_type.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype)
@numba.njit(inline="always") @numba.njit(inline="always")
def cast(x): def cast(x):
...@@ -1169,7 +1168,6 @@ def numba_funcify_Clip(op, **kwargs): ...@@ -1169,7 +1168,6 @@ def numba_funcify_Clip(op, **kwargs):
@numba_funcify.register(ARange) @numba_funcify.register(ARange)
def numba_funcify_ARange(op, **kwargs): def numba_funcify_ARange(op, **kwargs):
dtype = np.dtype(op.dtype) dtype = np.dtype(op.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype)
@numba.njit(inline="always") @numba.njit(inline="always")
def arange(start, stop, step): def arange(start, stop, step):
...@@ -1213,7 +1211,6 @@ def numba_funcify_ExtractDiag(op, **kwargs): ...@@ -1213,7 +1211,6 @@ def numba_funcify_ExtractDiag(op, **kwargs):
@numba_funcify.register(Eye) @numba_funcify.register(Eye)
def numba_funcify_Eye(op, **kwargs): def numba_funcify_Eye(op, **kwargs):
dtype = np.dtype(op.dtype) dtype = np.dtype(op.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype)
@numba.njit(inline="always") @numba.njit(inline="always")
def eye(N, M, k): def eye(N, M, k):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论