提交 49acbc5e authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Ricardo Vieira

fix(numba): Cast arguments to dot to float

Numba doesn't support dot with non-floating point arguments.
上级 2ba26473
...@@ -760,7 +760,9 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): ...@@ -760,7 +760,9 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
def int_to_float_fn(inputs, out_dtype): def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats.""" """Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
if all(input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs): if all(
input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs
) and isinstance(np.dtype(out_dtype), np.floating):
@numba_njit @numba_njit
def inputs_cast(x): def inputs_cast(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论