提交 9ada9455 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Provide JAX Ops from Optional tfp dependency

上级 8ac8342d
......@@ -145,7 +145,7 @@ jobs:
# PyTensor next, pip installs a lower version of numpy via the PyPI.
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
pip install -e ./
mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
......
import functools
import typing
from typing import Callable, Optional
import jax
import jax.numpy as jnp
......@@ -18,7 +20,21 @@ from pytensor.scalar.basic import (
Second,
Sub,
)
from pytensor.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi
from pytensor.scalar.math import Erf, Erfc, Erfcinv, Erfcx, Erfinv, Iv, Log1mexp, Psi
def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Callable:
try:
import tensorflow_probability.substrates.jax.math as tfp_jax_math
except ModuleNotFoundError:
raise NotImplementedError(
f"No JAX implementation for Op {op.name}. "
"Implementation is available if TensorFlow Probability is installed"
)
if jax_op_name is None:
jax_op_name = op.name
return typing.cast(Callable, getattr(tfp_jax_math, jax_op_name))
def check_if_inputs_scalars(node):
......@@ -211,6 +227,24 @@ def jax_funcify_Erfinv(op, **kwargs):
return erfinv
@jax_funcify.register(Erfcx)
@jax_funcify.register(Erfcinv)
def jax_funcify_from_tfp(op, **kwargs):
tfp_jax_op = try_import_tfp_jax_op(op)
return tfp_jax_op
@jax_funcify.register(Iv)
def jax_funcify_Iv(op, **kwargs):
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
def iv(v, x):
return ive(v, x) / jnp.exp(-jnp.abs(jnp.real(x)))
return iv
@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
def log1mexp(x):
......
......@@ -7,13 +7,17 @@ from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.scalar.basic import Composite
from pytensor.tensor import as_tensor
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import (
cosh,
erf,
erfc,
erfcinv,
erfcx,
erfinv,
iv,
log,
log1mexp,
psi,
......@@ -28,6 +32,14 @@ jax = pytest.importorskip("jax")
from pytensor.link.jax.dispatch import jax_funcify
try:
pass
TFP_INSTALLED = True
except ModuleNotFoundError:
TFP_INSTALLED = False
def test_second():
a0 = scalar("a0")
b = scalar("b")
......@@ -134,6 +146,23 @@ def test_erfinv():
compare_jax_and_py(fg, [0.95])
@pytest.mark.parametrize(
"op, test_values",
[
(erfcx, (0.7,)),
(erfcinv, (0.7,)),
(iv, (0.3, 0.7)),
],
)
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")
def test_tfp_ops(op, test_values):
inputs = [as_tensor(test_value).type() for test_value in test_values]
output = op(*inputs)
fg = FunctionGraph(inputs, [output])
compare_jax_and_py(fg, test_values)
def test_psi():
x = scalar("x")
out = psi(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论