提交 739bd49f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add support for shared inputs in numba_funcify_Scan

上级 9ae884dd
from itertools import groupby
from textwrap import dedent, indent from textwrap import dedent, indent
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -14,6 +13,7 @@ from aesara.link.numba.dispatch.basic import ( ...@@ -14,6 +13,7 @@ from aesara.link.numba.dispatch.basic import (
) )
from aesara.link.utils import compile_function_src from aesara.link.utils import compile_function_src
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.tensor.type import TensorType
def idx_to_str( def idx_to_str(
...@@ -49,8 +49,6 @@ def array0d_range(x): ...@@ -49,8 +49,6 @@ def array0d_range(x):
def numba_funcify_Scan(op, node, **kwargs): def numba_funcify_Scan(op, node, **kwargs):
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph)) scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
n_seqs = op.info.n_seqs
outer_in_names_to_vars = { outer_in_names_to_vars = {
(f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs) (f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)
} }
...@@ -60,22 +58,63 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -60,22 +58,63 @@ def numba_funcify_Scan(op, node, **kwargs):
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names) outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names) outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names) outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
outer_in_shared_names = op.outer_shared(outer_in_names)
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
# These are all the outer-input names that have produce outputs/have output
# taps (i.e. they have inner-outputs and corresponding outer-outputs).
# Outer-outputs are ordered as follows:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs
outer_in_outtap_names = ( outer_in_outtap_names = (
outer_in_mit_mot_names outer_in_mit_mot_names
+ outer_in_mit_sot_names + outer_in_mit_sot_names
+ outer_in_sit_sot_names + outer_in_sit_sot_names
+ outer_in_nit_sot_names + outer_in_nit_sot_names
+ outer_in_shared_names
) )
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
inner_in_to_index_offset: List[Tuple[str, Optional[int], Optional[int]]] = [] # We create distinct variables for/references to the storage arrays for
allocate_taps_storage: List[str] = [] # each output.
outer_in_to_storage_name: Dict[str, str] = {}
for outer_in_name in outer_in_mit_mot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_mitmot_storage"
for outer_in_name in outer_in_mit_sot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_mitsot_storage"
for outer_in_name in outer_in_sit_sot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_sitsot_storage"
for outer_in_name in outer_in_nit_sot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_nitsot_storage"
for outer_in_name in outer_in_shared_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_shared_storage"
outer_output_names = list(outer_in_to_storage_name.values())
assert len(outer_output_names) == len(node.outputs)
# Construct the inner-input expressions (e.g. indexed storage expressions)
# Inner-inputs are ordered as follows:
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
# shared-inputs + non-sequences.
inner_in_exprs: List[str] = []
def add_inner_in_expr(
outer_in_name: str, tap_offset: Optional[int], storage_size_var: Optional[str]
):
"""Construct an inner-input expression."""
storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name)
indexed_inner_in_str = (
storage_name
if tap_offset is None
else idx_to_str(storage_name, tap_offset, size=storage_size_var)
)
inner_in_exprs.append(indexed_inner_in_str)
for outer_in_name in outer_in_seqs_names: for outer_in_name in outer_in_seqs_names:
# A sequence with multiple taps is provided as multiple modified input # These outer-inputs are indexed without offsets or storage wrap-around
# sequences--all sliced so as to keep following the logic of a normal add_inner_in_expr(outer_in_name, 0, None)
# sequence.
inner_in_to_index_offset.append((outer_in_name, 0, None))
inner_in_names_to_input_taps: Dict[str, Tuple[int]] = dict( inner_in_names_to_input_taps: Dict[str, Tuple[int]] = dict(
zip( zip(
...@@ -89,201 +128,202 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -89,201 +128,202 @@ def numba_funcify_Scan(op, node, **kwargs):
zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices) zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices)
) )
# Inner-outputs consist of:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
# shared-outputs [+ while-condition]
inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))] inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]
# Maps storage array names to their tap values (i.e. maximum absolute tap # inner_out_shared_names = op.inner_shared_outs(inner_output_names)
# value) and storage sizes
inner_out_name_to_taps_storage: List[Tuple[str, int, Optional[str]]] = [] # The assignment statements that copy inner-outputs into the outer-outputs
outer_in_to_storage_name: Dict[str, str] = {} # storage
outer_in_sot_names = set( inner_out_to_outer_in_stmts: List[str] = []
outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names
) # Special statements that perform storage truncation for `while`-loops and
# rotation for initially truncated storage.
output_storage_post_proc_stmts: List[str] = []
# In truncated storage situations (e.g. created by `save_mem_new_scan`),
# the taps and output storage overlap, instead of the standard situation in
# which the output storage is large enough to contain both the initial taps
# values and the output storage. In this truncated case, we use the
# storage array like a circular buffer, and that's why we need to track the
# storage size along with the taps length/indexing offset.
def add_output_storage_post_proc_stmt(
outer_in_name: str, tap_sizes: Tuple[int], storage_size: str
):
tap_size = max(tap_sizes)
if op.info.as_while:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_proc_stmts.append(
dedent(
f"""
if i + {tap_size} < {storage_size}:
{storage_size} = i + {tap_size}
{outer_in_name} = {outer_in_name}[:{storage_size}]
"""
).strip()
)
# Rotate the storage so that the last computed value is at the end of
# the storage array.
# This is needed when the output storage array does not have a length
# equal to the number of taps plus `n_steps`.
output_storage_post_proc_stmts.append(
dedent(
f"""
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
if {outer_in_name}_shift > 0:
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left))
"""
).strip()
)
# Special in-loop statements that create (nit-sot) storage arrays after a
# single iteration is performed. This is necessary because we don't know
# the exact shapes of the storage arrays that need to be allocated until
# after an iteration is performed.
inner_out_post_processing_stmts: List[str] = [] inner_out_post_processing_stmts: List[str] = []
# Storage allocation statements
# For output storage allocated/provided by the inputs, these statements
# will either construct aliases between the input names and the entries in
# `outer_in_to_storage_name` or assign the latter to expressions that
# create copies of those storage inputs.
# In the nit-sot case, empty dummy arrays are assigned to the storage
# variables and updated later by the statements in
# `inner_out_post_processing_stmts`.
storage_alloc_stmts: List[str] = []
for outer_in_name in outer_in_outtap_names: for outer_in_name in outer_in_outtap_names:
outer_in_var = outer_in_names_to_vars[outer_in_name] outer_in_var = outer_in_names_to_vars[outer_in_name]
if outer_in_name in outer_in_sot_names: if outer_in_name not in outer_in_nit_sot_names:
if outer_in_name in outer_in_mit_mot_names:
storage_name = f"{outer_in_name}_mitmot_storage"
elif outer_in_name in outer_in_mit_sot_names:
storage_name = f"{outer_in_name}_mitsot_storage"
else:
# Note that the outputs with single, non-`-1` taps are (e.g. `taps
# = [-2]`) are classified as mit-sot, so the code for handling
# sit-sots remains constant as follows
storage_name = f"{outer_in_name}_sitsot_storage"
output_idx = len(outer_in_to_storage_name) storage_name = outer_in_to_storage_name[outer_in_name]
outer_in_to_storage_name[outer_in_name] = storage_name
input_taps = inner_in_names_to_input_taps[outer_in_name] is_tensor_type = isinstance(outer_in_var.type, TensorType)
tap_storage_size = -min(input_taps) if is_tensor_type:
assert tap_storage_size >= 0 storage_size_name = f"{outer_in_name}_len"
storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]"
input_taps = inner_in_names_to_input_taps[outer_in_name]
tap_storage_size = -min(input_taps)
assert tap_storage_size >= 0
storage_size_name = f"{outer_in_name}_len" for in_tap in input_taps:
tap_offset = in_tap + tap_storage_size
assert tap_offset >= 0
add_inner_in_expr(outer_in_name, tap_offset, storage_size_name)
for in_tap in input_taps: output_taps = inner_in_names_to_output_taps.get(
tap_offset = in_tap + tap_storage_size outer_in_name, [tap_storage_size]
assert tap_offset >= 0
# In truncated storage situations (i.e. created by
# `save_mem_new_scan`), the taps and output storage overlap,
# instead of the standard situation in which the output storage
# is large enough to contain both the initial taps values and
# the output storage.
inner_in_to_index_offset.append(
(outer_in_name, tap_offset, storage_size_name)
) )
for out_tap in output_taps:
inner_out_to_outer_in_stmts.append(
idx_to_str(storage_name, out_tap, size=storage_size_name)
)
output_taps = inner_in_names_to_output_taps.get( add_output_storage_post_proc_stmt(
outer_in_name, [tap_storage_size] storage_name, output_taps, storage_size_name
)
for out_tap in output_taps:
inner_out_name_to_taps_storage.append(
(storage_name, out_tap, storage_size_name)
) )
if output_idx in node.op.destroy_map: else:
storage_size_stmt = ""
add_inner_in_expr(outer_in_name, None, None)
inner_out_to_outer_in_stmts.append(storage_name)
output_idx = outer_output_names.index(storage_name)
if output_idx in node.op.destroy_map or not is_tensor_type:
storage_alloc_stmt = f"{storage_name} = {outer_in_name}" storage_alloc_stmt = f"{storage_name} = {outer_in_name}"
else: else:
storage_alloc_stmt = f"{storage_name} = np.copy({outer_in_name})" storage_alloc_stmt = f"{storage_name} = np.copy({outer_in_name})"
storage_alloc_stmt = dedent( storage_alloc_stmt = dedent(
f""" f"""
# {outer_in_var.type} {storage_size_stmt}
{storage_size_name} = {outer_in_name}.shape[0]
{storage_alloc_stmt} {storage_alloc_stmt}
""" """
).strip() ).strip()
allocate_taps_storage.append(storage_alloc_stmt) storage_alloc_stmts.append(storage_alloc_stmt)
else:
assert outer_in_name in outer_in_nit_sot_names
elif outer_in_name in outer_in_nit_sot_names:
# This is a special case in which there are no outer-inputs used # This is a special case in which there are no outer-inputs used
# for outer-output storage, so we need to create our own storage # for outer-output storage, so we need to create our own storage
# from scratch. # from scratch.
storage_name = outer_in_to_storage_name[outer_in_name]
storage_name = f"{outer_in_name}_nitsot_storage"
outer_in_to_storage_name[outer_in_name] = storage_name
storage_size_name = f"{outer_in_name}_len" storage_size_name = f"{outer_in_name}_len"
inner_out_name_to_taps_storage.append((storage_name, 0, storage_size_name))
inner_out_to_outer_in_stmts.append(
idx_to_str(storage_name, 0, size=storage_size_name)
)
add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name)
# In case of nit-sots we are provided the length of the array in # In case of nit-sots we are provided the length of the array in
# the iteration dimension instead of actual arrays, hence we # the iteration dimension instead of actual arrays, hence we
# allocate space for the results accordingly. # allocate space for the results accordingly.
curr_nit_sot_position = outer_in_names[1:].index(outer_in_name) - n_seqs curr_nit_sot_position = outer_in_nit_sot_names.index(outer_in_name)
curr_nit_sot = op.inner_outputs[curr_nit_sot_position] curr_nit_sot = op.inner_nitsot_outs(op.inner_outputs)[curr_nit_sot_position]
needs_alloc = curr_nit_sot.ndim > 0
storage_shape = create_tuple_string( storage_shape = create_tuple_string(
[storage_size_name] + ["0"] * curr_nit_sot.ndim [storage_size_name] + ["0"] * curr_nit_sot.ndim
) )
storage_dtype = curr_nit_sot.type.numpy_dtype.name storage_dtype = curr_nit_sot.type.numpy_dtype.name
allocate_taps_storage.append( storage_alloc_stmts.append(
dedent( dedent(
f""" f"""
# {curr_nit_sot.type}
{storage_size_name} = to_numba_scalar({outer_in_name}) {storage_size_name} = to_numba_scalar({outer_in_name})
{storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype}) {storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype})
""" """
).strip() ).strip()
) )
if needs_alloc: if curr_nit_sot.type.ndim > 0:
allocate_taps_storage.append(f"{outer_in_name}_ready = False") storage_alloc_stmts.append(f"{outer_in_name}_ready = False")
# In this case, we don't know the shape of the output storage # In this case, we don't know the shape of the output storage
# array until we get some output from the inner-function. # array until we get some output from the inner-function.
# With the following we add delayed output storage initialization: # With the following we add delayed output storage initialization:
inner_out_name = inner_output_names[curr_nit_sot_position] inner_out_name = op.inner_nitsot_outs(inner_output_names)[
curr_nit_sot_position
]
inner_out_post_processing_stmts.append( inner_out_post_processing_stmts.append(
dedent( dedent(
f""" f"""
if not {outer_in_name}_ready: if not {outer_in_name}_ready:
{storage_name} = np.empty(({storage_size_name},) + {inner_out_name}.shape, dtype=np.{storage_dtype}) {storage_name} = np.empty(({storage_size_name},) + np.shape({inner_out_name}), dtype=np.{storage_dtype})
{outer_in_name}_ready = True {outer_in_name}_ready = True
""" """
).strip() ).strip()
) )
# The non_seqs are passed to the inner function as-is
for name in outer_in_non_seqs_names: for name in outer_in_non_seqs_names:
inner_in_to_index_offset.append((name, None, None)) add_inner_in_expr(name, None, None)
inner_out_storage_indexed = [
name if taps is None else idx_to_str(name, taps, size=size)
for (name, taps, size) in inner_out_name_to_taps_storage
]
output_storage_post_processing_stmts: List[str] = []
for outer_in_name, grp_vals in groupby(
inner_out_name_to_taps_storage, lambda x: x[0]
):
_, tap_sizes, storage_sizes = zip(*grp_vals)
tap_size = max(tap_sizes)
storage_size = storage_sizes[0]
if op.info.as_while:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_processing_stmts.append(
dedent(
f"""
if i + {tap_size} < {storage_size}:
{storage_size} = i + {tap_size}
{outer_in_name} = {outer_in_name}[:{storage_size}]
"""
).strip()
)
# Rotate the storage so that the last computed value is at the end of
# the storage array.
# This is needed when the output storage array does not have a length
# equal to the number of taps plus `n_steps`.
output_storage_post_processing_stmts.append(
dedent(
f"""
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
if {outer_in_name}_shift > 0:
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left))
"""
).strip()
)
if op.info.as_while: if op.info.as_while:
# The inner function will return a boolean as the last value # The inner function will return a boolean as the last value
inner_out_storage_indexed.append("cond") inner_out_to_outer_in_stmts.append("cond")
output_names = [outer_in_to_storage_name[n] for n in outer_in_outtap_names] assert len(inner_in_exprs) == len(op.fgraph.inputs)
# Construct the inner-input expressions inner_in_args = create_arg_string(inner_in_exprs)
inner_inputs: List[str] = []
for outer_in_name, tap_offset, size in inner_in_to_index_offset:
storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name)
indexed_inner_in_str = (
idx_to_str(storage_name, tap_offset, size=size)
if tap_offset is not None
else storage_name
)
# if outer_in_names_to_vars[outer_in_name].type.ndim - 1 <= 0:
# # Convert scalar inner-inputs to Numba scalars
# indexed_inner_in_str = f"to_numba_scalar({indexed_inner_in_str})"
inner_inputs.append(indexed_inner_in_str)
inner_inputs = create_arg_string(inner_inputs)
inner_outputs = create_tuple_string(inner_output_names) inner_outputs = create_tuple_string(inner_output_names)
input_storage_block = "\n".join(allocate_taps_storage) input_storage_block = "\n".join(storage_alloc_stmts)
output_storage_post_processing_block = "\n".join( output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts)
output_storage_post_processing_stmts
)
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts) inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
inner_out_to_outer_out_stmts = "\n".join(
[f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)]
)
scan_op_src = f""" scan_op_src = f"""
def scan({", ".join(outer_in_names)}): def scan({", ".join(outer_in_names)}):
...@@ -292,14 +332,14 @@ def scan({", ".join(outer_in_names)}): ...@@ -292,14 +332,14 @@ def scan({", ".join(outer_in_names)}):
i = 0 i = 0
cond = False cond = False
while i < n_steps and not cond: while i < n_steps and not cond:
{inner_outputs} = scan_inner_func({inner_inputs}) {inner_outputs} = scan_inner_func({inner_in_args})
{indent(inner_out_post_processing_block, " " * 8)} {indent(inner_out_post_processing_block, " " * 8)}
{create_tuple_string(inner_out_storage_indexed)} = {inner_outputs} {indent(inner_out_to_outer_out_stmts, " " * 8)}
i += 1 i += 1
{indent(output_storage_post_processing_block, " " * 4)} {indent(output_storage_post_processing_block, " " * 4)}
return {create_arg_string(output_names)} return {create_arg_string(outer_output_names)}
""" """
global_env = { global_env = {
......
...@@ -554,7 +554,7 @@ def test_DirichletRV(a, size, cm): ...@@ -554,7 +554,7 @@ def test_DirichletRV(a, size, cm):
a_val = a.tag.test_value a_val = a.tag.test_value
# For coverage purposes only... # For coverage purposes only...
eval_python_only([a], FunctionGraph(outputs=[g], clone=False), [a_val]) eval_python_only([a], [g], [a_val])
all_samples = [] all_samples = []
for i in range(1000): for i in range(1000):
......
...@@ -2,15 +2,160 @@ import numpy as np ...@@ -2,15 +2,160 @@ import numpy as np
import pytest import pytest
import aesara.tensor as at import aesara.tensor as at
from aesara import config, grad from aesara import config, function, grad
from aesara.compile.mode import Mode, get_mode from aesara.compile.mode import Mode, get_mode
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.scan.basic import scan from aesara.scan.basic import scan
from aesara.scan.op import Scan
from aesara.scan.utils import until from aesara.scan.utils import until
from aesara.tensor.random.utils import RandomStream
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
@pytest.mark.parametrize(
"fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check",
[
# sequences
(
lambda a_t: 2 * a_t,
[at.dvector("a")],
[{}],
[],
None,
[np.arange(10)],
None,
lambda op: op.info.n_seqs > 0,
),
# nit-sot
(
lambda: at.as_tensor(2.0),
[],
[{}],
[],
3,
[],
None,
lambda op: op.info.n_nit_sot > 0,
),
# nit-sot, non_seq
(
lambda c: at.as_tensor(2.0) * c,
[],
[{}],
[at.dscalar("c")],
3,
[1.0],
None,
lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0,
),
# sit-sot
(
lambda a_tm1: 2 * a_tm1,
[],
[{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]}],
[],
3,
[],
None,
lambda op: op.info.n_sit_sot > 0,
),
# sit-sot, while
(
lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
[],
[{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}],
[],
3,
[],
None,
lambda op: op.info.n_sit_sot > 0,
),
# nit-sot, shared input/output
(
lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
0, 1, name="a"
),
[],
[{}],
[],
3,
[],
[np.array([-1.63408257, 0.18046406, 2.43265803])],
lambda op: op.info.n_shared_outs > 0,
),
# mit-sot (that's also a type of sit-sot)
(
lambda a_tm1: 2 * a_tm1,
[],
[{"initial": at.as_tensor([0.0, 1.0], dtype="floatX"), "taps": [-2]}],
[],
6,
[],
None,
lambda op: op.info.n_mit_sot > 0,
),
# mit-sot
(
lambda a_tm1, b_tm1: (2 * a_tm1, 2 * b_tm1),
[],
[
{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]},
{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]},
],
[],
10,
[],
None,
lambda op: op.info.n_mit_sot > 0,
),
],
)
def test_xit_xot_types(
fn,
sequences,
outputs_info,
non_sequences,
n_steps,
input_vals,
output_vals,
op_check,
):
"""Test basic xit-xot configurations."""
res, updates = scan(
fn,
sequences=sequences,
outputs_info=outputs_info,
non_sequences=non_sequences,
n_steps=n_steps,
strict=True,
mode=Mode(linker="py", optimizer=None),
)
if not isinstance(res, list):
res = [res]
# Get rid of any `Subtensor` indexing on the `Scan` outputs
res = [r.owner.inputs[0] if not isinstance(r.owner.op, Scan) else r for r in res]
scan_op = res[0].owner.op
assert isinstance(scan_op, Scan)
_ = op_check(scan_op)
if output_vals is None:
compare_numba_and_py(
(sequences + non_sequences, res), input_vals, updates=updates
)
else:
numba_mode = get_mode("NUMBA")
numba_fn = function(
sequences + non_sequences, res, mode=numba_mode, updates=updates
)
res_val = numba_fn(*input_vals)
assert np.allclose(res_val, output_vals)
def test_scan_multiple_output(): def test_scan_multiple_output():
"""Test a scan implementation of a SEIR model. """Test a scan implementation of a SEIR model.
...@@ -202,34 +347,10 @@ def test_scan_multiple_none_output(): ...@@ -202,34 +347,10 @@ def test_scan_multiple_none_output():
compare_numba_and_py(out_fg, test_input_vals) compare_numba_and_py(out_fg, test_input_vals)
def test_scan_save_mem_basic(): @pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_basic(n_steps_val):
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite.""" """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
k = at.iscalar("k")
A = at.dvector("A")
result, _ = scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=at.ones_like(A),
non_sequences=A,
n_steps=k,
)
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([A, k], [result])
test_input_vals = (np.arange(10, dtype=np.int32), 2)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
test_input_vals = (np.arange(10, dtype=np.int32), 4)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
@pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_2(n_steps_val):
def f_pow2(x_tm2, x_tm1): def f_pow2(x_tm2, x_tm1):
return 2 * x_tm1 + x_tm2 return 2 * x_tm1 + x_tm2
...@@ -245,7 +366,7 @@ def test_scan_save_mem_2(n_steps_val): ...@@ -245,7 +366,7 @@ def test_scan_save_mem_2(n_steps_val):
state_val = np.array([1.0, 2.0]) state_val = np.array([1.0, 2.0])
numba_mode = get_mode("NUMBA") # .including("scan_save_mem") numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem") py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([init_x, n_steps], [output]) out_fg = FunctionGraph([init_x, n_steps], [output])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论