提交 42587563 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba CAReduce: respect acc_dtype

Also fix infinity identities for unsigned integers
上级 0cc6314b
......@@ -2,7 +2,6 @@ from functools import singledispatch
from hashlib import sha256
from textwrap import dedent, indent
import numba
import numpy as np
from numba.core.extending import overload
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
......@@ -15,6 +14,7 @@ from pytensor.link.numba.cache import (
)
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_tuple_string,
numba_funcify_and_cache_key,
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
......@@ -126,10 +126,12 @@ if {res}[{idx}] > {arr}:
def create_multiaxis_reducer(
scalar_op,
*,
identity,
axes,
ndim,
dtype,
acc_dtype=None,
out_dtype,
keepdims: bool = False,
):
r"""Construct a function that reduces multiple axes.
......@@ -139,17 +141,46 @@ def create_multiaxis_reducer(
.. code-block:: python
def careduce_add(x):
# For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add"
x_shape = x.shape
res_shape = x_shape[2]
res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype)
res_shape = (x_shape[0], x_shape[1])
# identity = 0.0
res = np.full(res_shape, identity, dtype=np.float64)
for i0 in range(x_shape[0]):
for i1 in range(x_shape[1]):
for i2 in range(x_shape[2]):
res[i0, i1] += x[i0, i1, i2]
return res
If accumulation dtype differs from output_dtype
.. code-block:: python
def careduce_add(x):
x_shape = x.shape
res_shape = (x_shape[0], x_shape[1])
# identity = 0.0
res = np.full(res_shape, identity, dtype=np.float64)
for i0 in range(x_shape[0]):
for i1 in range(x_shape[1]):
for i2 in range(x_shape[2]):
res[i2] += x[i0, i1, i2]
res[i0, i1] += x[i0, i1, i2]
return res.astype(np.int32)
Full reductions accumulate on scalars
.. code-block:: python
def careduce_mul(x):
x_shape = x.shape
res_shape = ()
# identity = 1.0
res = identity
for i0 in range(x_shape[0]):
for i1 in range(x_shape[1]):
for i2 in range(x_shape[2]):
res *= x[i0, i1, i2]
return np.array(res, dtype=np.int32)
return res
Parameters
==========
......@@ -161,7 +192,9 @@ def create_multiaxis_reducer(
The axes to reduce.
ndim:
The number of dimensions of the input variable.
dtype:
acc_dtype: dtype, optional
The data type used during accumulation. Defaults to out_dtype if not provided
out_dtype:
The data type of the result.
keepdims: boolean, default False
Whether to keep the reduced dimensions.
......@@ -179,19 +212,23 @@ def create_multiaxis_reducer(
"Cannot keep multiple dimensions when reducing multiple axes"
)
out_dtype = np.dtype(out_dtype)
acc_dtype = out_dtype if acc_dtype is None else np.dtype(acc_dtype)
# Numba doesn't allow converting complex to real with a simple `astype`
complex_to_real = acc_dtype.kind == "c" and out_dtype.kind != "c"
out_dtype_str = f"np.{out_dtype.name}"
acc_dtype_str = f"np.{acc_dtype.name}"
careduce_fn_name = f"careduce_{scalar_op}"
identity = str(identity)
if identity == "inf":
identity = "np.inf"
elif identity == "-inf":
identity = "-np.inf"
global_env = {
"np": np,
"numba_basic": numba_basic,
"out_dtype": dtype,
}
if acc_dtype.kind in "ui" and not np.isfinite(identity):
if np.isposinf(identity):
identity = np.iinfo(acc_dtype).max
else:
identity = np.iinfo(acc_dtype).min
# Make sure it has the correct dtype
identity = getattr(np, acc_dtype.name)(identity)
complete_reduction = len(axes) == ndim
kept_axis = tuple(i for i in range(ndim) if i not in axes)
......@@ -209,17 +246,23 @@ def create_multiaxis_reducer(
scalar_op, res_indices, "res", f"x[{arr_indices}]"
)
res_shape = f"({', '.join(f'x_shape[{i}]' for i in kept_axis)})"
res_shape = create_tuple_string([f"x_shape[{i}]" for i in kept_axis])
if complete_reduction and ndim > 0:
# We accumulate on a scalar, not an array
res_creator = f"np.asarray({identity}).astype(out_dtype).item()"
res_creator = "identity"
inplace_update_stmt = inplace_update_stmt.replace("res[()]", "res")
return_obj = "np.asarray(res)"
if complex_to_real:
return_obj = f"np.array(res).real.astype({out_dtype_str})"
else:
res_creator = (
f"np.full({res_shape}, np.asarray({identity}).item(), dtype=out_dtype)"
return_obj = f"np.array(res, dtype={out_dtype_str})"
else:
res_creator = f"np.full(res_shape, identity, dtype={acc_dtype_str})"
if complex_to_real:
return_obj = f"res.real.astype({out_dtype_str})"
else:
return_obj = (
"res" if out_dtype == acc_dtype else f"res.astype({out_dtype_str})"
)
return_obj = "res"
if keepdims:
[axis] = axes
......@@ -230,6 +273,7 @@ def create_multiaxis_reducer(
def {careduce_fn_name}(x):
x_shape = x.shape
res_shape = {res_shape}
# identity = {identity}
res = {res_creator}
"""
)
......@@ -239,13 +283,12 @@ def create_multiaxis_reducer(
" " * (4 + 4 * axis),
)
careduce_def_src += indent(inplace_update_stmt, " " * (4 + 4 * ndim))
careduce_def_src += "\n\n"
careduce_def_src += "\n"
careduce_def_src += indent(f"return {return_obj}", " " * 4)
careduce_fn = compile_numba_function_src(
careduce_def_src, careduce_fn_name, {**globals(), **global_env}
careduce_def_src, careduce_fn_name, globals() | {"np": np, "identity": identity}
)
return careduce_fn
......@@ -356,24 +399,18 @@ def numba_funcify_CAReduce(op, node, **kwargs):
acc_dtype = op.acc_dtype
else:
acc_dtype = node.outputs[0].type.dtype
np_acc_dtype = np.dtype(acc_dtype)
scalar_op_identity = op.scalar_op.identity
if np_acc_dtype.kind == "i" and not np.isfinite(scalar_op_identity):
if np.isposinf(scalar_op_identity):
scalar_op_identity = np.iinfo(np_acc_dtype).max
else:
scalar_op_identity = np.iinfo(np_acc_dtype).min
# Make sure it has the correct dtype
scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype)
out_dtype = np.dtype(node.outputs[0].type.dtype)
if isinstance(op, Sum) and node.inputs[0].ndim == len(axes):
if (
isinstance(op, Sum)
and node.inputs[0].ndim == len(axes)
and out_dtype == acc_dtype
):
# Slightly faster for this case
@numba_basic.numba_njit
def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
return np.array(array.sum())
careduce_fn = impl_sum # Some tests look for this name
......@@ -381,16 +418,26 @@ def numba_funcify_CAReduce(op, node, **kwargs):
ndim = node.inputs[0].ndim
careduce_py_fn = create_multiaxis_reducer(
op.scalar_op,
scalar_op_identity,
axes,
ndim,
out_dtype,
identity=op.scalar_op.identity,
axes=axes,
ndim=ndim,
acc_dtype=acc_dtype,
out_dtype=out_dtype,
)
careduce_fn = numba_basic.numba_njit(careduce_py_fn, boundscheck=False)
cache_version = 1
careduce_key = sha256(
str(
(type(op), type(op.scalar_op), axes, acc_dtype, scalar_op_identity.item())
(
type(op),
type(op.scalar_op),
axes,
out_dtype,
acc_dtype,
op.scalar_op.identity,
cache_version,
)
).encode()
).hexdigest()
return careduce_fn, careduce_key
......@@ -449,18 +496,26 @@ def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
@register_funcify_default_op_cache_key(Softmax)
def numba_funcify_Softmax(op, node, **kwargs):
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
ndim = node.inputs[0].type.ndim
inp_dtype = node.inputs[0].type.numpy_dtype
axis = op.axis
if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
if ndim > 1 and axis is not None:
reduce_max_py = create_multiaxis_reducer(
maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
maximum,
identity=-np.inf,
axes=(axis,),
ndim=ndim,
out_dtype=inp_dtype,
keepdims=True,
)
reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
add_as,
identity=0.0,
axes=(axis,),
ndim=ndim,
out_dtype=inp_dtype,
keepdims=True,
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
......@@ -470,29 +525,32 @@ def numba_funcify_Softmax(op, node, **kwargs):
reduce_max = np.max
reduce_sum = np.sum
def softmax_py_fn(x):
@numba_basic.numba_njit(boundscheck=False)
def softmax(x):
z = reduce_max(x)
e_x = np.exp(x - z)
w = reduce_sum(e_x)
sm = e_x / w
return sm
softmax = numba_basic.numba_njit(softmax_py_fn, boundscheck=False)
return softmax
cache_version = 1
return softmax, cache_version
@register_funcify_default_op_cache_key(SoftmaxGrad)
def numba_funcify_SoftmaxGrad(op, node, **kwargs):
sm_at = node.inputs[1]
sm_dtype = sm_at.type.numpy_dtype
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
ndim = node.inputs[0].type.ndim
inp_dtype = node.inputs[0].type.numpy_dtype
axis = op.axis
if axis is not None:
axis = normalize_axis_index(axis, sm_at.ndim)
if ndim > 1 and axis is not None:
reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
add_as,
identity=0.0,
axes=(axis,),
ndim=ndim,
out_dtype=inp_dtype,
keepdims=True,
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
......@@ -500,36 +558,39 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
else:
reduce_sum = np.sum
def softmax_grad_py_fn(dy, sm):
@numba_basic.numba_njit(boundscheck=False)
def softmax_grad(dy, sm):
dy_times_sm = dy * sm
sum_dy_times_sm = reduce_sum(dy_times_sm)
dx = dy_times_sm - sum_dy_times_sm * sm
return dx
softmax_grad = numba_basic.numba_njit(softmax_grad_py_fn, boundscheck=False)
return softmax_grad
cache_version = 1
return softmax_grad, cache_version
@register_funcify_default_op_cache_key(LogSoftmax)
def numba_funcify_LogSoftmax(op, node, **kwargs):
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
ndim = node.inputs[0].type.ndim
inp_dtype = node.inputs[0].type.numpy_dtype
axis = op.axis
if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
if ndim > 1 and axis is not None:
reduce_max_py = create_multiaxis_reducer(
maximum,
-np.inf,
(axis,),
x_at.ndim,
x_dtype,
identity=-np.inf,
axes=(axis,),
ndim=ndim,
out_dtype=inp_dtype,
keepdims=True,
)
reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
add_as,
identity=0.0,
axes=(axis,),
ndim=ndim,
out_dtype=inp_dtype,
keepdims=True,
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
......@@ -539,13 +600,14 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
reduce_max = np.max
reduce_sum = np.sum
def log_softmax_py_fn(x):
@numba_basic.numba_njit(boundscheck=False)
def log_softmax(x):
xdev = x - reduce_max(x)
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
return lsm
log_softmax = numba_basic.numba_njit(log_softmax_py_fn, boundscheck=False)
return log_softmax
cache_version = 1
return log_softmax, cache_version
@register_funcify_default_op_cache_key(Argmax)
......
......@@ -1391,6 +1391,9 @@ class CAReduce(COp):
return f"axes={list(axis)}"
def __str__(self):
if self.acc_dtype != self.dtype:
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}, acc={self.acc_dtype}}}"
else:
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}"
def perform(self, node, inp, out):
......
......@@ -357,6 +357,9 @@ def max_and_argmax(a, axis=None, keepdims=False):
class FixedOpCAReduce(CAReduce):
def __str__(self):
if self.dtype != self.acc_dtype:
return f"{type(self).__name__}{{{self._axis_str()}, acc={self.acc_dtype}}}"
else:
return f"{type(self).__name__}{{{self._axis_str()}}}"
......
......@@ -13,7 +13,7 @@ from pytensor.compile.ops import deep_copy_op
from pytensor.gradient import grad
from pytensor.scalar import Composite, float64
from pytensor.scalar import add as scalar_add
from pytensor.tensor import blas, tensor
from pytensor.tensor import blas, matrix, tensor, tensor3
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
......@@ -366,6 +366,45 @@ def test_CAReduce(careduce_fn, axis, v):
assert isinstance(node.op, CAReduce)
@pytest.mark.parametrize("axis", (-1, (0, -1), None))
def test_CAReduce_respects_acc_dtype(axis):
x = tensor3("x", dtype="int8")
out = x.sum(dtype="int8", acc_dtype="int64", axis=axis)
# Choose values that would overflow if accumulated internally in int8
max_int8 = np.iinfo(np.int8).max
test_x = np.array([max_int8, 5, max_int8, -max_int8, 5, -max_int8], dtype=np.int8)
test_x = np.broadcast_to(test_x, (6, 2, 6)).copy()
_, [res] = compare_numba_and_py(
[x],
[out],
[test_x],
)
if axis == -1:
assert np.all(res == 10)
elif axis == (0, -1):
assert np.all(res == 60)
elif axis is None:
assert res == 120
@pytest.mark.parametrize("axis", (1, None))
def test_CAReduce_acc_complex_out_float(axis):
x = matrix("x", dtype="complex128")
out = x.sum(dtype="float64", axis=axis)
test_x = np.array([[1 + 0.5j, 2 - 0.5j], [3 + 0.5j, 4 - 0.5j]], dtype="complex128")
compare_numba_and_py([x], [out], [test_x])
@pytest.mark.parametrize("axis", (-1, (0, -1), None))
def test_CAReduce_discrete_infinity_identity(axis):
rng = np.random.default_rng(337)
x = tensor3("x", dtype="int8")
out = x.max(axis)
compare_numba_and_py(
[x], [out], [rng.integers(-127, 127, size=(6, 6, 6)).astype("int8")]
)
def test_scalar_Elemwise_Clip():
a = pt.scalar("a")
b = pt.scalar("b")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论