提交 3d96ee80 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix storage handling in numba_funcify_Scan

上级 a2d05adc
from itertools import groupby
from textwrap import dedent, indent from textwrap import dedent, indent
from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
from numba import types from numba import types
from numba.extending import overload from numba.extending import overload
from aesara.graph.fg import FunctionGraph
from aesara.link.numba.dispatch import basic as numba_basic from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import ( from aesara.link.numba.dispatch.basic import (
create_arg_string, create_arg_string,
...@@ -15,13 +16,23 @@ from aesara.link.utils import compile_function_src ...@@ -15,13 +16,23 @@ from aesara.link.utils import compile_function_src
from aesara.scan.op import Scan from aesara.scan.op import Scan
def idx_to_str(idx): def idx_to_str(
res = "[i" array_name: str, offset: int, size: Optional[str] = None, idx_symbol: str = "i"
if idx < 0: ) -> str:
res += str(idx) if offset < 0:
elif idx > 0: indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}"
res += "+" + str(idx) elif offset > 0:
return res + "]" indices = f"{idx_symbol} + {offset}"
else:
indices = idx_symbol
if size:
# TODO FIXME: The `Scan` `Op` should tell us which outputs are computed
# in this way. We shouldn't have to waste run-time efforts in order to
# compensate for this poor `Op`/rewrite design and implementation.
indices = f"({indices}) % {size}"
return f"{array_name}[{indices}]"
@overload(range) @overload(range)
...@@ -36,124 +47,267 @@ def array0d_range(x): ...@@ -36,124 +47,267 @@ def array0d_range(x):
@numba_funcify.register(Scan) @numba_funcify.register(Scan)
def numba_funcify_Scan(op, node, **kwargs): def numba_funcify_Scan(op, node, **kwargs):
inner_fg = FunctionGraph(op.inner_inputs, op.inner_outputs) scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
numba_at_inner_func = numba_basic.numba_njit(numba_funcify(inner_fg, **kwargs))
n_seqs = op.info.n_seqs n_seqs = op.info.n_seqs
n_mit_mot = op.info.n_mit_mot
n_mit_sot = op.info.n_mit_sot
n_nit_sot = op.info.n_nit_sot
n_sit_sot = op.info.n_sit_sot
tap_array = op.info.tap_array
n_shared_outs = op.info.n_shared_outs
mit_mot_in_taps = tuple(tap_array[:n_mit_mot])
mit_sot_in_taps = tuple(tap_array[n_mit_mot : n_mit_mot + n_mit_sot])
p_in_mit_mot = n_seqs
p_in_mit_sot = p_in_mit_mot + n_mit_mot
p_in_sit_sot = p_in_mit_sot + n_mit_sot
p_outer_in_shared = p_in_sit_sot + n_sit_sot
p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs
p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot
input_names = [f"outer_in_{i}" for i, n in enumerate(node.inputs[1:])]
outer_in_seqs_names = input_names[:n_seqs]
outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot]
outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot]
outer_in_sit_sot_names = input_names[p_in_sit_sot : p_in_sit_sot + n_sit_sot]
outer_in_shared_names = input_names[
p_outer_in_shared : p_outer_in_shared + n_shared_outs
]
outer_in_nit_sot_names = input_names[
p_outer_in_nit_sot : p_outer_in_nit_sot + n_nit_sot
]
outer_in_feedback_names = input_names[n_seqs:p_outer_in_non_seqs]
outer_in_non_seqs_names = input_names[p_outer_in_non_seqs:]
inner_in_indexed = [] outer_in_names_to_vars = {
allocate_mem_to_nit_sot = "" (f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)
}
outer_in_names = list(outer_in_names_to_vars.keys())
outer_in_seqs_names = op.outer_seqs(outer_in_names)
outer_in_mit_mot_names = op.outer_mitmot(outer_in_names)
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
outer_in_outtap_names = (
outer_in_mit_mot_names
+ outer_in_mit_sot_names
+ outer_in_sit_sot_names
+ outer_in_nit_sot_names
)
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
inner_in_to_index_offset: List[Tuple[str, Optional[int], Optional[int]]] = []
allocate_taps_storage: List[str] = []
for _name in outer_in_seqs_names: for outer_in_name in outer_in_seqs_names:
# A sequence with multiple taps is provided as multiple modified input # A sequence with multiple taps is provided as multiple modified input
# sequences--all sliced so as to keep following the logic of a normal # sequences--all sliced so as to keep following the logic of a normal
# sequence. # sequence.
inner_in_indexed.append(f"{_name}[i]") inner_in_to_index_offset.append((outer_in_name, 0, None))
name_to_input_map = dict(zip(input_names, node.inputs[1:])) inner_in_names_to_input_taps: Dict[str, Tuple[int]] = dict(
mit_sot_name_to_taps = dict(zip(outer_in_mit_sot_names, mit_sot_in_taps)) zip(
inner_out_name_to_index = {} outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names,
for _name in outer_in_feedback_names: op.info.mit_mot_in_slices
if _name in outer_in_mit_sot_names: + op.info.mit_sot_in_slices
curr_taps = mit_sot_name_to_taps[_name] + op.info.sit_sot_in_slices,
min_tap = min(curr_taps) )
)
inner_in_names_to_output_taps: Dict[str, Optional[Tuple[int, ...]]] = dict(
zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices)
)
for _tap in curr_taps: inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]
index = idx_to_str(_tap - min_tap)
inner_in_indexed.append(f"{_name}{index}")
inner_out_name_to_index[_name] = -min_tap # Maps storage array names to their tap values (i.e. maximum absolute tap
# value) and storage sizes
inner_out_name_to_taps_storage: List[Tuple[str, int, Optional[str]]] = []
outer_in_to_storage_name: Dict[str, str] = {}
outer_in_sot_names = set(
outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names
)
inner_out_post_processing_stmts: List[str] = []
for outer_in_name in outer_in_outtap_names:
outer_in_var = outer_in_names_to_vars[outer_in_name]
if _name in outer_in_sit_sot_names: if outer_in_name in outer_in_sot_names:
if outer_in_name in outer_in_mit_mot_names:
storage_name = f"{outer_in_name}_mitmot_storage"
elif outer_in_name in outer_in_mit_sot_names:
storage_name = f"{outer_in_name}_mitsot_storage"
else:
# Note that the outputs with single, non-`-1` taps are (e.g. `taps # Note that the outputs with single, non-`-1` taps are (e.g. `taps
# = [-2]`) are classified as mit-sot, so the code for handling # = [-2]`) are classified as mit-sot, so the code for handling
# sit-sots remains constant as follows # sit-sots remains constant as follows
inner_in_indexed.append(f"{_name}[i]") storage_name = f"{outer_in_name}_sitsot_storage"
inner_out_name_to_index[_name] = 1
output_idx = len(outer_in_to_storage_name)
if _name in outer_in_nit_sot_names: outer_in_to_storage_name[outer_in_name] = storage_name
output_name = f"{_name}_nitsot_storage"
inner_out_name_to_index[output_name] = 0 input_taps = inner_in_names_to_input_taps[outer_in_name]
# In case of nit-sots we are provided the shape of the array tap_storage_size = -min(input_taps)
# instead of actual arrays (like other cases), hence we allocate assert tap_storage_size >= 0
# space for the results accordingly.
curr_nit_sot_position = input_names.index(_name) - n_seqs storage_size_name = f"{outer_in_name}_len"
curr_nit_sot = inner_fg.outputs[curr_nit_sot_position]
mem_shape = ["1"] * curr_nit_sot.ndim for in_tap in input_taps:
curr_dtype = curr_nit_sot.type.numpy_dtype.name tap_offset = in_tap + tap_storage_size
allocate_mem_to_nit_sot += dedent( assert tap_offset >= 0
# In truncated storage situations (i.e. created by
# `save_mem_new_scan`), the taps and output storage overlap,
# instead of the standard situation in which the output storage
# is large enough to contain both the initial taps values and
# the output storage.
inner_in_to_index_offset.append(
(outer_in_name, tap_offset, storage_size_name)
)
output_taps = inner_in_names_to_output_taps.get(
outer_in_name, [tap_storage_size]
)
for out_tap in output_taps:
inner_out_name_to_taps_storage.append(
(storage_name, out_tap, storage_size_name)
)
if output_idx in node.op.destroy_map:
storage_alloc_stmt = f"{storage_name} = {outer_in_name}"
else:
storage_alloc_stmt = f"{storage_name} = np.copy({outer_in_name})"
storage_alloc_stmt = dedent(
f""" f"""
{output_name} = [ # {outer_in_var.type}
np.empty(({create_arg_string(mem_shape)},), dtype=np.{curr_dtype}) for i in range({_name}.item()) {storage_size_name} = {outer_in_name}.shape[0]
]""" {storage_alloc_stmt}
"""
).strip()
allocate_taps_storage.append(storage_alloc_stmt)
elif outer_in_name in outer_in_nit_sot_names:
# This is a special case in which there are no outer-inputs used
# for outer-output storage, so we need to create our own storage
# from scratch.
storage_name = f"{outer_in_name}_nitsot_storage"
outer_in_to_storage_name[outer_in_name] = storage_name
storage_size_name = f"{outer_in_name}_len"
inner_out_name_to_taps_storage.append((storage_name, 0, storage_size_name))
# In case of nit-sots we are provided the length of the array in
# the iteration dimension instead of actual arrays, hence we
# allocate space for the results accordingly.
curr_nit_sot_position = outer_in_names[1:].index(outer_in_name) - n_seqs
curr_nit_sot = op.inner_outputs[curr_nit_sot_position]
needs_alloc = curr_nit_sot.ndim > 0
storage_shape = create_tuple_string(
[storage_size_name] + ["0"] * curr_nit_sot.ndim
)
storage_dtype = curr_nit_sot.type.numpy_dtype.name
allocate_taps_storage.append(
dedent(
f"""
# {curr_nit_sot.type}
{storage_size_name} = to_numba_scalar({outer_in_name})
{storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype})
"""
).strip()
) )
# The non_seqs are passed to inner function as-is if needs_alloc:
inner_in_indexed += outer_in_non_seqs_names allocate_taps_storage.append(f"{outer_in_name}_ready = False")
inner_out_indexed = [
_name + idx_to_str(idx) for _name, idx in inner_out_name_to_index.items() # In this case, we don't know the shape of the output storage
# array until we get some output from the inner-function.
# With the following we add delayed output storage initialization:
inner_out_name = inner_output_names[curr_nit_sot_position]
inner_out_post_processing_stmts.append(
dedent(
f"""
if not {outer_in_name}_ready:
{storage_name} = np.empty(({storage_size_name},) + {inner_out_name}.shape, dtype=np.{storage_dtype})
{outer_in_name}_ready = True
"""
).strip()
)
# The non_seqs are passed to the inner function as-is
for name in outer_in_non_seqs_names:
inner_in_to_index_offset.append((name, None, None))
inner_out_storage_indexed = [
name if taps is None else idx_to_str(name, taps, size=size)
for (name, taps, size) in inner_out_name_to_taps_storage
] ]
while_logic = "" output_storage_post_processing_stmts: List[str] = []
for outer_in_name, grp_vals in groupby(
inner_out_name_to_taps_storage, lambda x: x[0]
):
_, tap_sizes, storage_sizes = zip(*grp_vals)
tap_size = max(tap_sizes)
storage_size = storage_sizes[0]
if op.info.as_while: if op.info.as_while:
# The inner function will return a boolean as the last value # While loops need to truncate the output storage to a length given
inner_out_indexed.append("while_flag") # by the number of iterations performed.
while_logic += """ output_storage_post_processing_stmts.append(
if while_flag: dedent(
""" f"""
for _name, idx in inner_out_name_to_index.items(): if i + {tap_size} < {storage_size}:
while_logic += f""" {storage_size} = i + {tap_size}
{_name} = {_name}[:i+{idx+1}] {outer_in_name} = {outer_in_name}[:{storage_size}]
""" """
while_logic += """ ).strip()
break )
# Rotate the storage so that the last computed value is at the end of
# the storage array.
# This is needed when the output storage array does not have a length
# equal to the number of taps plus `n_steps`.
output_storage_post_processing_stmts.append(
dedent(
f"""
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
if {outer_in_name}_shift > 0:
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left))
""" """
).strip()
)
global_env = locals() if op.info.as_while:
global_env["np"] = np # The inner function will return a boolean as the last value
inner_out_storage_indexed.append("cond")
output_names = [outer_in_to_storage_name[n] for n in outer_in_outtap_names]
output_names = outer_in_mit_sot_names + outer_in_sit_sot_names # Construct the inner-input expressions
output_names += [f"{n}_nitsot_storage" for n in outer_in_nit_sot_names] inner_inputs: List[str] = []
for outer_in_name, tap_offset, size in inner_in_to_index_offset:
storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name)
indexed_inner_in_str = (
idx_to_str(storage_name, tap_offset, size=size)
if tap_offset is not None
else storage_name
)
# if outer_in_names_to_vars[outer_in_name].type.ndim - 1 <= 0:
# # Convert scalar inner-inputs to Numba scalars
# indexed_inner_in_str = f"to_numba_scalar({indexed_inner_in_str})"
inner_inputs.append(indexed_inner_in_str)
inner_inputs = create_arg_string(inner_inputs)
inner_outputs = create_tuple_string(inner_output_names)
input_storage_block = "\n".join(allocate_taps_storage)
output_storage_post_processing_block = "\n".join(
output_storage_post_processing_stmts
)
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
scan_op_src = f""" scan_op_src = f"""
def scan(n_steps, {", ".join(input_names)}): def scan({", ".join(outer_in_names)}):
{indent(allocate_mem_to_nit_sot, " " * 4)}
{indent(input_storage_block, " " * 4)}
i = 0
cond = False
while i < n_steps and not cond:
{inner_outputs} = scan_inner_func({inner_inputs})
{indent(inner_out_post_processing_block, " " * 8)}
{create_tuple_string(inner_out_storage_indexed)} = {inner_outputs}
i += 1
{indent(output_storage_post_processing_block, " " * 4)}
for i in range(n_steps):
inner_args = {create_tuple_string(inner_in_indexed)}
{create_tuple_string(inner_out_indexed)} = numba_at_inner_func(*inner_args)
{while_logic}
return {create_arg_string(output_names)} return {create_arg_string(output_names)}
""" """
global_env = {
"scan_inner_func": scan_inner_func,
"to_numba_scalar": numba_basic.to_scalar,
}
global_env["np"] = np
scalar_op_fn = compile_function_src( scalar_op_fn = compile_function_src(
scan_op_src, "scan", {**globals(), **global_env} scan_op_src, "scan", {**globals(), **global_env}
) )
......
import numpy as np import numpy as np
import pytest
import aesara.tensor as at import aesara.tensor as at
from aesara import config from aesara import config, grad
from aesara.compile.mode import Mode, get_mode
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.scan.basic import scan from aesara.scan.basic import scan
from aesara.scan.utils import until from aesara.scan.utils import until
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849)
def test_scan_multiple_output(): def test_scan_multiple_output():
"""Test a scan implementation of a SEIR model. """Test a scan implementation of a SEIR model.
SEIR model definition: SEIR model definition:
S[t+1] = S[t] - B[t] S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t] E[t+1] = E[t] + B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t] I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta) B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma) C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta) D[t] ~ Binom(I[t], delta)
""" """
def binomln(n, k): def binomln(n, k):
...@@ -198,3 +200,99 @@ def test_scan_multiple_none_output(): ...@@ -198,3 +200,99 @@ def test_scan_multiple_none_output():
test_input_vals = (np.array([1.0, 2.0]),) test_input_vals = (np.array([1.0, 2.0]),)
compare_numba_and_py(out_fg, test_input_vals) compare_numba_and_py(out_fg, test_input_vals)
def test_scan_save_mem_basic():
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
k = at.iscalar("k")
A = at.dvector("A")
result, _ = scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=at.ones_like(A),
non_sequences=A,
n_steps=k,
)
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([A, k], [result])
test_input_vals = (np.arange(10, dtype=np.int32), 2)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
test_input_vals = (np.arange(10, dtype=np.int32), 4)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
@pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_2(n_steps_val):
def f_pow2(x_tm2, x_tm1):
return 2 * x_tm1 + x_tm2
init_x = at.dvector("init_x")
n_steps = at.iscalar("n_steps")
output, _ = scan(
f_pow2,
sequences=[],
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
non_sequences=[],
n_steps=n_steps,
)
state_val = np.array([1.0, 2.0])
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([init_x, n_steps], [output])
test_input_vals = (state_val, n_steps_val)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
def test_grad_sitsot():
def get_sum_of_grad(inp):
scan_outputs, updates = scan(
fn=lambda x: x * 2, outputs_info=[inp], n_steps=5, mode="NUMBA"
)
return grad(scan_outputs.sum(), inp).sum()
floatX = config.floatX
inputs_test_values = [
np.random.default_rng(utt.fetch_seed()).random(3).astype(floatX)
]
utt.verify_grad(get_sum_of_grad, inputs_test_values, mode="NUMBA")
def test_mitmots_basic():
init_x = at.dvector()
seq = at.dvector()
def inner_fct(seq, state_old, state_current):
return state_old * 2 + state_current + seq
out, _ = scan(
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]}
)
g_outs = grad(out.sum(), [seq, init_x])
numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([seq, init_x], g_outs)
seq_val = np.arange(3)
init_x_val = np.r_[-2, -1]
test_input_vals = (seq_val, init_x_val)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论