提交 3d96ee80 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix storage handling in numba_funcify_Scan

上级 a2d05adc
import numpy as np
import pytest
import aesara.tensor as at
from aesara import config
from aesara import config, grad
from aesara.compile.mode import Mode, get_mode
from aesara.graph.fg import FunctionGraph
from aesara.scan.basic import scan
from aesara.scan.utils import until
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849)
def test_scan_multiple_output():
"""Test a scan implementation of a SEIR model.
SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
E[t+1] = E[t] + B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
"""
def binomln(n, k):
......@@ -198,3 +200,99 @@ def test_scan_multiple_none_output():
test_input_vals = (np.array([1.0, 2.0]),)
compare_numba_and_py(out_fg, test_input_vals)
def test_scan_save_mem_basic():
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
k = at.iscalar("k")
A = at.dvector("A")
result, _ = scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=at.ones_like(A),
non_sequences=A,
n_steps=k,
)
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([A, k], [result])
test_input_vals = (np.arange(10, dtype=np.int32), 2)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
test_input_vals = (np.arange(10, dtype=np.int32), 4)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
@pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_2(n_steps_val):
def f_pow2(x_tm2, x_tm1):
return 2 * x_tm1 + x_tm2
init_x = at.dvector("init_x")
n_steps = at.iscalar("n_steps")
output, _ = scan(
f_pow2,
sequences=[],
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
non_sequences=[],
n_steps=n_steps,
)
state_val = np.array([1.0, 2.0])
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([init_x, n_steps], [output])
test_input_vals = (state_val, n_steps_val)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
def test_grad_sitsot():
def get_sum_of_grad(inp):
scan_outputs, updates = scan(
fn=lambda x: x * 2, outputs_info=[inp], n_steps=5, mode="NUMBA"
)
return grad(scan_outputs.sum(), inp).sum()
floatX = config.floatX
inputs_test_values = [
np.random.default_rng(utt.fetch_seed()).random(3).astype(floatX)
]
utt.verify_grad(get_sum_of_grad, inputs_test_values, mode="NUMBA")
def test_mitmots_basic():
init_x = at.dvector()
seq = at.dvector()
def inner_fct(seq, state_old, state_current):
return state_old * 2 + state_current + seq
out, _ = scan(
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]}
)
g_outs = grad(out.sum(), [seq, init_x])
numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([seq, init_x], g_outs)
seq_val = np.arange(3)
init_x_val = np.r_[-2, -1]
test_input_vals = (seq_val, init_x_val)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论