提交 739bd49f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add support for shared inputs in numba_funcify_Scan

上级 9ae884dd
from itertools import groupby
from textwrap import dedent, indent
from typing import Dict, List, Optional, Tuple
......@@ -14,6 +13,7 @@ from aesara.link.numba.dispatch.basic import (
)
from aesara.link.utils import compile_function_src
from aesara.scan.op import Scan
from aesara.tensor.type import TensorType
def idx_to_str(
......@@ -49,8 +49,6 @@ def array0d_range(x):
def numba_funcify_Scan(op, node, **kwargs):
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
n_seqs = op.info.n_seqs
outer_in_names_to_vars = {
(f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)
}
......@@ -60,22 +58,63 @@ def numba_funcify_Scan(op, node, **kwargs):
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
outer_in_shared_names = op.outer_shared(outer_in_names)
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
# These are all the outer-input names that have produce outputs/have output
# taps (i.e. they have inner-outputs and corresponding outer-outputs).
# Outer-outputs are ordered as follows:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs
outer_in_outtap_names = (
outer_in_mit_mot_names
+ outer_in_mit_sot_names
+ outer_in_sit_sot_names
+ outer_in_nit_sot_names
+ outer_in_shared_names
)
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
inner_in_to_index_offset: List[Tuple[str, Optional[int], Optional[int]]] = []
allocate_taps_storage: List[str] = []
# We create distinct variables for/references to the storage arrays for
# each output.
outer_in_to_storage_name: Dict[str, str] = {}
for outer_in_name in outer_in_mit_mot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_mitmot_storage"
for outer_in_name in outer_in_mit_sot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_mitsot_storage"
for outer_in_name in outer_in_sit_sot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_sitsot_storage"
for outer_in_name in outer_in_nit_sot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_nitsot_storage"
for outer_in_name in outer_in_shared_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_shared_storage"
outer_output_names = list(outer_in_to_storage_name.values())
assert len(outer_output_names) == len(node.outputs)
# Construct the inner-input expressions (e.g. indexed storage expressions)
# Inner-inputs are ordered as follows:
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
# shared-inputs + non-sequences.
inner_in_exprs: List[str] = []
def add_inner_in_expr(
outer_in_name: str, tap_offset: Optional[int], storage_size_var: Optional[str]
):
"""Construct an inner-input expression."""
storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name)
indexed_inner_in_str = (
storage_name
if tap_offset is None
else idx_to_str(storage_name, tap_offset, size=storage_size_var)
)
inner_in_exprs.append(indexed_inner_in_str)
for outer_in_name in outer_in_seqs_names:
# A sequence with multiple taps is provided as multiple modified input
# sequences--all sliced so as to keep following the logic of a normal
# sequence.
inner_in_to_index_offset.append((outer_in_name, 0, None))
# These outer-inputs are indexed without offsets or storage wrap-around
add_inner_in_expr(outer_in_name, 0, None)
inner_in_names_to_input_taps: Dict[str, Tuple[int]] = dict(
zip(
......@@ -89,201 +128,202 @@ def numba_funcify_Scan(op, node, **kwargs):
zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices)
)
# Inner-outputs consist of:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
# shared-outputs [+ while-condition]
inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]
# Maps storage array names to their tap values (i.e. maximum absolute tap
# value) and storage sizes
inner_out_name_to_taps_storage: List[Tuple[str, int, Optional[str]]] = []
outer_in_to_storage_name: Dict[str, str] = {}
outer_in_sot_names = set(
outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names
)
# inner_out_shared_names = op.inner_shared_outs(inner_output_names)
# The assignment statements that copy inner-outputs into the outer-outputs
# storage
inner_out_to_outer_in_stmts: List[str] = []
# Special statements that perform storage truncation for `while`-loops and
# rotation for initially truncated storage.
output_storage_post_proc_stmts: List[str] = []
# In truncated storage situations (e.g. created by `save_mem_new_scan`),
# the taps and output storage overlap, instead of the standard situation in
# which the output storage is large enough to contain both the initial taps
# values and the output storage. In this truncated case, we use the
# storage array like a circular buffer, and that's why we need to track the
# storage size along with the taps length/indexing offset.
def add_output_storage_post_proc_stmt(
outer_in_name: str, tap_sizes: Tuple[int], storage_size: str
):
tap_size = max(tap_sizes)
if op.info.as_while:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_proc_stmts.append(
dedent(
f"""
if i + {tap_size} < {storage_size}:
{storage_size} = i + {tap_size}
{outer_in_name} = {outer_in_name}[:{storage_size}]
"""
).strip()
)
# Rotate the storage so that the last computed value is at the end of
# the storage array.
# This is needed when the output storage array does not have a length
# equal to the number of taps plus `n_steps`.
output_storage_post_proc_stmts.append(
dedent(
f"""
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
if {outer_in_name}_shift > 0:
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left))
"""
).strip()
)
# Special in-loop statements that create (nit-sot) storage arrays after a
# single iteration is performed. This is necessary because we don't know
# the exact shapes of the storage arrays that need to be allocated until
# after an iteration is performed.
inner_out_post_processing_stmts: List[str] = []
# Storage allocation statements
# For output storage allocated/provided by the inputs, these statements
# will either construct aliases between the input names and the entries in
# `outer_in_to_storage_name` or assign the latter to expressions that
# create copies of those storage inputs.
# In the nit-sot case, empty dummy arrays are assigned to the storage
# variables and updated later by the statements in
# `inner_out_post_processing_stmts`.
storage_alloc_stmts: List[str] = []
for outer_in_name in outer_in_outtap_names:
outer_in_var = outer_in_names_to_vars[outer_in_name]
if outer_in_name in outer_in_sot_names:
if outer_in_name in outer_in_mit_mot_names:
storage_name = f"{outer_in_name}_mitmot_storage"
elif outer_in_name in outer_in_mit_sot_names:
storage_name = f"{outer_in_name}_mitsot_storage"
else:
# Note that the outputs with single, non-`-1` taps are (e.g. `taps
# = [-2]`) are classified as mit-sot, so the code for handling
# sit-sots remains constant as follows
storage_name = f"{outer_in_name}_sitsot_storage"
if outer_in_name not in outer_in_nit_sot_names:
output_idx = len(outer_in_to_storage_name)
outer_in_to_storage_name[outer_in_name] = storage_name
storage_name = outer_in_to_storage_name[outer_in_name]
input_taps = inner_in_names_to_input_taps[outer_in_name]
tap_storage_size = -min(input_taps)
assert tap_storage_size >= 0
is_tensor_type = isinstance(outer_in_var.type, TensorType)
if is_tensor_type:
storage_size_name = f"{outer_in_name}_len"
storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]"
input_taps = inner_in_names_to_input_taps[outer_in_name]
tap_storage_size = -min(input_taps)
assert tap_storage_size >= 0
storage_size_name = f"{outer_in_name}_len"
for in_tap in input_taps:
tap_offset = in_tap + tap_storage_size
assert tap_offset >= 0
add_inner_in_expr(outer_in_name, tap_offset, storage_size_name)
for in_tap in input_taps:
tap_offset = in_tap + tap_storage_size
assert tap_offset >= 0
# In truncated storage situations (i.e. created by
# `save_mem_new_scan`), the taps and output storage overlap,
# instead of the standard situation in which the output storage
# is large enough to contain both the initial taps values and
# the output storage.
inner_in_to_index_offset.append(
(outer_in_name, tap_offset, storage_size_name)
output_taps = inner_in_names_to_output_taps.get(
outer_in_name, [tap_storage_size]
)
for out_tap in output_taps:
inner_out_to_outer_in_stmts.append(
idx_to_str(storage_name, out_tap, size=storage_size_name)
)
output_taps = inner_in_names_to_output_taps.get(
outer_in_name, [tap_storage_size]
)
for out_tap in output_taps:
inner_out_name_to_taps_storage.append(
(storage_name, out_tap, storage_size_name)
add_output_storage_post_proc_stmt(
storage_name, output_taps, storage_size_name
)
if output_idx in node.op.destroy_map:
else:
storage_size_stmt = ""
add_inner_in_expr(outer_in_name, None, None)
inner_out_to_outer_in_stmts.append(storage_name)
output_idx = outer_output_names.index(storage_name)
if output_idx in node.op.destroy_map or not is_tensor_type:
storage_alloc_stmt = f"{storage_name} = {outer_in_name}"
else:
storage_alloc_stmt = f"{storage_name} = np.copy({outer_in_name})"
storage_alloc_stmt = dedent(
f"""
# {outer_in_var.type}
{storage_size_name} = {outer_in_name}.shape[0]
{storage_size_stmt}
{storage_alloc_stmt}
"""
).strip()
allocate_taps_storage.append(storage_alloc_stmt)
storage_alloc_stmts.append(storage_alloc_stmt)
else:
assert outer_in_name in outer_in_nit_sot_names
elif outer_in_name in outer_in_nit_sot_names:
# This is a special case in which there are no outer-inputs used
# for outer-output storage, so we need to create our own storage
# from scratch.
storage_name = f"{outer_in_name}_nitsot_storage"
outer_in_to_storage_name[outer_in_name] = storage_name
storage_name = outer_in_to_storage_name[outer_in_name]
storage_size_name = f"{outer_in_name}_len"
inner_out_name_to_taps_storage.append((storage_name, 0, storage_size_name))
inner_out_to_outer_in_stmts.append(
idx_to_str(storage_name, 0, size=storage_size_name)
)
add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name)
# In case of nit-sots we are provided the length of the array in
# the iteration dimension instead of actual arrays, hence we
# allocate space for the results accordingly.
curr_nit_sot_position = outer_in_names[1:].index(outer_in_name) - n_seqs
curr_nit_sot = op.inner_outputs[curr_nit_sot_position]
needs_alloc = curr_nit_sot.ndim > 0
curr_nit_sot_position = outer_in_nit_sot_names.index(outer_in_name)
curr_nit_sot = op.inner_nitsot_outs(op.inner_outputs)[curr_nit_sot_position]
storage_shape = create_tuple_string(
[storage_size_name] + ["0"] * curr_nit_sot.ndim
)
storage_dtype = curr_nit_sot.type.numpy_dtype.name
allocate_taps_storage.append(
storage_alloc_stmts.append(
dedent(
f"""
# {curr_nit_sot.type}
{storage_size_name} = to_numba_scalar({outer_in_name})
{storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype})
"""
).strip()
)
if needs_alloc:
allocate_taps_storage.append(f"{outer_in_name}_ready = False")
if curr_nit_sot.type.ndim > 0:
storage_alloc_stmts.append(f"{outer_in_name}_ready = False")
# In this case, we don't know the shape of the output storage
# array until we get some output from the inner-function.
# With the following we add delayed output storage initialization:
inner_out_name = inner_output_names[curr_nit_sot_position]
inner_out_name = op.inner_nitsot_outs(inner_output_names)[
curr_nit_sot_position
]
inner_out_post_processing_stmts.append(
dedent(
f"""
if not {outer_in_name}_ready:
{storage_name} = np.empty(({storage_size_name},) + {inner_out_name}.shape, dtype=np.{storage_dtype})
{storage_name} = np.empty(({storage_size_name},) + np.shape({inner_out_name}), dtype=np.{storage_dtype})
{outer_in_name}_ready = True
"""
).strip()
)
# The non_seqs are passed to the inner function as-is
for name in outer_in_non_seqs_names:
inner_in_to_index_offset.append((name, None, None))
inner_out_storage_indexed = [
name if taps is None else idx_to_str(name, taps, size=size)
for (name, taps, size) in inner_out_name_to_taps_storage
]
output_storage_post_processing_stmts: List[str] = []
for outer_in_name, grp_vals in groupby(
inner_out_name_to_taps_storage, lambda x: x[0]
):
_, tap_sizes, storage_sizes = zip(*grp_vals)
tap_size = max(tap_sizes)
storage_size = storage_sizes[0]
if op.info.as_while:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_processing_stmts.append(
dedent(
f"""
if i + {tap_size} < {storage_size}:
{storage_size} = i + {tap_size}
{outer_in_name} = {outer_in_name}[:{storage_size}]
"""
).strip()
)
# Rotate the storage so that the last computed value is at the end of
# the storage array.
# This is needed when the output storage array does not have a length
# equal to the number of taps plus `n_steps`.
output_storage_post_processing_stmts.append(
dedent(
f"""
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
if {outer_in_name}_shift > 0:
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left))
"""
).strip()
)
add_inner_in_expr(name, None, None)
if op.info.as_while:
# The inner function will return a boolean as the last value
inner_out_storage_indexed.append("cond")
inner_out_to_outer_in_stmts.append("cond")
output_names = [outer_in_to_storage_name[n] for n in outer_in_outtap_names]
assert len(inner_in_exprs) == len(op.fgraph.inputs)
# Construct the inner-input expressions
inner_inputs: List[str] = []
for outer_in_name, tap_offset, size in inner_in_to_index_offset:
storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name)
indexed_inner_in_str = (
idx_to_str(storage_name, tap_offset, size=size)
if tap_offset is not None
else storage_name
)
# if outer_in_names_to_vars[outer_in_name].type.ndim - 1 <= 0:
# # Convert scalar inner-inputs to Numba scalars
# indexed_inner_in_str = f"to_numba_scalar({indexed_inner_in_str})"
inner_inputs.append(indexed_inner_in_str)
inner_inputs = create_arg_string(inner_inputs)
inner_in_args = create_arg_string(inner_in_exprs)
inner_outputs = create_tuple_string(inner_output_names)
input_storage_block = "\n".join(allocate_taps_storage)
output_storage_post_processing_block = "\n".join(
output_storage_post_processing_stmts
)
input_storage_block = "\n".join(storage_alloc_stmts)
output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts)
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
inner_out_to_outer_out_stmts = "\n".join(
[f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)]
)
scan_op_src = f"""
def scan({", ".join(outer_in_names)}):
......@@ -292,14 +332,14 @@ def scan({", ".join(outer_in_names)}):
i = 0
cond = False
while i < n_steps and not cond:
{inner_outputs} = scan_inner_func({inner_inputs})
{inner_outputs} = scan_inner_func({inner_in_args})
{indent(inner_out_post_processing_block, " " * 8)}
{create_tuple_string(inner_out_storage_indexed)} = {inner_outputs}
{indent(inner_out_to_outer_out_stmts, " " * 8)}
i += 1
{indent(output_storage_post_processing_block, " " * 4)}
return {create_arg_string(output_names)}
return {create_arg_string(outer_output_names)}
"""
global_env = {
......
......@@ -554,7 +554,7 @@ def test_DirichletRV(a, size, cm):
a_val = a.tag.test_value
# For coverage purposes only...
eval_python_only([a], FunctionGraph(outputs=[g], clone=False), [a_val])
eval_python_only([a], [g], [a_val])
all_samples = []
for i in range(1000):
......
......@@ -2,15 +2,160 @@ import numpy as np
import pytest
import aesara.tensor as at
from aesara import config, grad
from aesara import config, function, grad
from aesara.compile.mode import Mode, get_mode
from aesara.graph.fg import FunctionGraph
from aesara.scan.basic import scan
from aesara.scan.op import Scan
from aesara.scan.utils import until
from aesara.tensor.random.utils import RandomStream
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py
@pytest.mark.parametrize(
"fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check",
[
# sequences
(
lambda a_t: 2 * a_t,
[at.dvector("a")],
[{}],
[],
None,
[np.arange(10)],
None,
lambda op: op.info.n_seqs > 0,
),
# nit-sot
(
lambda: at.as_tensor(2.0),
[],
[{}],
[],
3,
[],
None,
lambda op: op.info.n_nit_sot > 0,
),
# nit-sot, non_seq
(
lambda c: at.as_tensor(2.0) * c,
[],
[{}],
[at.dscalar("c")],
3,
[1.0],
None,
lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0,
),
# sit-sot
(
lambda a_tm1: 2 * a_tm1,
[],
[{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]}],
[],
3,
[],
None,
lambda op: op.info.n_sit_sot > 0,
),
# sit-sot, while
(
lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
[],
[{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}],
[],
3,
[],
None,
lambda op: op.info.n_sit_sot > 0,
),
# nit-sot, shared input/output
(
lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
0, 1, name="a"
),
[],
[{}],
[],
3,
[],
[np.array([-1.63408257, 0.18046406, 2.43265803])],
lambda op: op.info.n_shared_outs > 0,
),
# mit-sot (that's also a type of sit-sot)
(
lambda a_tm1: 2 * a_tm1,
[],
[{"initial": at.as_tensor([0.0, 1.0], dtype="floatX"), "taps": [-2]}],
[],
6,
[],
None,
lambda op: op.info.n_mit_sot > 0,
),
# mit-sot
(
lambda a_tm1, b_tm1: (2 * a_tm1, 2 * b_tm1),
[],
[
{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]},
{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]},
],
[],
10,
[],
None,
lambda op: op.info.n_mit_sot > 0,
),
],
)
def test_xit_xot_types(
fn,
sequences,
outputs_info,
non_sequences,
n_steps,
input_vals,
output_vals,
op_check,
):
"""Test basic xit-xot configurations."""
res, updates = scan(
fn,
sequences=sequences,
outputs_info=outputs_info,
non_sequences=non_sequences,
n_steps=n_steps,
strict=True,
mode=Mode(linker="py", optimizer=None),
)
if not isinstance(res, list):
res = [res]
# Get rid of any `Subtensor` indexing on the `Scan` outputs
res = [r.owner.inputs[0] if not isinstance(r.owner.op, Scan) else r for r in res]
scan_op = res[0].owner.op
assert isinstance(scan_op, Scan)
_ = op_check(scan_op)
if output_vals is None:
compare_numba_and_py(
(sequences + non_sequences, res), input_vals, updates=updates
)
else:
numba_mode = get_mode("NUMBA")
numba_fn = function(
sequences + non_sequences, res, mode=numba_mode, updates=updates
)
res_val = numba_fn(*input_vals)
assert np.allclose(res_val, output_vals)
def test_scan_multiple_output():
"""Test a scan implementation of a SEIR model.
......@@ -202,34 +347,10 @@ def test_scan_multiple_none_output():
compare_numba_and_py(out_fg, test_input_vals)
def test_scan_save_mem_basic():
@pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_basic(n_steps_val):
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
k = at.iscalar("k")
A = at.dvector("A")
result, _ = scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=at.ones_like(A),
non_sequences=A,
n_steps=k,
)
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([A, k], [result])
test_input_vals = (np.arange(10, dtype=np.int32), 2)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
test_input_vals = (np.arange(10, dtype=np.int32), 4)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
@pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_2(n_steps_val):
def f_pow2(x_tm2, x_tm1):
return 2 * x_tm1 + x_tm2
......@@ -245,7 +366,7 @@ def test_scan_save_mem_2(n_steps_val):
state_val = np.array([1.0, 2.0])
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([init_x, n_steps], [output])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论