提交 ea267946 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba uint: fix Sigmoid and Softplus with uint inputs

上级 f6ebb5ca
...@@ -31,10 +31,10 @@ from pytensor.scalar.basic import ( ...@@ -31,10 +31,10 @@ from pytensor.scalar.basic import (
from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus
def scalar_op_cache_key(op): def scalar_op_cache_key(op, **extra_fields):
# Scalar Ops don't have _props, because of their weird outputs_types_preference function # Scalar Ops don't have _props, because of their weird outputs_types_preference function
# So we create hash differently # So we create hash differently
return sha256(str(type(op)).encode()).hexdigest() return sha256(str((type(op), tuple(extra_fields.items()))).encode()).hexdigest()
@register_funcify_and_cache_key(ScalarOp) @register_funcify_and_cache_key(ScalarOp)
...@@ -267,11 +267,28 @@ def numba_funcify_Reciprocal(op, node, **kwargs): ...@@ -267,11 +267,28 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
@register_funcify_and_cache_key(Sigmoid) @register_funcify_and_cache_key(Sigmoid)
def numba_funcify_Sigmoid(op, node, **kwargs): def numba_funcify_Sigmoid(op, node, **kwargs):
inp_dtype = node.inputs[0].type.dtype
if inp_dtype.startswith("uint"):
upcast_uint_dtype = {
"uint8": np.float32, # numpy uses float16, but not Numba
"uint16": np.float32,
"uint32": np.float64,
"uint64": np.float64,
}[inp_dtype]
@numba_basic.numba_njit
def sigmoid(x):
# Can't negate uint
float_x = numba_basic.direct_cast(x, upcast_uint_dtype)
return 1 / (1 + np.exp(-float_x))
else:
@numba_basic.numba_njit @numba_basic.numba_njit
def sigmoid(x): def sigmoid(x):
return 1 / (1 + np.exp(-x)) return 1 / (1 + np.exp(-x))
return sigmoid, scalar_op_cache_key(op) return sigmoid, scalar_op_cache_key(op, cache_version=1)
@register_funcify_and_cache_key(GammaLn) @register_funcify_and_cache_key(GammaLn)
...@@ -319,6 +336,16 @@ def numba_funcify_Erfc(op, **kwargs): ...@@ -319,6 +336,16 @@ def numba_funcify_Erfc(op, **kwargs):
@register_funcify_and_cache_key(Softplus) @register_funcify_and_cache_key(Softplus)
def numba_funcify_Softplus(op, node, **kwargs): def numba_funcify_Softplus(op, node, **kwargs):
inp_dtype = node.inputs[0].type.dtype
if inp_dtype.startswith("uint"):
upcast_uint_dtype = {
"uint8": np.float32, # numpy uses float16, but not Numba
"uint16": np.float32,
"uint32": np.float64,
"uint64": np.float64,
}[inp_dtype]
else:
upcast_uint_dtype = None
out_dtype = np.dtype(node.outputs[0].type.dtype) out_dtype = np.dtype(node.outputs[0].type.dtype)
@numba_basic.numba_njit @numba_basic.numba_njit
...@@ -328,9 +355,12 @@ def numba_funcify_Softplus(op, node, **kwargs): ...@@ -328,9 +355,12 @@ def numba_funcify_Softplus(op, node, **kwargs):
elif x < 18.0: elif x < 18.0:
value = np.log1p(np.exp(x)) value = np.log1p(np.exp(x))
elif x < 33.3: elif x < 33.3:
if upcast_uint_dtype is not None:
# Can't negate uint
x = numba_basic.direct_cast(x, upcast_uint_dtype)
value = x + np.exp(-x) value = x + np.exp(-x)
else: else:
value = x value = x
return numba_basic.direct_cast(value, out_dtype) return numba_basic.direct_cast(value, out_dtype)
return softplus, scalar_op_cache_key(op) return softplus, scalar_op_cache_key(op, cache_version=1)
...@@ -1259,7 +1259,8 @@ class Softplus(UnaryScalarOp): ...@@ -1259,7 +1259,8 @@ class Softplus(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.exp will compute the result in # If x is an int8 or uint8, numpy.exp will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8") x_dtype = getattr(x, "dtype", None)
not_int8 = x_dtype is None or x_dtype.itemsize > 1
if x < -37.0: if x < -37.0:
return np.exp(x) if not_int8 else np.exp(x, signature="f") return np.exp(x) if not_int8 else np.exp(x, signature="f")
elif x < 18.0: elif x < 18.0:
...@@ -1267,6 +1268,9 @@ class Softplus(UnaryScalarOp): ...@@ -1267,6 +1268,9 @@ class Softplus(UnaryScalarOp):
np.log1p(np.exp(x)) if not_int8 else np.log1p(np.exp(x, signature="f")) np.log1p(np.exp(x)) if not_int8 else np.log1p(np.exp(x, signature="f"))
) )
elif x < 33.3: elif x < 33.3:
if x_dtype is not None and x_dtype.kind == "u":
# Negate uint will not do what we want
x = x.astype("float32" if x_dtype.itemsize <= 2 else "float64")
return x + np.exp(-x) if not_int8 else x + np.exp(-x, signature="f") return x + np.exp(-x) if not_int8 else x + np.exp(-x, signature="f")
else: else:
return x return x
......
...@@ -158,15 +158,9 @@ def test_isnan(composite): ...@@ -158,15 +158,9 @@ def test_isnan(composite):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dtype", "dtype",
[ [
pytest.param(
"float32", "float32",
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
),
"float64", "float64",
pytest.param(
"int16", "int16",
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
),
"int64", "int64",
"uint32", "uint32",
], ],
...@@ -183,7 +177,7 @@ def test_Softplus(dtype): ...@@ -183,7 +177,7 @@ def test_Softplus(dtype):
test_x = np.dtype(dtype).type(value) test_x = np.dtype(dtype).type(value)
np.testing.assert_allclose( np.testing.assert_allclose(
py_fn(test_x), py_fn(test_x),
numba_fn(test_x), getattr(np, g.dtype)(numba_fn(test_x)),
strict=True, strict=True,
err_msg=f"Failed for value {value}", err_msg=f"Failed for value {value}",
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论