提交 3d96ee80 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix storage handling in numba_funcify_Scan

上级 a2d05adc
import numpy as np import numpy as np
import pytest
import aesara.tensor as at import aesara.tensor as at
from aesara import config from aesara import config, grad
from aesara.compile.mode import Mode, get_mode
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.scan.basic import scan from aesara.scan.basic import scan
from aesara.scan.utils import until from aesara.scan.utils import until
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849)
def test_scan_multiple_output(): def test_scan_multiple_output():
"""Test a scan implementation of a SEIR model. """Test a scan implementation of a SEIR model.
SEIR model definition: SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta) S[t+1] = S[t] - B[t]
C[t] ~ Binom(E[t], gamma) E[t+1] = E[t] + B[t] - C[t]
D[t] ~ Binom(I[t], delta) I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
""" """
def binomln(n, k): def binomln(n, k):
...@@ -198,3 +200,99 @@ def test_scan_multiple_none_output(): ...@@ -198,3 +200,99 @@ def test_scan_multiple_none_output():
test_input_vals = (np.array([1.0, 2.0]),) test_input_vals = (np.array([1.0, 2.0]),)
compare_numba_and_py(out_fg, test_input_vals) compare_numba_and_py(out_fg, test_input_vals)
def test_scan_save_mem_basic():
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
k = at.iscalar("k")
A = at.dvector("A")
result, _ = scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=at.ones_like(A),
non_sequences=A,
n_steps=k,
)
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([A, k], [result])
test_input_vals = (np.arange(10, dtype=np.int32), 2)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
test_input_vals = (np.arange(10, dtype=np.int32), 4)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
@pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_2(n_steps_val):
def f_pow2(x_tm2, x_tm1):
return 2 * x_tm1 + x_tm2
init_x = at.dvector("init_x")
n_steps = at.iscalar("n_steps")
output, _ = scan(
f_pow2,
sequences=[],
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
non_sequences=[],
n_steps=n_steps,
)
state_val = np.array([1.0, 2.0])
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([init_x, n_steps], [output])
test_input_vals = (state_val, n_steps_val)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
def test_grad_sitsot():
def get_sum_of_grad(inp):
scan_outputs, updates = scan(
fn=lambda x: x * 2, outputs_info=[inp], n_steps=5, mode="NUMBA"
)
return grad(scan_outputs.sum(), inp).sum()
floatX = config.floatX
inputs_test_values = [
np.random.default_rng(utt.fetch_seed()).random(3).astype(floatX)
]
utt.verify_grad(get_sum_of_grad, inputs_test_values, mode="NUMBA")
def test_mitmots_basic():
init_x = at.dvector()
seq = at.dvector()
def inner_fct(seq, state_old, state_current):
return state_old * 2 + state_current + seq
out, _ = scan(
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]}
)
g_outs = grad(out.sum(), [seq, init_x])
numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([seq, init_x], g_outs)
seq_val = np.arange(3)
init_x_val = np.r_[-2, -1]
test_input_vals = (seq_val, init_x_val)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论