提交 15fba0e3 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added Numba Scan implementation

上级 e51e8787
......@@ -8,5 +8,6 @@ import aesara.link.numba.dispatch.extra_ops
import aesara.link.numba.dispatch.nlinalg
import aesara.link.numba.dispatch.random
import aesara.link.numba.dispatch.elemwise
import aesara.link.numba.dispatch.scan
# isort: on
import numba
import numpy as np
from aesara.graph.fg import FunctionGraph
from aesara.link.numba.dispatch.basic import create_tuple_string, numba_funcify
from aesara.link.utils import compile_function_src
from aesara.scan.op import Scan
def idx_to_str(idx):
res = "[i"
if idx < 0:
res += str(idx)
elif idx > 0:
res += "+" + str(idx)
return res + "]"
@numba_funcify.register(Scan)
def numba_funcify_Scan(op, node, **kwargs):
inner_fg = FunctionGraph(op.inputs, op.outputs)
numba_aet_inner_func = numba.njit(numba_funcify(inner_fg, **kwargs))
n_seqs = op.info.n_seqs
n_mit_mot = op.info.n_mit_mot
n_mit_sot = op.info.n_mit_sot
n_nit_sot = op.info.n_nit_sot
n_sit_sot = op.info.n_sit_sot
tap_array = op.info.tap_array
n_shared_outs = op.info.n_shared_outs
mit_mot_in_taps = tuple(tap_array[:n_mit_mot])
mit_sot_in_taps = tuple(tap_array[n_mit_mot : n_mit_mot + n_mit_sot])
p_in_mit_mot = n_seqs
p_in_mit_sot = p_in_mit_mot + n_mit_mot
p_in_sit_sot = p_in_mit_sot + n_mit_sot
p_outer_in_shared = p_in_sit_sot + n_sit_sot
p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs
p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot
input_names = [n.auto_name for n in node.inputs[1:]]
outer_in_seqs_names = input_names[:n_seqs]
outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot]
outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot]
outer_in_sit_sot_names = input_names[p_in_sit_sot : p_in_sit_sot + n_sit_sot]
outer_in_shared_names = input_names[
p_outer_in_shared : p_outer_in_shared + n_shared_outs
]
outer_in_nit_sot_names = input_names[
p_outer_in_nit_sot : p_outer_in_nit_sot + n_nit_sot
]
outer_in_feedback_names = input_names[n_seqs:p_outer_in_non_seqs]
outer_in_non_seqs_names = input_names[p_outer_in_non_seqs:]
inner_in_indexed = []
inner_out_indexed = []
allocate_mem_to_nit_sot = ""
for _name in outer_in_seqs_names:
# TODO:Index sould be updating according to sequence's taps
index = "[i]"
inner_in_indexed.append(_name + index)
name_to_input_map = dict(zip(input_names, node.inputs[1:]))
mit_sot_name_to_taps = dict(zip(outer_in_mit_sot_names, mit_sot_in_taps))
for _name in outer_in_feedback_names:
if _name in outer_in_mit_sot_names:
curr_taps = mit_sot_name_to_taps[_name]
min_tap = min(*curr_taps)
for _tap in curr_taps:
index = idx_to_str(_tap - min_tap)
inner_in_indexed.append(_name + index)
index = idx_to_str(-min_tap)
inner_out_indexed.append(_name + index)
if _name in outer_in_sit_sot_names:
# TODO: Input according to taps
index = "[i]"
inner_in_indexed.append(_name + index)
index = "[i+1]"
inner_out_indexed.append(_name + index)
if _name in outer_in_nit_sot_names:
# TODO: Allocate this properly
index = "[i]"
inner_out_indexed.append(_name + index)
allocate_mem_to_nit_sot += f"""
{_name} = np.zeros(n_steps)
"""
# The non_seqs are passed to inner function as-is
inner_in_indexed += outer_in_non_seqs_names
global_env = locals()
global_env["np"] = np
scan_op_src = f"""
def scan(n_steps, {", ".join(input_names)}):
outer_in_seqs = {create_tuple_string(outer_in_seqs_names)}
outer_in_mit_sot = {create_tuple_string(outer_in_mit_sot_names)}
outer_in_sit_sot = {create_tuple_string(outer_in_sit_sot_names)}
outer_in_shared = {create_tuple_string(outer_in_shared_names)}
outer_in_non_seqs = {create_tuple_string(outer_in_non_seqs_names)}
{allocate_mem_to_nit_sot}
outer_in_nit_sot = {create_tuple_string(outer_in_nit_sot_names)}
for i in range(n_steps):
inner_args = {create_tuple_string(inner_in_indexed)}
{create_tuple_string(inner_out_indexed)} = numba_aet_inner_func(*inner_args)
return (
outer_in_mit_sot +
outer_in_sit_sot +
outer_in_nit_sot
)
"""
scalar_op_fn = compile_function_src(scan_op_src, "scan", global_env)
return numba.njit(scalar_op_fn)
......@@ -28,6 +28,7 @@ from aesara.graph.type import Type
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite
from aesara.scan.basic import scan
from aesara.tensor import blas
from aesara.tensor import elemwise as aet_elemwise
from aesara.tensor import extra_ops, nlinalg, slinalg
......@@ -2893,3 +2894,144 @@ def test_random_Generator():
if not isinstance(i, (SharedVariable, Constant))
],
)
def test_scan_multiple_output():
"""Test a scan implementation of a SEIR model.
SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
"""
def binomln(n, k):
return aet.exp(n + 1) - aet.exp(k + 1) - aet.exp(n - k + 1)
def binom_log_prob(n, p, value):
return binomln(n, value) + value * aet.exp(p) + (n - value) * aet.exp(1 - p)
# sequences
aet_C = aet.ivector("C_t")
aet_D = aet.ivector("D_t")
# outputs_info (initial conditions)
st0 = aet.lscalar("s_t0")
et0 = aet.lscalar("e_t0")
it0 = aet.lscalar("i_t0")
logp_c = aet.scalar("logp_c")
logp_d = aet.scalar("logp_d")
# non_sequences
beta = aet.scalar("beta")
gamma = aet.scalar("gamma")
delta = aet.scalar("delta")
def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
bt0 = st0 * beta
bt0 = bt0.astype(st0.dtype)
logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype)
logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype)
st1 = st0 - bt0
et1 = et0 + bt0 - ct0
it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan(
fn=seir_one_step,
sequences=[aet_C, aet_D],
outputs_info=[st0, et0, it0, logp_c, logp_d],
non_sequences=[beta, gamma, delta],
)
st.name = "S_t"
et.name = "E_t"
it.name = "I_t"
logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp"
out_fg = FunctionGraph(
[aet_C, aet_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
)
s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0, dtype=config.floatX)
logp_d0 = np.array(0.0, dtype=config.floatX)
beta_val, gamma_val, delta_val = [
np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753]
]
C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32)
D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32)
test_input_vals = [
C,
D,
s0,
e0,
i0,
logp_c0,
logp_d0,
beta_val,
gamma_val,
delta_val,
]
compare_numba_and_py(out_fg, test_input_vals)
@config.change_flags(compute_test_value="raise")
def test_scan_tap_output():
a_aet = aet.scalar("a")
a_aet.tag.test_value = 10.0
b_aet = aet.arange(10).astype(config.floatX)
b_aet.name = "b"
c_aet = aet.arange(20, 30, dtype=config.floatX)
c_aet.name = "c"
def input_step_fn(b, c, x_tm1, y_tm1, y_tm3, a):
x_tm1.name = "x_tm1"
y_tm1.name = "y_tm1"
y_tm3.name = "y_tm3"
y_t = (y_tm1 + y_tm3) * a + b
z_t = y_t * c
x_t = x_tm1 + 1
x_t.name = "x_t"
y_t.name = "y_t"
return x_t, y_t, z_t
scan_res, _ = scan(
fn=input_step_fn,
sequences=[b_aet, c_aet],
outputs_info=[
{
"initial": aet.as_tensor_variable(0.0, dtype=config.floatX),
"taps": [-1],
},
{
"initial": aet.as_tensor_variable(
np.r_[-1.0, 1.3, 0.0].astype(config.floatX)
),
"taps": [-1, -3],
},
None,
],
non_sequences=[a_aet],
# n_steps=10,
name="yz_scan",
strict=True,
)
out_fg = FunctionGraph([a_aet, b_aet, c_aet], scan_res)
test_input_vals = [
np.array(10.0).astype(config.floatX),
np.arange(10, dtype=config.floatX),
np.arange(20, 30, dtype=config.floatX),
]
compare_numba_and_py(out_fg, test_input_vals)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论