提交 88cc33b6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix Scan JAX dispatcher

上级 7b609047
import re
import numpy as np
import pytest
from packaging.version import parse as version_parse
import pytensor.tensor as at
from pytensor import function, shared
from pytensor.compile import get_mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.scan import until
from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
from pytensor.tensor import random
from pytensor.tensor.math import gammaln, log
from pytensor.tensor.type import ivector, lscalar, scalar
from pytensor.tensor.type import lscalar, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax")
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_scan_multiple_output():
@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)])
def test_scan_sit_sot(view):
x0 = at.scalar("x0", dtype="float64")
xs, _ = scan(
lambda xtm1: xtm1 + 1,
outputs_info=[x0],
n_steps=10,
)
if view:
xs = xs[view]
fg = FunctionGraph([x0], [xs])
test_input_vals = [np.e]
compare_jax_and_py(fg, test_input_vals)
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
def test_scan_mit_sot(view):
x0 = at.vector("x0", dtype="float64", shape=(3,))
xs, _ = scan(
lambda xtm3, xtm1: xtm3 + xtm1 + 1,
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
n_steps=10,
)
if view:
xs = xs[view]
fg = FunctionGraph([x0], [xs])
test_input_vals = [np.full((3,), np.e)]
compare_jax_and_py(fg, test_input_vals)
@pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)])
@pytest.mark.parametrize("view_y", [None, (-1,), slice(-4, -1, None)])
def test_scan_multiple_mit_sot(view_x, view_y):
x0 = at.vector("x0", dtype="float64", shape=(3,))
y0 = at.vector("y0", dtype="float64", shape=(4,))
def step(xtm3, xtm1, ytm4, ytm2):
return xtm3 + ytm4 + 1, xtm1 + ytm2 + 2
[xs, ys], _ = scan(
fn=step,
outputs_info=[
{"initial": x0, "taps": [-3, -1]},
{"initial": y0, "taps": [-4, -2]},
],
n_steps=10,
)
if view_x:
xs = xs[view_x]
if view_y:
ys = ys[view_y]
fg = FunctionGraph([x0, y0], [xs, ys])
test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)]
compare_jax_and_py(fg, test_input_vals)
@pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)])
def test_scan_nit_sot(view):
rng = np.random.default_rng(seed=49)
xs = at.vector("x0", dtype="float64", shape=(10,))
ys, _ = scan(
lambda x: at.exp(x),
outputs_info=[None],
sequences=[xs],
)
if view:
ys = ys[view]
fg = FunctionGraph([xs], [ys])
test_input_vals = [rng.normal(size=10)]
# We need to remove pushout rewrites, or the whole scan would just be
# converted to an Elemwise on xs
jax_fn, _ = compare_jax_and_py(
fg, test_input_vals, jax_mode=get_mode("JAX").excluding("scan_pushout")
)
scan_nodes = [
node for node in jax_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
assert len(scan_nodes) == 1
@pytest.mark.xfail(raises=NotImplementedError)
def test_scan_mit_mot():
xs = at.vector("xs", shape=(10,))
ys, _ = scan(
lambda xtm2, xtm1: (xtm2 + xtm1),
outputs_info=[{"initial": xs, "taps": [-2, -1]}],
n_steps=10,
)
grads_wrt_xs = at.grad(ys.sum(), wrt=xs)
fg = FunctionGraph([xs], [grads_wrt_xs])
compare_jax_and_py(fg, [np.arange(10)])
def test_scan_update():
sh_static = shared(np.array(0.0), name="sh_static")
sh_update = shared(np.array(1.0), name="sh_update")
xs, update = scan(
lambda sh_static, sh_update: (
sh_static + sh_update,
{sh_update: sh_update * 2},
),
outputs_info=[None],
non_sequences=[sh_static, sh_update],
strict=True,
n_steps=7,
)
jax_fn = function([], xs, updates=update, mode="JAX")
np.testing.assert_array_equal(jax_fn(), np.array([1, 2, 4, 8, 16, 32, 64]) + 0.0)
sh_static.set_value(1.0)
np.testing.assert_array_equal(
jax_fn(), np.array([128, 256, 512, 1024, 2048, 4096, 8192]) + 1.0
)
sh_static.set_value(2.0)
sh_update.set_value(1.0)
np.testing.assert_array_equal(jax_fn(), np.array([1, 2, 4, 8, 16, 32, 64]) + 2.0)
def test_scan_rng_update():
rng = shared(np.random.default_rng(190), name="rng")
def update_fn(rng):
new_rng, x = random.normal(rng=rng).owner.outputs
return x, {rng: new_rng}
xs, update = scan(
update_fn,
outputs_info=[None],
non_sequences=[rng],
strict=True,
n_steps=10,
)
# Without updates
with pytest.warns(
UserWarning,
match=re.escape("[rng] will not be used in the compiled JAX graph"),
):
jax_fn = function([], [xs], updates=None, mode="JAX")
res1, res2 = jax_fn(), jax_fn()
assert np.unique(res1).size == 10
assert np.unique(res2).size == 10
np.testing.assert_array_equal(res1, res2)
# With updates
with pytest.warns(
UserWarning,
match=re.escape("[rng] will not be used in the compiled JAX graph"),
):
jax_fn = function([], [xs], updates=update, mode="JAX")
res1, res2 = jax_fn(), jax_fn()
assert np.unique(res1).size == 10
assert np.unique(res2).size == 10
assert np.all(np.not_equal(res1, res2))
@pytest.mark.xfail(raises=NotImplementedError)
def test_scan_while():
xs, _ = scan(
lambda x: (x + 1, until(x < 10)),
outputs_info=[at.zeros(())],
n_steps=100,
)
fg = FunctionGraph([], [xs])
compare_jax_and_py(fg, [])
def test_scan_SEIR():
"""Test a scan implementation of a SEIR model.
SEIR model definition:
......@@ -38,8 +216,8 @@ def test_jax_scan_multiple_output():
return binomln(n, value) + value * log(p) + (n - value) * log(1 - p)
# sequences
at_C = ivector("C_t")
at_D = ivector("D_t")
at_C = vector("C_t", dtype="int32", shape=(8,))
at_D = vector("D_t", dtype="int32", shape=(8,))
# outputs_info (initial conditions)
st0 = lscalar("s_t0")
et0 = lscalar("e_t0")
......@@ -108,11 +286,7 @@ def test_jax_scan_multiple_output():
compare_jax_and_py(out_fg, test_input_vals)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_scan_tap_output():
def test_scan_mitsot_with_nonseq():
a_at = scalar("a")
def input_step_fn(y_tm1, y_tm3, a):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论