提交 033bf332 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix Numba CAReduce infinite identity values for non-floats

上级 d0cb1374
......@@ -487,7 +487,15 @@ def numba_funcify_CAReduce(op, node, **kwargs):
np_acc_dtype = np.dtype(acc_dtype)
scalar_op_identity = np.asarray(op.scalar_op.identity, dtype=np_acc_dtype)
scalar_op_identity = op.scalar_op.identity
if np_acc_dtype.kind == "i" and not np.isfinite(scalar_op_identity):
if np.isposinf(scalar_op_identity):
scalar_op_identity = np.iinfo(np_acc_dtype).max
else:
scalar_op_identity = np.iinfo(np_acc_dtype).min
# Make sure it has the correct dtype
scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype)
input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim
......
......@@ -1164,6 +1164,13 @@ def test_ARange(start, stop, step, dtype):
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
None,
set_test_value(
at.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
None,
......@@ -1171,6 +1178,13 @@ def test_ARange(start, stop, step, dtype):
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
None,
set_test_value(
at.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))
),
),
],
)
def test_CAReduce(careduce_fn, axis, v):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论