提交 42587563 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba CAReduce: respect acc_dtype

Also fix infinity identities for unsigned integers
上级 0cc6314b
...@@ -2,7 +2,6 @@ from functools import singledispatch ...@@ -2,7 +2,6 @@ from functools import singledispatch
from hashlib import sha256 from hashlib import sha256
from textwrap import dedent, indent from textwrap import dedent, indent
import numba
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
...@@ -15,6 +14,7 @@ from pytensor.link.numba.cache import ( ...@@ -15,6 +14,7 @@ from pytensor.link.numba.cache import (
) )
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
create_tuple_string,
numba_funcify_and_cache_key, numba_funcify_and_cache_key,
register_funcify_and_cache_key, register_funcify_and_cache_key,
register_funcify_default_op_cache_key, register_funcify_default_op_cache_key,
...@@ -126,10 +126,12 @@ if {res}[{idx}] > {arr}: ...@@ -126,10 +126,12 @@ if {res}[{idx}] > {arr}:
def create_multiaxis_reducer( def create_multiaxis_reducer(
scalar_op, scalar_op,
*,
identity, identity,
axes, axes,
ndim, ndim,
dtype, acc_dtype=None,
out_dtype,
keepdims: bool = False, keepdims: bool = False,
): ):
r"""Construct a function that reduces multiple axes. r"""Construct a function that reduces multiple axes.
...@@ -139,17 +141,46 @@ def create_multiaxis_reducer( ...@@ -139,17 +141,46 @@ def create_multiaxis_reducer(
.. code-block:: python .. code-block:: python
def careduce_add(x): def careduce_add(x):
# For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add"
x_shape = x.shape x_shape = x.shape
res_shape = x_shape[2] res_shape = (x_shape[0], x_shape[1])
res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype) # identity = 0.0
res = np.full(res_shape, identity, dtype=np.float64)
for i0 in range(x_shape[0]):
for i1 in range(x_shape[1]):
for i2 in range(x_shape[2]):
res[i0, i1] += x[i0, i1, i2]
return res
If accumulation dtype differs from output_dtype
.. code-block:: python
def careduce_add(x):
x_shape = x.shape
res_shape = (x_shape[0], x_shape[1])
# identity = 0.0
res = np.full(res_shape, identity, dtype=np.float64)
for i0 in range(x_shape[0]): for i0 in range(x_shape[0]):
for i1 in range(x_shape[1]): for i1 in range(x_shape[1]):
for i2 in range(x_shape[2]): for i2 in range(x_shape[2]):
res[i2] += x[i0, i1, i2] res[i0, i1] += x[i0, i1, i2]
return res.astype(np.int32)
Full reductions accumulate on scalars
.. code-block:: python
def careduce_mul(x):
x_shape = x.shape
res_shape = ()
# identity = 1.0
res = identity
for i0 in range(x_shape[0]):
for i1 in range(x_shape[1]):
for i2 in range(x_shape[2]):
res *= x[i0, i1, i2]
return np.array(res, dtype=np.int32)
return res
Parameters Parameters
========== ==========
...@@ -161,7 +192,9 @@ def create_multiaxis_reducer( ...@@ -161,7 +192,9 @@ def create_multiaxis_reducer(
The axes to reduce. The axes to reduce.
ndim: ndim:
The number of dimensions of the input variable. The number of dimensions of the input variable.
dtype: acc_dtype: dtype, optional
The data type used during accumulation. Defaults to out_dtype if not provided
out_dtype:
The data type of the result. The data type of the result.
keepdims: boolean, default False keepdims: boolean, default False
Whether to keep the reduced dimensions. Whether to keep the reduced dimensions.
...@@ -179,19 +212,23 @@ def create_multiaxis_reducer( ...@@ -179,19 +212,23 @@ def create_multiaxis_reducer(
"Cannot keep multiple dimensions when reducing multiple axes" "Cannot keep multiple dimensions when reducing multiple axes"
) )
out_dtype = np.dtype(out_dtype)
acc_dtype = out_dtype if acc_dtype is None else np.dtype(acc_dtype)
# Numba doesn't allow converting complex to real with a simple `astype`
complex_to_real = acc_dtype.kind == "c" and out_dtype.kind != "c"
out_dtype_str = f"np.{out_dtype.name}"
acc_dtype_str = f"np.{acc_dtype.name}"
careduce_fn_name = f"careduce_{scalar_op}" careduce_fn_name = f"careduce_{scalar_op}"
identity = str(identity) if acc_dtype.kind in "ui" and not np.isfinite(identity):
if identity == "inf": if np.isposinf(identity):
identity = "np.inf" identity = np.iinfo(acc_dtype).max
elif identity == "-inf": else:
identity = "-np.inf" identity = np.iinfo(acc_dtype).min
global_env = { # Make sure it has the correct dtype
"np": np, identity = getattr(np, acc_dtype.name)(identity)
"numba_basic": numba_basic,
"out_dtype": dtype,
}
complete_reduction = len(axes) == ndim complete_reduction = len(axes) == ndim
kept_axis = tuple(i for i in range(ndim) if i not in axes) kept_axis = tuple(i for i in range(ndim) if i not in axes)
...@@ -209,17 +246,23 @@ def create_multiaxis_reducer( ...@@ -209,17 +246,23 @@ def create_multiaxis_reducer(
scalar_op, res_indices, "res", f"x[{arr_indices}]" scalar_op, res_indices, "res", f"x[{arr_indices}]"
) )
res_shape = f"({', '.join(f'x_shape[{i}]' for i in kept_axis)})" res_shape = create_tuple_string([f"x_shape[{i}]" for i in kept_axis])
if complete_reduction and ndim > 0: if complete_reduction and ndim > 0:
# We accumulate on a scalar, not an array # We accumulate on a scalar, not an array
res_creator = f"np.asarray({identity}).astype(out_dtype).item()" res_creator = "identity"
inplace_update_stmt = inplace_update_stmt.replace("res[()]", "res") inplace_update_stmt = inplace_update_stmt.replace("res[()]", "res")
return_obj = "np.asarray(res)" if complex_to_real:
return_obj = f"np.array(res).real.astype({out_dtype_str})"
else:
return_obj = f"np.array(res, dtype={out_dtype_str})"
else: else:
res_creator = ( res_creator = f"np.full(res_shape, identity, dtype={acc_dtype_str})"
f"np.full({res_shape}, np.asarray({identity}).item(), dtype=out_dtype)" if complex_to_real:
) return_obj = f"res.real.astype({out_dtype_str})"
return_obj = "res" else:
return_obj = (
"res" if out_dtype == acc_dtype else f"res.astype({out_dtype_str})"
)
if keepdims: if keepdims:
[axis] = axes [axis] = axes
...@@ -230,6 +273,7 @@ def create_multiaxis_reducer( ...@@ -230,6 +273,7 @@ def create_multiaxis_reducer(
def {careduce_fn_name}(x): def {careduce_fn_name}(x):
x_shape = x.shape x_shape = x.shape
res_shape = {res_shape} res_shape = {res_shape}
# identity = {identity}
res = {res_creator} res = {res_creator}
""" """
) )
...@@ -239,13 +283,12 @@ def create_multiaxis_reducer( ...@@ -239,13 +283,12 @@ def create_multiaxis_reducer(
" " * (4 + 4 * axis), " " * (4 + 4 * axis),
) )
careduce_def_src += indent(inplace_update_stmt, " " * (4 + 4 * ndim)) careduce_def_src += indent(inplace_update_stmt, " " * (4 + 4 * ndim))
careduce_def_src += "\n\n" careduce_def_src += "\n"
careduce_def_src += indent(f"return {return_obj}", " " * 4) careduce_def_src += indent(f"return {return_obj}", " " * 4)
careduce_fn = compile_numba_function_src( careduce_fn = compile_numba_function_src(
careduce_def_src, careduce_fn_name, {**globals(), **global_env} careduce_def_src, careduce_fn_name, globals() | {"np": np, "identity": identity}
) )
return careduce_fn return careduce_fn
...@@ -356,24 +399,18 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -356,24 +399,18 @@ def numba_funcify_CAReduce(op, node, **kwargs):
acc_dtype = op.acc_dtype acc_dtype = op.acc_dtype
else: else:
acc_dtype = node.outputs[0].type.dtype acc_dtype = node.outputs[0].type.dtype
np_acc_dtype = np.dtype(acc_dtype)
scalar_op_identity = op.scalar_op.identity
if np_acc_dtype.kind == "i" and not np.isfinite(scalar_op_identity):
if np.isposinf(scalar_op_identity):
scalar_op_identity = np.iinfo(np_acc_dtype).max
else:
scalar_op_identity = np.iinfo(np_acc_dtype).min
# Make sure it has the correct dtype
scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype)
out_dtype = np.dtype(node.outputs[0].type.dtype) out_dtype = np.dtype(node.outputs[0].type.dtype)
if isinstance(op, Sum) and node.inputs[0].ndim == len(axes): if (
isinstance(op, Sum)
and node.inputs[0].ndim == len(axes)
and out_dtype == acc_dtype
):
# Slightly faster for this case # Slightly faster for this case
@numba_basic.numba_njit @numba_basic.numba_njit
def impl_sum(array): def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) return np.array(array.sum())
careduce_fn = impl_sum # Some tests look for this name careduce_fn = impl_sum # Some tests look for this name
...@@ -381,16 +418,26 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -381,16 +418,26 @@ def numba_funcify_CAReduce(op, node, **kwargs):
ndim = node.inputs[0].ndim ndim = node.inputs[0].ndim
careduce_py_fn = create_multiaxis_reducer( careduce_py_fn = create_multiaxis_reducer(
op.scalar_op, op.scalar_op,
scalar_op_identity, identity=op.scalar_op.identity,
axes, axes=axes,
ndim, ndim=ndim,
out_dtype, acc_dtype=acc_dtype,
out_dtype=out_dtype,
) )
careduce_fn = numba_basic.numba_njit(careduce_py_fn, boundscheck=False) careduce_fn = numba_basic.numba_njit(careduce_py_fn, boundscheck=False)
cache_version = 1
careduce_key = sha256( careduce_key = sha256(
str( str(
(type(op), type(op.scalar_op), axes, acc_dtype, scalar_op_identity.item()) (
type(op),
type(op.scalar_op),
axes,
out_dtype,
acc_dtype,
op.scalar_op.identity,
cache_version,
)
).encode() ).encode()
).hexdigest() ).hexdigest()
return careduce_fn, careduce_key return careduce_fn, careduce_key
...@@ -449,18 +496,26 @@ def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs): ...@@ -449,18 +496,26 @@ def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
@register_funcify_default_op_cache_key(Softmax) @register_funcify_default_op_cache_key(Softmax)
def numba_funcify_Softmax(op, node, **kwargs): def numba_funcify_Softmax(op, node, **kwargs):
x_at = node.inputs[0] ndim = node.inputs[0].type.ndim
x_dtype = x_at.type.numpy_dtype inp_dtype = node.inputs[0].type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis axis = op.axis
if axis is not None: if ndim > 1 and axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_multiaxis_reducer( reduce_max_py = create_multiaxis_reducer(
maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True maximum,
identity=-np.inf,
axes=(axis,),
ndim=ndim,
out_dtype=inp_dtype,
keepdims=True,
) )
reduce_sum_py = create_multiaxis_reducer( reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True add_as,
identity=0.0,
axes=(axis,),
ndim=ndim,
out_dtype=inp_dtype,
keepdims=True,
) )
jit_fn = numba_basic.numba_njit(boundscheck=False) jit_fn = numba_basic.numba_njit(boundscheck=False)
...@@ -470,29 +525,32 @@ def numba_funcify_Softmax(op, node, **kwargs): ...@@ -470,29 +525,32 @@ def numba_funcify_Softmax(op, node, **kwargs):
reduce_max = np.max reduce_max = np.max
reduce_sum = np.sum reduce_sum = np.sum
def softmax_py_fn(x): @numba_basic.numba_njit(boundscheck=False)
def softmax(x):
z = reduce_max(x) z = reduce_max(x)
e_x = np.exp(x - z) e_x = np.exp(x - z)
w = reduce_sum(e_x) w = reduce_sum(e_x)
sm = e_x / w sm = e_x / w
return sm return sm
softmax = numba_basic.numba_njit(softmax_py_fn, boundscheck=False) cache_version = 1
return softmax, cache_version
return softmax
@register_funcify_default_op_cache_key(SoftmaxGrad) @register_funcify_default_op_cache_key(SoftmaxGrad)
def numba_funcify_SoftmaxGrad(op, node, **kwargs): def numba_funcify_SoftmaxGrad(op, node, **kwargs):
sm_at = node.inputs[1] ndim = node.inputs[0].type.ndim
sm_dtype = sm_at.type.numpy_dtype inp_dtype = node.inputs[0].type.numpy_dtype
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
axis = op.axis axis = op.axis
if axis is not None: if ndim > 1 and axis is not None:
axis = normalize_axis_index(axis, sm_at.ndim)
reduce_sum_py = create_multiaxis_reducer( reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True add_as,
identity=0.0,
axes=(axis,),
ndim=ndim,
out_dtype=inp_dtype,
keepdims=True,
) )
jit_fn = numba_basic.numba_njit(boundscheck=False) jit_fn = numba_basic.numba_njit(boundscheck=False)
...@@ -500,36 +558,39 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): ...@@ -500,36 +558,39 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
else: else:
reduce_sum = np.sum reduce_sum = np.sum
def softmax_grad_py_fn(dy, sm): @numba_basic.numba_njit(boundscheck=False)
def softmax_grad(dy, sm):
dy_times_sm = dy * sm dy_times_sm = dy * sm
sum_dy_times_sm = reduce_sum(dy_times_sm) sum_dy_times_sm = reduce_sum(dy_times_sm)
dx = dy_times_sm - sum_dy_times_sm * sm dx = dy_times_sm - sum_dy_times_sm * sm
return dx return dx
softmax_grad = numba_basic.numba_njit(softmax_grad_py_fn, boundscheck=False) cache_version = 1
return softmax_grad, cache_version
return softmax_grad
@register_funcify_default_op_cache_key(LogSoftmax) @register_funcify_default_op_cache_key(LogSoftmax)
def numba_funcify_LogSoftmax(op, node, **kwargs): def numba_funcify_LogSoftmax(op, node, **kwargs):
x_at = node.inputs[0] ndim = node.inputs[0].type.ndim
x_dtype = x_at.type.numpy_dtype inp_dtype = node.inputs[0].type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis axis = op.axis
if axis is not None: if ndim > 1 and axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_multiaxis_reducer( reduce_max_py = create_multiaxis_reducer(
maximum, maximum,
-np.inf, identity=-np.inf,
(axis,), axes=(axis,),
x_at.ndim, ndim=ndim,
x_dtype, out_dtype=inp_dtype,
keepdims=True, keepdims=True,
) )
reduce_sum_py = create_multiaxis_reducer( reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True add_as,
identity=0.0,
axes=(axis,),
ndim=ndim,
out_dtype=inp_dtype,
keepdims=True,
) )
jit_fn = numba_basic.numba_njit(boundscheck=False) jit_fn = numba_basic.numba_njit(boundscheck=False)
...@@ -539,13 +600,14 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -539,13 +600,14 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
reduce_max = np.max reduce_max = np.max
reduce_sum = np.sum reduce_sum = np.sum
def log_softmax_py_fn(x): @numba_basic.numba_njit(boundscheck=False)
def log_softmax(x):
xdev = x - reduce_max(x) xdev = x - reduce_max(x)
lsm = xdev - np.log(reduce_sum(np.exp(xdev))) lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
return lsm return lsm
log_softmax = numba_basic.numba_njit(log_softmax_py_fn, boundscheck=False) cache_version = 1
return log_softmax return log_softmax, cache_version
@register_funcify_default_op_cache_key(Argmax) @register_funcify_default_op_cache_key(Argmax)
......
...@@ -1391,7 +1391,10 @@ class CAReduce(COp): ...@@ -1391,7 +1391,10 @@ class CAReduce(COp):
return f"axes={list(axis)}" return f"axes={list(axis)}"
def __str__(self): def __str__(self):
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}" if self.acc_dtype != self.dtype:
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}, acc={self.acc_dtype}}}"
else:
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}"
def perform(self, node, inp, out): def perform(self, node, inp, out):
(input,) = inp (input,) = inp
......
...@@ -357,7 +357,10 @@ def max_and_argmax(a, axis=None, keepdims=False): ...@@ -357,7 +357,10 @@ def max_and_argmax(a, axis=None, keepdims=False):
class FixedOpCAReduce(CAReduce): class FixedOpCAReduce(CAReduce):
def __str__(self): def __str__(self):
return f"{type(self).__name__}{{{self._axis_str()}}}" if self.dtype != self.acc_dtype:
return f"{type(self).__name__}{{{self._axis_str()}, acc={self.acc_dtype}}}"
else:
return f"{type(self).__name__}{{{self._axis_str()}}}"
class NonZeroDimsCAReduce(FixedOpCAReduce): class NonZeroDimsCAReduce(FixedOpCAReduce):
......
...@@ -13,7 +13,7 @@ from pytensor.compile.ops import deep_copy_op ...@@ -13,7 +13,7 @@ from pytensor.compile.ops import deep_copy_op
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.scalar import Composite, float64 from pytensor.scalar import Composite, float64
from pytensor.scalar import add as scalar_add from pytensor.scalar import add as scalar_add
from pytensor.tensor import blas, tensor from pytensor.tensor import blas, matrix, tensor, tensor3
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
...@@ -366,6 +366,45 @@ def test_CAReduce(careduce_fn, axis, v): ...@@ -366,6 +366,45 @@ def test_CAReduce(careduce_fn, axis, v):
assert isinstance(node.op, CAReduce) assert isinstance(node.op, CAReduce)
@pytest.mark.parametrize("axis", (-1, (0, -1), None))
def test_CAReduce_respects_acc_dtype(axis):
x = tensor3("x", dtype="int8")
out = x.sum(dtype="int8", acc_dtype="int64", axis=axis)
# Choose values that would overflow if accumulated internally in int8
max_int8 = np.iinfo(np.int8).max
test_x = np.array([max_int8, 5, max_int8, -max_int8, 5, -max_int8], dtype=np.int8)
test_x = np.broadcast_to(test_x, (6, 2, 6)).copy()
_, [res] = compare_numba_and_py(
[x],
[out],
[test_x],
)
if axis == -1:
assert np.all(res == 10)
elif axis == (0, -1):
assert np.all(res == 60)
elif axis is None:
assert res == 120
@pytest.mark.parametrize("axis", (1, None))
def test_CAReduce_acc_complex_out_float(axis):
x = matrix("x", dtype="complex128")
out = x.sum(dtype="float64", axis=axis)
test_x = np.array([[1 + 0.5j, 2 - 0.5j], [3 + 0.5j, 4 - 0.5j]], dtype="complex128")
compare_numba_and_py([x], [out], [test_x])
@pytest.mark.parametrize("axis", (-1, (0, -1), None))
def test_CAReduce_discrete_infinity_identity(axis):
rng = np.random.default_rng(337)
x = tensor3("x", dtype="int8")
out = x.max(axis)
compare_numba_and_py(
[x], [out], [rng.integers(-127, 127, size=(6, 6, 6)).astype("int8")]
)
def test_scalar_Elemwise_Clip(): def test_scalar_Elemwise_Clip():
a = pt.scalar("a") a = pt.scalar("a")
b = pt.scalar("b") b = pt.scalar("b")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论