提交 739bd49f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add support for shared inputs in numba_funcify_Scan

上级 9ae884dd
...@@ -554,7 +554,7 @@ def test_DirichletRV(a, size, cm): ...@@ -554,7 +554,7 @@ def test_DirichletRV(a, size, cm):
a_val = a.tag.test_value a_val = a.tag.test_value
# For coverage purposes only... # For coverage purposes only...
eval_python_only([a], FunctionGraph(outputs=[g], clone=False), [a_val]) eval_python_only([a], [g], [a_val])
all_samples = [] all_samples = []
for i in range(1000): for i in range(1000):
......
...@@ -2,15 +2,160 @@ import numpy as np ...@@ -2,15 +2,160 @@ import numpy as np
import pytest import pytest
import aesara.tensor as at import aesara.tensor as at
from aesara import config, grad from aesara import config, function, grad
from aesara.compile.mode import Mode, get_mode from aesara.compile.mode import Mode, get_mode
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.scan.basic import scan from aesara.scan.basic import scan
from aesara.scan.op import Scan
from aesara.scan.utils import until from aesara.scan.utils import until
from aesara.tensor.random.utils import RandomStream
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
@pytest.mark.parametrize(
"fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check",
[
# sequences
(
lambda a_t: 2 * a_t,
[at.dvector("a")],
[{}],
[],
None,
[np.arange(10)],
None,
lambda op: op.info.n_seqs > 0,
),
# nit-sot
(
lambda: at.as_tensor(2.0),
[],
[{}],
[],
3,
[],
None,
lambda op: op.info.n_nit_sot > 0,
),
# nit-sot, non_seq
(
lambda c: at.as_tensor(2.0) * c,
[],
[{}],
[at.dscalar("c")],
3,
[1.0],
None,
lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0,
),
# sit-sot
(
lambda a_tm1: 2 * a_tm1,
[],
[{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]}],
[],
3,
[],
None,
lambda op: op.info.n_sit_sot > 0,
),
# sit-sot, while
(
lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
[],
[{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}],
[],
3,
[],
None,
lambda op: op.info.n_sit_sot > 0,
),
# nit-sot, shared input/output
(
lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
0, 1, name="a"
),
[],
[{}],
[],
3,
[],
[np.array([-1.63408257, 0.18046406, 2.43265803])],
lambda op: op.info.n_shared_outs > 0,
),
# mit-sot (that's also a type of sit-sot)
(
lambda a_tm1: 2 * a_tm1,
[],
[{"initial": at.as_tensor([0.0, 1.0], dtype="floatX"), "taps": [-2]}],
[],
6,
[],
None,
lambda op: op.info.n_mit_sot > 0,
),
# mit-sot
(
lambda a_tm1, b_tm1: (2 * a_tm1, 2 * b_tm1),
[],
[
{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]},
{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]},
],
[],
10,
[],
None,
lambda op: op.info.n_mit_sot > 0,
),
],
)
def test_xit_xot_types(
fn,
sequences,
outputs_info,
non_sequences,
n_steps,
input_vals,
output_vals,
op_check,
):
"""Test basic xit-xot configurations."""
res, updates = scan(
fn,
sequences=sequences,
outputs_info=outputs_info,
non_sequences=non_sequences,
n_steps=n_steps,
strict=True,
mode=Mode(linker="py", optimizer=None),
)
if not isinstance(res, list):
res = [res]
# Get rid of any `Subtensor` indexing on the `Scan` outputs
res = [r.owner.inputs[0] if not isinstance(r.owner.op, Scan) else r for r in res]
scan_op = res[0].owner.op
assert isinstance(scan_op, Scan)
_ = op_check(scan_op)
if output_vals is None:
compare_numba_and_py(
(sequences + non_sequences, res), input_vals, updates=updates
)
else:
numba_mode = get_mode("NUMBA")
numba_fn = function(
sequences + non_sequences, res, mode=numba_mode, updates=updates
)
res_val = numba_fn(*input_vals)
assert np.allclose(res_val, output_vals)
def test_scan_multiple_output(): def test_scan_multiple_output():
"""Test a scan implementation of a SEIR model. """Test a scan implementation of a SEIR model.
...@@ -202,34 +347,10 @@ def test_scan_multiple_none_output(): ...@@ -202,34 +347,10 @@ def test_scan_multiple_none_output():
compare_numba_and_py(out_fg, test_input_vals) compare_numba_and_py(out_fg, test_input_vals)
def test_scan_save_mem_basic(): @pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_basic(n_steps_val):
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite.""" """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
k = at.iscalar("k")
A = at.dvector("A")
result, _ = scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=at.ones_like(A),
non_sequences=A,
n_steps=k,
)
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([A, k], [result])
test_input_vals = (np.arange(10, dtype=np.int32), 2)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
test_input_vals = (np.arange(10, dtype=np.int32), 4)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
@pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_2(n_steps_val):
def f_pow2(x_tm2, x_tm1): def f_pow2(x_tm2, x_tm1):
return 2 * x_tm1 + x_tm2 return 2 * x_tm1 + x_tm2
...@@ -245,7 +366,7 @@ def test_scan_save_mem_2(n_steps_val): ...@@ -245,7 +366,7 @@ def test_scan_save_mem_2(n_steps_val):
state_val = np.array([1.0, 2.0]) state_val = np.array([1.0, 2.0])
numba_mode = get_mode("NUMBA") # .including("scan_save_mem") numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem") py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([init_x, n_steps], [output]) out_fg = FunctionGraph([init_x, n_steps], [output])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论