提交 739bd49f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add support for shared inputs in numba_funcify_Scan

上级 9ae884dd
......@@ -554,7 +554,7 @@ def test_DirichletRV(a, size, cm):
a_val = a.tag.test_value
# For coverage purposes only...
eval_python_only([a], FunctionGraph(outputs=[g], clone=False), [a_val])
eval_python_only([a], [g], [a_val])
all_samples = []
for i in range(1000):
......
......@@ -2,15 +2,160 @@ import numpy as np
import pytest
import aesara.tensor as at
from aesara import config, grad
from aesara import config, function, grad
from aesara.compile.mode import Mode, get_mode
from aesara.graph.fg import FunctionGraph
from aesara.scan.basic import scan
from aesara.scan.op import Scan
from aesara.scan.utils import until
from aesara.tensor.random.utils import RandomStream
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py
@pytest.mark.parametrize(
"fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check",
[
# sequences
(
lambda a_t: 2 * a_t,
[at.dvector("a")],
[{}],
[],
None,
[np.arange(10)],
None,
lambda op: op.info.n_seqs > 0,
),
# nit-sot
(
lambda: at.as_tensor(2.0),
[],
[{}],
[],
3,
[],
None,
lambda op: op.info.n_nit_sot > 0,
),
# nit-sot, non_seq
(
lambda c: at.as_tensor(2.0) * c,
[],
[{}],
[at.dscalar("c")],
3,
[1.0],
None,
lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0,
),
# sit-sot
(
lambda a_tm1: 2 * a_tm1,
[],
[{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]}],
[],
3,
[],
None,
lambda op: op.info.n_sit_sot > 0,
),
# sit-sot, while
(
lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
[],
[{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}],
[],
3,
[],
None,
lambda op: op.info.n_sit_sot > 0,
),
# nit-sot, shared input/output
(
lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
0, 1, name="a"
),
[],
[{}],
[],
3,
[],
[np.array([-1.63408257, 0.18046406, 2.43265803])],
lambda op: op.info.n_shared_outs > 0,
),
# mit-sot (that's also a type of sit-sot)
(
lambda a_tm1: 2 * a_tm1,
[],
[{"initial": at.as_tensor([0.0, 1.0], dtype="floatX"), "taps": [-2]}],
[],
6,
[],
None,
lambda op: op.info.n_mit_sot > 0,
),
# mit-sot
(
lambda a_tm1, b_tm1: (2 * a_tm1, 2 * b_tm1),
[],
[
{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]},
{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]},
],
[],
10,
[],
None,
lambda op: op.info.n_mit_sot > 0,
),
],
)
def test_xit_xot_types(
fn,
sequences,
outputs_info,
non_sequences,
n_steps,
input_vals,
output_vals,
op_check,
):
"""Test basic xit-xot configurations."""
res, updates = scan(
fn,
sequences=sequences,
outputs_info=outputs_info,
non_sequences=non_sequences,
n_steps=n_steps,
strict=True,
mode=Mode(linker="py", optimizer=None),
)
if not isinstance(res, list):
res = [res]
# Get rid of any `Subtensor` indexing on the `Scan` outputs
res = [r.owner.inputs[0] if not isinstance(r.owner.op, Scan) else r for r in res]
scan_op = res[0].owner.op
assert isinstance(scan_op, Scan)
_ = op_check(scan_op)
if output_vals is None:
compare_numba_and_py(
(sequences + non_sequences, res), input_vals, updates=updates
)
else:
numba_mode = get_mode("NUMBA")
numba_fn = function(
sequences + non_sequences, res, mode=numba_mode, updates=updates
)
res_val = numba_fn(*input_vals)
assert np.allclose(res_val, output_vals)
def test_scan_multiple_output():
"""Test a scan implementation of a SEIR model.
......@@ -202,34 +347,10 @@ def test_scan_multiple_none_output():
compare_numba_and_py(out_fg, test_input_vals)
def test_scan_save_mem_basic():
@pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_basic(n_steps_val):
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
k = at.iscalar("k")
A = at.dvector("A")
result, _ = scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=at.ones_like(A),
non_sequences=A,
n_steps=k,
)
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([A, k], [result])
test_input_vals = (np.arange(10, dtype=np.int32), 2)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
test_input_vals = (np.arange(10, dtype=np.int32), 4)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
@pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_2(n_steps_val):
def f_pow2(x_tm2, x_tm1):
return 2 * x_tm1 + x_tm2
......@@ -245,7 +366,7 @@ def test_scan_save_mem_2(n_steps_val):
state_val = np.array([1.0, 2.0])
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([init_x, n_steps], [output])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论