提交 ea267946 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba uint: fix Sigmoid and Softplus with uint inputs

上级 f6ebb5ca
......@@ -31,10 +31,10 @@ from pytensor.scalar.basic import (
from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus
def scalar_op_cache_key(op):
def scalar_op_cache_key(op, **extra_fields):
# Scalar Ops don't have _props, because of their weird outputs_types_preference function
# So we create hash differently
return sha256(str(type(op)).encode()).hexdigest()
return sha256(str((type(op), tuple(extra_fields.items()))).encode()).hexdigest()
@register_funcify_and_cache_key(ScalarOp)
......@@ -267,11 +267,28 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
@register_funcify_and_cache_key(Sigmoid)
def numba_funcify_Sigmoid(op, node, **kwargs):
@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))
inp_dtype = node.inputs[0].type.dtype
if inp_dtype.startswith("uint"):
upcast_uint_dtype = {
"uint8": np.float32, # numpy uses float16, but not Numba
"uint16": np.float32,
"uint32": np.float64,
"uint64": np.float64,
}[inp_dtype]
@numba_basic.numba_njit
def sigmoid(x):
# Can't negate uint
float_x = numba_basic.direct_cast(x, upcast_uint_dtype)
return 1 / (1 + np.exp(-float_x))
else:
return sigmoid, scalar_op_cache_key(op)
@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))
return sigmoid, scalar_op_cache_key(op, cache_version=1)
@register_funcify_and_cache_key(GammaLn)
......@@ -319,6 +336,16 @@ def numba_funcify_Erfc(op, **kwargs):
@register_funcify_and_cache_key(Softplus)
def numba_funcify_Softplus(op, node, **kwargs):
inp_dtype = node.inputs[0].type.dtype
if inp_dtype.startswith("uint"):
upcast_uint_dtype = {
"uint8": np.float32, # numpy uses float16, but not Numba
"uint16": np.float32,
"uint32": np.float64,
"uint64": np.float64,
}[inp_dtype]
else:
upcast_uint_dtype = None
out_dtype = np.dtype(node.outputs[0].type.dtype)
@numba_basic.numba_njit
......@@ -328,9 +355,12 @@ def numba_funcify_Softplus(op, node, **kwargs):
elif x < 18.0:
value = np.log1p(np.exp(x))
elif x < 33.3:
if upcast_uint_dtype is not None:
# Can't negate uint
x = numba_basic.direct_cast(x, upcast_uint_dtype)
value = x + np.exp(-x)
else:
value = x
return numba_basic.direct_cast(value, out_dtype)
return softplus, scalar_op_cache_key(op)
return softplus, scalar_op_cache_key(op, cache_version=1)
......@@ -1259,7 +1259,8 @@ class Softplus(UnaryScalarOp):
def impl(self, x):
# If x is an int8 or uint8, numpy.exp will compute the result in
# half-precision (float16), where we want float32.
not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8")
x_dtype = getattr(x, "dtype", None)
not_int8 = x_dtype is None or x_dtype.itemsize > 1
if x < -37.0:
return np.exp(x) if not_int8 else np.exp(x, signature="f")
elif x < 18.0:
......@@ -1267,6 +1268,9 @@ class Softplus(UnaryScalarOp):
np.log1p(np.exp(x)) if not_int8 else np.log1p(np.exp(x, signature="f"))
)
elif x < 33.3:
if x_dtype is not None and x_dtype.kind == "u":
# Negate uint will not do what we want
x = x.astype("float32" if x_dtype.itemsize <= 2 else "float64")
return x + np.exp(-x) if not_int8 else x + np.exp(-x, signature="f")
else:
return x
......
......@@ -158,15 +158,9 @@ def test_isnan(composite):
@pytest.mark.parametrize(
"dtype",
[
pytest.param(
"float32",
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
),
"float32",
"float64",
pytest.param(
"int16",
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
),
"int16",
"int64",
"uint32",
],
......@@ -183,7 +177,7 @@ def test_Softplus(dtype):
test_x = np.dtype(dtype).type(value)
np.testing.assert_allclose(
py_fn(test_x),
numba_fn(test_x),
getattr(np, g.dtype)(numba_fn(test_x)),
strict=True,
err_msg=f"Failed for value {value}",
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论