提交 15fba0e3 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added Numba Scan implementation

上级 e51e8787
...@@ -8,5 +8,6 @@ import aesara.link.numba.dispatch.extra_ops ...@@ -8,5 +8,6 @@ import aesara.link.numba.dispatch.extra_ops
import aesara.link.numba.dispatch.nlinalg import aesara.link.numba.dispatch.nlinalg
import aesara.link.numba.dispatch.random import aesara.link.numba.dispatch.random
import aesara.link.numba.dispatch.elemwise import aesara.link.numba.dispatch.elemwise
import aesara.link.numba.dispatch.scan
# isort: on # isort: on
import numba
import numpy as np
from aesara.graph.fg import FunctionGraph
from aesara.link.numba.dispatch.basic import create_tuple_string, numba_funcify
from aesara.link.utils import compile_function_src
from aesara.scan.op import Scan
def idx_to_str(idx):
res = "[i"
if idx < 0:
res += str(idx)
elif idx > 0:
res += "+" + str(idx)
return res + "]"
@numba_funcify.register(Scan)
def numba_funcify_Scan(op, node, **kwargs):
inner_fg = FunctionGraph(op.inputs, op.outputs)
numba_aet_inner_func = numba.njit(numba_funcify(inner_fg, **kwargs))
n_seqs = op.info.n_seqs
n_mit_mot = op.info.n_mit_mot
n_mit_sot = op.info.n_mit_sot
n_nit_sot = op.info.n_nit_sot
n_sit_sot = op.info.n_sit_sot
tap_array = op.info.tap_array
n_shared_outs = op.info.n_shared_outs
mit_mot_in_taps = tuple(tap_array[:n_mit_mot])
mit_sot_in_taps = tuple(tap_array[n_mit_mot : n_mit_mot + n_mit_sot])
p_in_mit_mot = n_seqs
p_in_mit_sot = p_in_mit_mot + n_mit_mot
p_in_sit_sot = p_in_mit_sot + n_mit_sot
p_outer_in_shared = p_in_sit_sot + n_sit_sot
p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs
p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot
input_names = [n.auto_name for n in node.inputs[1:]]
outer_in_seqs_names = input_names[:n_seqs]
outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot]
outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot]
outer_in_sit_sot_names = input_names[p_in_sit_sot : p_in_sit_sot + n_sit_sot]
outer_in_shared_names = input_names[
p_outer_in_shared : p_outer_in_shared + n_shared_outs
]
outer_in_nit_sot_names = input_names[
p_outer_in_nit_sot : p_outer_in_nit_sot + n_nit_sot
]
outer_in_feedback_names = input_names[n_seqs:p_outer_in_non_seqs]
outer_in_non_seqs_names = input_names[p_outer_in_non_seqs:]
inner_in_indexed = []
inner_out_indexed = []
allocate_mem_to_nit_sot = ""
for _name in outer_in_seqs_names:
# TODO:Index sould be updating according to sequence's taps
index = "[i]"
inner_in_indexed.append(_name + index)
name_to_input_map = dict(zip(input_names, node.inputs[1:]))
mit_sot_name_to_taps = dict(zip(outer_in_mit_sot_names, mit_sot_in_taps))
for _name in outer_in_feedback_names:
if _name in outer_in_mit_sot_names:
curr_taps = mit_sot_name_to_taps[_name]
min_tap = min(*curr_taps)
for _tap in curr_taps:
index = idx_to_str(_tap - min_tap)
inner_in_indexed.append(_name + index)
index = idx_to_str(-min_tap)
inner_out_indexed.append(_name + index)
if _name in outer_in_sit_sot_names:
# TODO: Input according to taps
index = "[i]"
inner_in_indexed.append(_name + index)
index = "[i+1]"
inner_out_indexed.append(_name + index)
if _name in outer_in_nit_sot_names:
# TODO: Allocate this properly
index = "[i]"
inner_out_indexed.append(_name + index)
allocate_mem_to_nit_sot += f"""
{_name} = np.zeros(n_steps)
"""
# The non_seqs are passed to inner function as-is
inner_in_indexed += outer_in_non_seqs_names
global_env = locals()
global_env["np"] = np
scan_op_src = f"""
def scan(n_steps, {", ".join(input_names)}):
outer_in_seqs = {create_tuple_string(outer_in_seqs_names)}
outer_in_mit_sot = {create_tuple_string(outer_in_mit_sot_names)}
outer_in_sit_sot = {create_tuple_string(outer_in_sit_sot_names)}
outer_in_shared = {create_tuple_string(outer_in_shared_names)}
outer_in_non_seqs = {create_tuple_string(outer_in_non_seqs_names)}
{allocate_mem_to_nit_sot}
outer_in_nit_sot = {create_tuple_string(outer_in_nit_sot_names)}
for i in range(n_steps):
inner_args = {create_tuple_string(inner_in_indexed)}
{create_tuple_string(inner_out_indexed)} = numba_aet_inner_func(*inner_args)
return (
outer_in_mit_sot +
outer_in_sit_sot +
outer_in_nit_sot
)
"""
scalar_op_fn = compile_function_src(scan_op_src, "scan", global_env)
return numba.njit(scalar_op_fn)
...@@ -28,6 +28,7 @@ from aesara.graph.type import Type ...@@ -28,6 +28,7 @@ from aesara.graph.type import Type
from aesara.link.numba.dispatch import basic as numba_basic from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.linker import NumbaLinker from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite from aesara.scalar.basic import Composite
from aesara.scan.basic import scan
from aesara.tensor import blas from aesara.tensor import blas
from aesara.tensor import elemwise as aet_elemwise from aesara.tensor import elemwise as aet_elemwise
from aesara.tensor import extra_ops, nlinalg, slinalg from aesara.tensor import extra_ops, nlinalg, slinalg
...@@ -2893,3 +2894,144 @@ def test_random_Generator(): ...@@ -2893,3 +2894,144 @@ def test_random_Generator():
if not isinstance(i, (SharedVariable, Constant)) if not isinstance(i, (SharedVariable, Constant))
], ],
) )
def test_scan_multiple_output():
"""Test a scan implementation of a SEIR model.
SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
"""
def binomln(n, k):
return aet.exp(n + 1) - aet.exp(k + 1) - aet.exp(n - k + 1)
def binom_log_prob(n, p, value):
return binomln(n, value) + value * aet.exp(p) + (n - value) * aet.exp(1 - p)
# sequences
aet_C = aet.ivector("C_t")
aet_D = aet.ivector("D_t")
# outputs_info (initial conditions)
st0 = aet.lscalar("s_t0")
et0 = aet.lscalar("e_t0")
it0 = aet.lscalar("i_t0")
logp_c = aet.scalar("logp_c")
logp_d = aet.scalar("logp_d")
# non_sequences
beta = aet.scalar("beta")
gamma = aet.scalar("gamma")
delta = aet.scalar("delta")
def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
bt0 = st0 * beta
bt0 = bt0.astype(st0.dtype)
logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype)
logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype)
st1 = st0 - bt0
et1 = et0 + bt0 - ct0
it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan(
fn=seir_one_step,
sequences=[aet_C, aet_D],
outputs_info=[st0, et0, it0, logp_c, logp_d],
non_sequences=[beta, gamma, delta],
)
st.name = "S_t"
et.name = "E_t"
it.name = "I_t"
logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp"
out_fg = FunctionGraph(
[aet_C, aet_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
)
s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0, dtype=config.floatX)
logp_d0 = np.array(0.0, dtype=config.floatX)
beta_val, gamma_val, delta_val = [
np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753]
]
C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32)
D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32)
test_input_vals = [
C,
D,
s0,
e0,
i0,
logp_c0,
logp_d0,
beta_val,
gamma_val,
delta_val,
]
compare_numba_and_py(out_fg, test_input_vals)
@config.change_flags(compute_test_value="raise")
def test_scan_tap_output():
a_aet = aet.scalar("a")
a_aet.tag.test_value = 10.0
b_aet = aet.arange(10).astype(config.floatX)
b_aet.name = "b"
c_aet = aet.arange(20, 30, dtype=config.floatX)
c_aet.name = "c"
def input_step_fn(b, c, x_tm1, y_tm1, y_tm3, a):
x_tm1.name = "x_tm1"
y_tm1.name = "y_tm1"
y_tm3.name = "y_tm3"
y_t = (y_tm1 + y_tm3) * a + b
z_t = y_t * c
x_t = x_tm1 + 1
x_t.name = "x_t"
y_t.name = "y_t"
return x_t, y_t, z_t
scan_res, _ = scan(
fn=input_step_fn,
sequences=[b_aet, c_aet],
outputs_info=[
{
"initial": aet.as_tensor_variable(0.0, dtype=config.floatX),
"taps": [-1],
},
{
"initial": aet.as_tensor_variable(
np.r_[-1.0, 1.3, 0.0].astype(config.floatX)
),
"taps": [-1, -3],
},
None,
],
non_sequences=[a_aet],
# n_steps=10,
name="yz_scan",
strict=True,
)
out_fg = FunctionGraph([a_aet, b_aet, c_aet], scan_res)
test_input_vals = [
np.array(10.0).astype(config.floatX),
np.arange(10, dtype=config.floatX),
np.arange(20, 30, dtype=config.floatX),
]
compare_numba_and_py(out_fg, test_input_vals)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论