提交 594f46b4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Jesse Grabowski

Cast to output, not input in numba dispatch of scalar Softplus

上级 bfcad6d1
...@@ -31,7 +31,6 @@ from pytensor.link.utils import ( ...@@ -31,7 +31,6 @@ from pytensor.link.utils import (
fgraph_to_python, fgraph_to_python,
) )
from pytensor.scalar.basic import ScalarType from pytensor.scalar.basic import ScalarType
from pytensor.scalar.math import Softplus
from pytensor.sparse import SparseTensorType from pytensor.sparse import SparseTensorType
from pytensor.tensor.basic import Nonzero from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
...@@ -607,25 +606,6 @@ def numba_funcify_Dot(op, node, **kwargs): ...@@ -607,25 +606,6 @@ def numba_funcify_Dot(op, node, **kwargs):
return dot return dot
@numba_funcify.register(Softplus)
def numba_funcify_Softplus(op, node, **kwargs):
x_dtype = np.dtype(node.inputs[0].dtype)
@numba_njit
def softplus(x):
if x < -37.0:
value = np.exp(x)
elif x < 18.0:
value = np.log1p(np.exp(x))
elif x < 33.3:
value = x + np.exp(-x)
else:
value = x
return direct_cast(value, x_dtype)
return softplus
@numba_funcify.register(Solve) @numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs): def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a assume_a = op.assume_a
...@@ -689,11 +669,6 @@ def numba_funcify_BatchedDot(op, node, **kwargs): ...@@ -689,11 +669,6 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
return batched_dot return batched_dot
# NOTE: The remaining `pytensor.tensor.blas` `Op`s appear unnecessary, because
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
# optimizations are apparently already performed by Numba
@numba_funcify.register(IfElse) @numba_funcify.register(IfElse)
def numba_funcify_IfElse(op, **kwargs): def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs n_outs = op.n_outs
......
...@@ -28,7 +28,7 @@ from pytensor.scalar.basic import ( ...@@ -28,7 +28,7 @@ from pytensor.scalar.basic import (
Second, Second,
Switch, Switch,
) )
from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus
@numba_funcify.register(ScalarOp) @numba_funcify.register(ScalarOp)
...@@ -312,3 +312,22 @@ def erfc(x): ...@@ -312,3 +312,22 @@ def erfc(x):
@numba_funcify.register(Erfc) @numba_funcify.register(Erfc)
def numba_funcify_Erfc(op, **kwargs): def numba_funcify_Erfc(op, **kwargs):
return numba_basic.global_numba_func(erfc) return numba_basic.global_numba_func(erfc)
@numba_funcify.register(Softplus)
def numba_funcify_Softplus(op, node, **kwargs):
out_dtype = np.dtype(node.outputs[0].type.dtype)
@numba_basic.numba_njit
def softplus(x):
if x < -37.0:
value = np.exp(x)
elif x < 18.0:
value = np.log1p(np.exp(x))
elif x < 33.3:
value = x + np.exp(-x)
else:
value = x
return numba_basic.direct_cast(value, out_dtype)
return softplus
...@@ -14,7 +14,6 @@ from tests.tensor.test_math_scipy import scipy ...@@ -14,7 +14,6 @@ from tests.tensor.test_math_scipy import scipy
numba = pytest.importorskip("numba") numba = pytest.importorskip("numba")
import pytensor.scalar as ps import pytensor.scalar as ps
import pytensor.scalar.math as psm
import pytensor.tensor as pt import pytensor.tensor as pt
import pytensor.tensor.math as ptm import pytensor.tensor.math as ptm
from pytensor import config, shared from pytensor import config, shared
...@@ -643,48 +642,6 @@ def test_Dot(x, y, exc): ...@@ -643,48 +642,6 @@ def test_Dot(x, y, exc):
) )
@pytest.mark.parametrize(
"x, exc",
[
(
(ps.float64(), np.array(0.0, dtype="float64")),
None,
),
(
(ps.float64(), np.array(-32.0, dtype="float64")),
None,
),
(
(ps.float64(), np.array(-40.0, dtype="float64")),
None,
),
(
(ps.float64(), np.array(32.0, dtype="float64")),
None,
),
(
(ps.float64(), np.array(40.0, dtype="float64")),
None,
),
(
(ps.int64(), np.array(32, dtype="int64")),
None,
),
],
)
def test_Softplus(x, exc):
x, x_test_value = x
g = psm.Softplus(ps.upgrade_to_float)(x)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x],
[g],
[x_test_value],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, y, exc", "x, y, exc",
[ [
......
...@@ -3,12 +3,13 @@ import pytest ...@@ -3,12 +3,13 @@ import pytest
import pytensor.scalar as ps import pytensor.scalar as ps
import pytensor.scalar.basic as psb import pytensor.scalar.basic as psb
import pytensor.scalar.math as psm
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config, function
from pytensor.scalar.basic import Composite from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py, numba_mode, py_mode
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -149,3 +150,37 @@ def test_isnan(composite): ...@@ -149,3 +150,37 @@ def test_isnan(composite):
[out], [out],
[np.array([1, 0], dtype="float64")], [np.array([1, 0], dtype="float64")],
) )
@pytest.mark.parametrize(
"dtype",
[
pytest.param(
"float32",
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
),
"float64",
pytest.param(
"int16",
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
),
"int64",
"uint32",
],
)
def test_Softplus(dtype):
x = ps.get_scalar_type(dtype)("x")
g = psm.softplus(x)
py_fn = function([x], g, mode=py_mode)
numba_fn = function([x], g, mode=numba_mode)
for value in (-40, -32, 0, 32, 40):
if value < 0 and dtype.startswith("u"):
continue
test_x = np.dtype(dtype).type(value)
np.testing.assert_allclose(
py_fn(test_x),
numba_fn(test_x),
strict=True,
err_msg=f"Failed for value {value}",
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论