提交 1366221b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use the correct dtype object in numba_funcify_CAReduce

上级 248ce6d9
...@@ -519,8 +519,6 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -519,8 +519,6 @@ def numba_funcify_CAReduce(op, node, **kwargs):
scalar_op_identity = np.asarray(op.scalar_op.identity, dtype=np_acc_dtype) scalar_op_identity = np.asarray(op.scalar_op.identity, dtype=np_acc_dtype)
acc_dtype = numba.np.numpy_support.from_dtype(np_acc_dtype)
scalar_nfunc_spec = op.scalar_op.nfunc_spec scalar_nfunc_spec = op.scalar_op.nfunc_spec
# We construct a dummy `Apply` that has the minimum required number of # We construct a dummy `Apply` that has the minimum required number of
...@@ -528,15 +526,15 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -528,15 +526,15 @@ def numba_funcify_CAReduce(op, node, **kwargs):
# with too few arguments. # with too few arguments.
dummy_node = Apply( dummy_node = Apply(
op, op,
[tensor(acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])], [tensor(np_acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])],
[tensor(acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])], [tensor(np_acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])],
) )
elemwise_fn = numba_funcify_Elemwise(op, dummy_node, use_signature=True, **kwargs) elemwise_fn = numba_funcify_Elemwise(op, dummy_node, use_signature=True, **kwargs)
input_name = get_name_for_object(node.inputs[0]) input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim ndim = node.inputs[0].ndim
careduce_fn = create_multiaxis_reducer( careduce_fn = create_multiaxis_reducer(
elemwise_fn, scalar_op_identity, axes, ndim, acc_dtype, input_name=input_name elemwise_fn, scalar_op_identity, axes, ndim, np_acc_dtype, input_name=input_name
) )
return numba.njit(careduce_fn) return numba.njit(careduce_fn)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论