提交 594f46b4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Jesse Grabowski

Cast to output, not input in numba dispatch of scalar Softplus

上级 bfcad6d1
......@@ -31,7 +31,6 @@ from pytensor.link.utils import (
fgraph_to_python,
)
from pytensor.scalar.basic import ScalarType
from pytensor.scalar.math import Softplus
from pytensor.sparse import SparseTensorType
from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot
......@@ -607,25 +606,6 @@ def numba_funcify_Dot(op, node, **kwargs):
return dot
@numba_funcify.register(Softplus)
def numba_funcify_Softplus(op, node, **kwargs):
x_dtype = np.dtype(node.inputs[0].dtype)
@numba_njit
def softplus(x):
if x < -37.0:
value = np.exp(x)
elif x < 18.0:
value = np.log1p(np.exp(x))
elif x < 33.3:
value = x + np.exp(-x)
else:
value = x
return direct_cast(value, x_dtype)
return softplus
@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
......@@ -689,11 +669,6 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
return batched_dot
# NOTE: The remaining `pytensor.tensor.blas` `Op`s appear unnecessary, because
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
# optimizations are apparently already performed by Numba
@numba_funcify.register(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
......
......@@ -28,7 +28,7 @@ from pytensor.scalar.basic import (
Second,
Switch,
)
from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid
from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus
@numba_funcify.register(ScalarOp)
......@@ -312,3 +312,22 @@ def erfc(x):
@numba_funcify.register(Erfc)
def numba_funcify_Erfc(op, **kwargs):
return numba_basic.global_numba_func(erfc)
@numba_funcify.register(Softplus)
def numba_funcify_Softplus(op, node, **kwargs):
out_dtype = np.dtype(node.outputs[0].type.dtype)
@numba_basic.numba_njit
def softplus(x):
if x < -37.0:
value = np.exp(x)
elif x < 18.0:
value = np.log1p(np.exp(x))
elif x < 33.3:
value = x + np.exp(-x)
else:
value = x
return numba_basic.direct_cast(value, out_dtype)
return softplus
......@@ -14,7 +14,6 @@ from tests.tensor.test_math_scipy import scipy
numba = pytest.importorskip("numba")
import pytensor.scalar as ps
import pytensor.scalar.math as psm
import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor import config, shared
......@@ -643,48 +642,6 @@ def test_Dot(x, y, exc):
)
@pytest.mark.parametrize(
"x, exc",
[
(
(ps.float64(), np.array(0.0, dtype="float64")),
None,
),
(
(ps.float64(), np.array(-32.0, dtype="float64")),
None,
),
(
(ps.float64(), np.array(-40.0, dtype="float64")),
None,
),
(
(ps.float64(), np.array(32.0, dtype="float64")),
None,
),
(
(ps.float64(), np.array(40.0, dtype="float64")),
None,
),
(
(ps.int64(), np.array(32, dtype="int64")),
None,
),
],
)
def test_Softplus(x, exc):
x, x_test_value = x
g = psm.Softplus(ps.upgrade_to_float)(x)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x],
[g],
[x_test_value],
)
@pytest.mark.parametrize(
"x, y, exc",
[
......
......@@ -3,12 +3,13 @@ import pytest
import pytensor.scalar as ps
import pytensor.scalar.basic as psb
import pytensor.scalar.math as psm
import pytensor.tensor as pt
from pytensor import config
from pytensor import config, function
from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode, py_mode
rng = np.random.default_rng(42849)
......@@ -149,3 +150,37 @@ def test_isnan(composite):
[out],
[np.array([1, 0], dtype="float64")],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(
"float32",
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
),
"float64",
pytest.param(
"int16",
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
),
"int64",
"uint32",
],
)
def test_Softplus(dtype):
x = ps.get_scalar_type(dtype)("x")
g = psm.softplus(x)
py_fn = function([x], g, mode=py_mode)
numba_fn = function([x], g, mode=numba_mode)
for value in (-40, -32, 0, 32, 40):
if value < 0 and dtype.startswith("u"):
continue
test_x = np.dtype(dtype).type(value)
np.testing.assert_allclose(
py_fn(test_x),
numba_fn(test_x),
strict=True,
err_msg=f"Failed for value {value}",
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论