提交 42587563 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba CAReduce: respect acc_dtype

Also fix infinity identities for unsigned integers
上级 0cc6314b
...@@ -1391,7 +1391,10 @@ class CAReduce(COp): ...@@ -1391,7 +1391,10 @@ class CAReduce(COp):
return f"axes={list(axis)}" return f"axes={list(axis)}"
def __str__(self): def __str__(self):
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}" if self.acc_dtype != self.dtype:
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}, acc={self.acc_dtype}}}"
else:
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}"
def perform(self, node, inp, out): def perform(self, node, inp, out):
(input,) = inp (input,) = inp
......
...@@ -357,7 +357,10 @@ def max_and_argmax(a, axis=None, keepdims=False): ...@@ -357,7 +357,10 @@ def max_and_argmax(a, axis=None, keepdims=False):
class FixedOpCAReduce(CAReduce): class FixedOpCAReduce(CAReduce):
def __str__(self): def __str__(self):
return f"{type(self).__name__}{{{self._axis_str()}}}" if self.dtype != self.acc_dtype:
return f"{type(self).__name__}{{{self._axis_str()}, acc={self.acc_dtype}}}"
else:
return f"{type(self).__name__}{{{self._axis_str()}}}"
class NonZeroDimsCAReduce(FixedOpCAReduce): class NonZeroDimsCAReduce(FixedOpCAReduce):
......
...@@ -13,7 +13,7 @@ from pytensor.compile.ops import deep_copy_op ...@@ -13,7 +13,7 @@ from pytensor.compile.ops import deep_copy_op
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.scalar import Composite, float64 from pytensor.scalar import Composite, float64
from pytensor.scalar import add as scalar_add from pytensor.scalar import add as scalar_add
from pytensor.tensor import blas, tensor from pytensor.tensor import blas, matrix, tensor, tensor3
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
...@@ -366,6 +366,45 @@ def test_CAReduce(careduce_fn, axis, v): ...@@ -366,6 +366,45 @@ def test_CAReduce(careduce_fn, axis, v):
assert isinstance(node.op, CAReduce) assert isinstance(node.op, CAReduce)
@pytest.mark.parametrize("axis", (-1, (0, -1), None))
def test_CAReduce_respects_acc_dtype(axis):
x = tensor3("x", dtype="int8")
out = x.sum(dtype="int8", acc_dtype="int64", axis=axis)
# Choose values that would overflow if accumulated internally in int8
max_int8 = np.iinfo(np.int8).max
test_x = np.array([max_int8, 5, max_int8, -max_int8, 5, -max_int8], dtype=np.int8)
test_x = np.broadcast_to(test_x, (6, 2, 6)).copy()
_, [res] = compare_numba_and_py(
[x],
[out],
[test_x],
)
if axis == -1:
assert np.all(res == 10)
elif axis == (0, -1):
assert np.all(res == 60)
elif axis is None:
assert res == 120
@pytest.mark.parametrize("axis", (1, None))
def test_CAReduce_acc_complex_out_float(axis):
x = matrix("x", dtype="complex128")
out = x.sum(dtype="float64", axis=axis)
test_x = np.array([[1 + 0.5j, 2 - 0.5j], [3 + 0.5j, 4 - 0.5j]], dtype="complex128")
compare_numba_and_py([x], [out], [test_x])
@pytest.mark.parametrize("axis", (-1, (0, -1), None))
def test_CAReduce_discrete_infinity_identity(axis):
rng = np.random.default_rng(337)
x = tensor3("x", dtype="int8")
out = x.max(axis)
compare_numba_and_py(
[x], [out], [rng.integers(-127, 127, size=(6, 6, 6)).astype("int8")]
)
def test_scalar_Elemwise_Clip(): def test_scalar_Elemwise_Clip():
a = pt.scalar("a") a = pt.scalar("a")
b = pt.scalar("b") b = pt.scalar("b")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论