提交 42587563 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba CAReduce: respect acc_dtype

Also fix infinity identities for unsigned integers
上级 0cc6314b
......@@ -1391,7 +1391,10 @@ class CAReduce(COp):
return f"axes={list(axis)}"
def __str__(self):
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}"
if self.acc_dtype != self.dtype:
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}, acc={self.acc_dtype}}}"
else:
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}"
def perform(self, node, inp, out):
(input,) = inp
......
......@@ -357,7 +357,10 @@ def max_and_argmax(a, axis=None, keepdims=False):
class FixedOpCAReduce(CAReduce):
def __str__(self):
return f"{type(self).__name__}{{{self._axis_str()}}}"
if self.dtype != self.acc_dtype:
return f"{type(self).__name__}{{{self._axis_str()}, acc={self.acc_dtype}}}"
else:
return f"{type(self).__name__}{{{self._axis_str()}}}"
class NonZeroDimsCAReduce(FixedOpCAReduce):
......
......@@ -13,7 +13,7 @@ from pytensor.compile.ops import deep_copy_op
from pytensor.gradient import grad
from pytensor.scalar import Composite, float64
from pytensor.scalar import add as scalar_add
from pytensor.tensor import blas, tensor
from pytensor.tensor import blas, matrix, tensor, tensor3
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
......@@ -366,6 +366,45 @@ def test_CAReduce(careduce_fn, axis, v):
assert isinstance(node.op, CAReduce)
@pytest.mark.parametrize("axis", (-1, (0, -1), None))
def test_CAReduce_respects_acc_dtype(axis):
x = tensor3("x", dtype="int8")
out = x.sum(dtype="int8", acc_dtype="int64", axis=axis)
# Choose values that would overflow if accumulated internally in int8
max_int8 = np.iinfo(np.int8).max
test_x = np.array([max_int8, 5, max_int8, -max_int8, 5, -max_int8], dtype=np.int8)
test_x = np.broadcast_to(test_x, (6, 2, 6)).copy()
_, [res] = compare_numba_and_py(
[x],
[out],
[test_x],
)
if axis == -1:
assert np.all(res == 10)
elif axis == (0, -1):
assert np.all(res == 60)
elif axis is None:
assert res == 120
@pytest.mark.parametrize("axis", (1, None))
def test_CAReduce_acc_complex_out_float(axis):
x = matrix("x", dtype="complex128")
out = x.sum(dtype="float64", axis=axis)
test_x = np.array([[1 + 0.5j, 2 - 0.5j], [3 + 0.5j, 4 - 0.5j]], dtype="complex128")
compare_numba_and_py([x], [out], [test_x])
@pytest.mark.parametrize("axis", (-1, (0, -1), None))
def test_CAReduce_discrete_infinity_identity(axis):
rng = np.random.default_rng(337)
x = tensor3("x", dtype="int8")
out = x.max(axis)
compare_numba_and_py(
[x], [out], [rng.integers(-127, 127, size=(6, 6, 6)).astype("int8")]
)
def test_scalar_Elemwise_Clip():
a = pt.scalar("a")
b = pt.scalar("b")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论