提交 3d96ee80 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix storage handling in numba_funcify_Scan

上级 a2d05adc
from itertools import groupby
from textwrap import dedent, indent from textwrap import dedent, indent
from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
from numba import types from numba import types
from numba.extending import overload from numba.extending import overload
from aesara.graph.fg import FunctionGraph
from aesara.link.numba.dispatch import basic as numba_basic from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import ( from aesara.link.numba.dispatch.basic import (
create_arg_string, create_arg_string,
...@@ -15,13 +16,23 @@ from aesara.link.utils import compile_function_src ...@@ -15,13 +16,23 @@ from aesara.link.utils import compile_function_src
from aesara.scan.op import Scan from aesara.scan.op import Scan
def idx_to_str(idx): def idx_to_str(
res = "[i" array_name: str, offset: int, size: Optional[str] = None, idx_symbol: str = "i"
if idx < 0: ) -> str:
res += str(idx) if offset < 0:
elif idx > 0: indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}"
res += "+" + str(idx) elif offset > 0:
return res + "]" indices = f"{idx_symbol} + {offset}"
else:
indices = idx_symbol
if size:
# TODO FIXME: The `Scan` `Op` should tell us which outputs are computed
# in this way. We shouldn't have to waste run-time efforts in order to
# compensate for this poor `Op`/rewrite design and implementation.
indices = f"({indices}) % {size}"
return f"{array_name}[{indices}]"
@overload(range) @overload(range)
...@@ -36,124 +47,267 @@ def array0d_range(x): ...@@ -36,124 +47,267 @@ def array0d_range(x):
@numba_funcify.register(Scan) @numba_funcify.register(Scan)
def numba_funcify_Scan(op, node, **kwargs): def numba_funcify_Scan(op, node, **kwargs):
inner_fg = FunctionGraph(op.inner_inputs, op.inner_outputs) scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
numba_at_inner_func = numba_basic.numba_njit(numba_funcify(inner_fg, **kwargs))
n_seqs = op.info.n_seqs n_seqs = op.info.n_seqs
n_mit_mot = op.info.n_mit_mot
n_mit_sot = op.info.n_mit_sot
n_nit_sot = op.info.n_nit_sot
n_sit_sot = op.info.n_sit_sot
tap_array = op.info.tap_array
n_shared_outs = op.info.n_shared_outs
mit_mot_in_taps = tuple(tap_array[:n_mit_mot])
mit_sot_in_taps = tuple(tap_array[n_mit_mot : n_mit_mot + n_mit_sot])
p_in_mit_mot = n_seqs
p_in_mit_sot = p_in_mit_mot + n_mit_mot
p_in_sit_sot = p_in_mit_sot + n_mit_sot
p_outer_in_shared = p_in_sit_sot + n_sit_sot
p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs
p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot
input_names = [f"outer_in_{i}" for i, n in enumerate(node.inputs[1:])]
outer_in_seqs_names = input_names[:n_seqs]
outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot]
outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot]
outer_in_sit_sot_names = input_names[p_in_sit_sot : p_in_sit_sot + n_sit_sot]
outer_in_shared_names = input_names[
p_outer_in_shared : p_outer_in_shared + n_shared_outs
]
outer_in_nit_sot_names = input_names[
p_outer_in_nit_sot : p_outer_in_nit_sot + n_nit_sot
]
outer_in_feedback_names = input_names[n_seqs:p_outer_in_non_seqs]
outer_in_non_seqs_names = input_names[p_outer_in_non_seqs:]
inner_in_indexed = [] outer_in_names_to_vars = {
allocate_mem_to_nit_sot = "" (f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)
}
outer_in_names = list(outer_in_names_to_vars.keys())
outer_in_seqs_names = op.outer_seqs(outer_in_names)
outer_in_mit_mot_names = op.outer_mitmot(outer_in_names)
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
outer_in_outtap_names = (
outer_in_mit_mot_names
+ outer_in_mit_sot_names
+ outer_in_sit_sot_names
+ outer_in_nit_sot_names
)
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
inner_in_to_index_offset: List[Tuple[str, Optional[int], Optional[int]]] = []
allocate_taps_storage: List[str] = []
for _name in outer_in_seqs_names: for outer_in_name in outer_in_seqs_names:
# A sequence with multiple taps is provided as multiple modified input # A sequence with multiple taps is provided as multiple modified input
# sequences--all sliced so as to keep following the logic of a normal # sequences--all sliced so as to keep following the logic of a normal
# sequence. # sequence.
inner_in_indexed.append(f"{_name}[i]") inner_in_to_index_offset.append((outer_in_name, 0, None))
name_to_input_map = dict(zip(input_names, node.inputs[1:])) inner_in_names_to_input_taps: Dict[str, Tuple[int]] = dict(
mit_sot_name_to_taps = dict(zip(outer_in_mit_sot_names, mit_sot_in_taps)) zip(
inner_out_name_to_index = {} outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names,
for _name in outer_in_feedback_names: op.info.mit_mot_in_slices
if _name in outer_in_mit_sot_names: + op.info.mit_sot_in_slices
curr_taps = mit_sot_name_to_taps[_name] + op.info.sit_sot_in_slices,
min_tap = min(curr_taps) )
)
for _tap in curr_taps: inner_in_names_to_output_taps: Dict[str, Optional[Tuple[int, ...]]] = dict(
index = idx_to_str(_tap - min_tap) zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices)
inner_in_indexed.append(f"{_name}{index}") )
inner_out_name_to_index[_name] = -min_tap inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]
if _name in outer_in_sit_sot_names: # Maps storage array names to their tap values (i.e. maximum absolute tap
# Note that the outputs with single, non-`-1` taps are (e.g. `taps # value) and storage sizes
# = [-2]`) are classified as mit-sot, so the code for handling inner_out_name_to_taps_storage: List[Tuple[str, int, Optional[str]]] = []
# sit-sots remains constant as follows outer_in_to_storage_name: Dict[str, str] = {}
inner_in_indexed.append(f"{_name}[i]") outer_in_sot_names = set(
inner_out_name_to_index[_name] = 1 outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names
)
if _name in outer_in_nit_sot_names: inner_out_post_processing_stmts: List[str] = []
output_name = f"{_name}_nitsot_storage" for outer_in_name in outer_in_outtap_names:
inner_out_name_to_index[output_name] = 0 outer_in_var = outer_in_names_to_vars[outer_in_name]
# In case of nit-sots we are provided the shape of the array
# instead of actual arrays (like other cases), hence we allocate if outer_in_name in outer_in_sot_names:
# space for the results accordingly. if outer_in_name in outer_in_mit_mot_names:
curr_nit_sot_position = input_names.index(_name) - n_seqs storage_name = f"{outer_in_name}_mitmot_storage"
curr_nit_sot = inner_fg.outputs[curr_nit_sot_position] elif outer_in_name in outer_in_mit_sot_names:
mem_shape = ["1"] * curr_nit_sot.ndim storage_name = f"{outer_in_name}_mitsot_storage"
curr_dtype = curr_nit_sot.type.numpy_dtype.name else:
allocate_mem_to_nit_sot += dedent( # Note that the outputs with single, non-`-1` taps are (e.g. `taps
# = [-2]`) are classified as mit-sot, so the code for handling
# sit-sots remains constant as follows
storage_name = f"{outer_in_name}_sitsot_storage"
output_idx = len(outer_in_to_storage_name)
outer_in_to_storage_name[outer_in_name] = storage_name
input_taps = inner_in_names_to_input_taps[outer_in_name]
tap_storage_size = -min(input_taps)
assert tap_storage_size >= 0
storage_size_name = f"{outer_in_name}_len"
for in_tap in input_taps:
tap_offset = in_tap + tap_storage_size
assert tap_offset >= 0
# In truncated storage situations (i.e. created by
# `save_mem_new_scan`), the taps and output storage overlap,
# instead of the standard situation in which the output storage
# is large enough to contain both the initial taps values and
# the output storage.
inner_in_to_index_offset.append(
(outer_in_name, tap_offset, storage_size_name)
)
output_taps = inner_in_names_to_output_taps.get(
outer_in_name, [tap_storage_size]
)
for out_tap in output_taps:
inner_out_name_to_taps_storage.append(
(storage_name, out_tap, storage_size_name)
)
if output_idx in node.op.destroy_map:
storage_alloc_stmt = f"{storage_name} = {outer_in_name}"
else:
storage_alloc_stmt = f"{storage_name} = np.copy({outer_in_name})"
storage_alloc_stmt = dedent(
f""" f"""
{output_name} = [ # {outer_in_var.type}
np.empty(({create_arg_string(mem_shape)},), dtype=np.{curr_dtype}) for i in range({_name}.item()) {storage_size_name} = {outer_in_name}.shape[0]
]""" {storage_alloc_stmt}
"""
).strip()
allocate_taps_storage.append(storage_alloc_stmt)
elif outer_in_name in outer_in_nit_sot_names:
# This is a special case in which there are no outer-inputs used
# for outer-output storage, so we need to create our own storage
# from scratch.
storage_name = f"{outer_in_name}_nitsot_storage"
outer_in_to_storage_name[outer_in_name] = storage_name
storage_size_name = f"{outer_in_name}_len"
inner_out_name_to_taps_storage.append((storage_name, 0, storage_size_name))
# In case of nit-sots we are provided the length of the array in
# the iteration dimension instead of actual arrays, hence we
# allocate space for the results accordingly.
curr_nit_sot_position = outer_in_names[1:].index(outer_in_name) - n_seqs
curr_nit_sot = op.inner_outputs[curr_nit_sot_position]
needs_alloc = curr_nit_sot.ndim > 0
storage_shape = create_tuple_string(
[storage_size_name] + ["0"] * curr_nit_sot.ndim
)
storage_dtype = curr_nit_sot.type.numpy_dtype.name
allocate_taps_storage.append(
dedent(
f"""
# {curr_nit_sot.type}
{storage_size_name} = to_numba_scalar({outer_in_name})
{storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype})
"""
).strip()
) )
# The non_seqs are passed to inner function as-is if needs_alloc:
inner_in_indexed += outer_in_non_seqs_names allocate_taps_storage.append(f"{outer_in_name}_ready = False")
inner_out_indexed = [
_name + idx_to_str(idx) for _name, idx in inner_out_name_to_index.items() # In this case, we don't know the shape of the output storage
# array until we get some output from the inner-function.
# With the following we add delayed output storage initialization:
inner_out_name = inner_output_names[curr_nit_sot_position]
inner_out_post_processing_stmts.append(
dedent(
f"""
if not {outer_in_name}_ready:
{storage_name} = np.empty(({storage_size_name},) + {inner_out_name}.shape, dtype=np.{storage_dtype})
{outer_in_name}_ready = True
"""
).strip()
)
# The non_seqs are passed to the inner function as-is
for name in outer_in_non_seqs_names:
inner_in_to_index_offset.append((name, None, None))
inner_out_storage_indexed = [
name if taps is None else idx_to_str(name, taps, size=size)
for (name, taps, size) in inner_out_name_to_taps_storage
] ]
while_logic = "" output_storage_post_processing_stmts: List[str] = []
for outer_in_name, grp_vals in groupby(
inner_out_name_to_taps_storage, lambda x: x[0]
):
_, tap_sizes, storage_sizes = zip(*grp_vals)
tap_size = max(tap_sizes)
storage_size = storage_sizes[0]
if op.info.as_while:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_processing_stmts.append(
dedent(
f"""
if i + {tap_size} < {storage_size}:
{storage_size} = i + {tap_size}
{outer_in_name} = {outer_in_name}[:{storage_size}]
"""
).strip()
)
# Rotate the storage so that the last computed value is at the end of
# the storage array.
# This is needed when the output storage array does not have a length
# equal to the number of taps plus `n_steps`.
output_storage_post_processing_stmts.append(
dedent(
f"""
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
if {outer_in_name}_shift > 0:
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left))
"""
).strip()
)
if op.info.as_while: if op.info.as_while:
# The inner function will return a boolean as the last value # The inner function will return a boolean as the last value
inner_out_indexed.append("while_flag") inner_out_storage_indexed.append("cond")
while_logic += """
if while_flag:
"""
for _name, idx in inner_out_name_to_index.items():
while_logic += f"""
{_name} = {_name}[:i+{idx+1}]
"""
while_logic += """
break
"""
global_env = locals()
global_env["np"] = np
output_names = outer_in_mit_sot_names + outer_in_sit_sot_names output_names = [outer_in_to_storage_name[n] for n in outer_in_outtap_names]
output_names += [f"{n}_nitsot_storage" for n in outer_in_nit_sot_names]
# Construct the inner-input expressions
inner_inputs: List[str] = []
for outer_in_name, tap_offset, size in inner_in_to_index_offset:
storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name)
indexed_inner_in_str = (
idx_to_str(storage_name, tap_offset, size=size)
if tap_offset is not None
else storage_name
)
# if outer_in_names_to_vars[outer_in_name].type.ndim - 1 <= 0:
# # Convert scalar inner-inputs to Numba scalars
# indexed_inner_in_str = f"to_numba_scalar({indexed_inner_in_str})"
inner_inputs.append(indexed_inner_in_str)
inner_inputs = create_arg_string(inner_inputs)
inner_outputs = create_tuple_string(inner_output_names)
input_storage_block = "\n".join(allocate_taps_storage)
output_storage_post_processing_block = "\n".join(
output_storage_post_processing_stmts
)
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
scan_op_src = f""" scan_op_src = f"""
def scan(n_steps, {", ".join(input_names)}): def scan({", ".join(outer_in_names)}):
{indent(allocate_mem_to_nit_sot, " " * 4)}
{indent(input_storage_block, " " * 4)}
i = 0
cond = False
while i < n_steps and not cond:
{inner_outputs} = scan_inner_func({inner_inputs})
{indent(inner_out_post_processing_block, " " * 8)}
{create_tuple_string(inner_out_storage_indexed)} = {inner_outputs}
i += 1
{indent(output_storage_post_processing_block, " " * 4)}
for i in range(n_steps):
inner_args = {create_tuple_string(inner_in_indexed)}
{create_tuple_string(inner_out_indexed)} = numba_at_inner_func(*inner_args)
{while_logic}
return {create_arg_string(output_names)} return {create_arg_string(output_names)}
""" """
global_env = {
"scan_inner_func": scan_inner_func,
"to_numba_scalar": numba_basic.to_scalar,
}
global_env["np"] = np
scalar_op_fn = compile_function_src( scalar_op_fn = compile_function_src(
scan_op_src, "scan", {**globals(), **global_env} scan_op_src, "scan", {**globals(), **global_env}
) )
......
import numpy as np import numpy as np
import pytest
import aesara.tensor as at import aesara.tensor as at
from aesara import config from aesara import config, grad
from aesara.compile.mode import Mode, get_mode
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.scan.basic import scan from aesara.scan.basic import scan
from aesara.scan.utils import until from aesara.scan.utils import until
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849)
def test_scan_multiple_output(): def test_scan_multiple_output():
"""Test a scan implementation of a SEIR model. """Test a scan implementation of a SEIR model.
SEIR model definition: SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta) S[t+1] = S[t] - B[t]
C[t] ~ Binom(E[t], gamma) E[t+1] = E[t] + B[t] - C[t]
D[t] ~ Binom(I[t], delta) I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
""" """
def binomln(n, k): def binomln(n, k):
...@@ -198,3 +200,99 @@ def test_scan_multiple_none_output(): ...@@ -198,3 +200,99 @@ def test_scan_multiple_none_output():
test_input_vals = (np.array([1.0, 2.0]),) test_input_vals = (np.array([1.0, 2.0]),)
compare_numba_and_py(out_fg, test_input_vals) compare_numba_and_py(out_fg, test_input_vals)
def test_scan_save_mem_basic():
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
k = at.iscalar("k")
A = at.dvector("A")
result, _ = scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=at.ones_like(A),
non_sequences=A,
n_steps=k,
)
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([A, k], [result])
test_input_vals = (np.arange(10, dtype=np.int32), 2)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
test_input_vals = (np.arange(10, dtype=np.int32), 4)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
@pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_2(n_steps_val):
def f_pow2(x_tm2, x_tm1):
return 2 * x_tm1 + x_tm2
init_x = at.dvector("init_x")
n_steps = at.iscalar("n_steps")
output, _ = scan(
f_pow2,
sequences=[],
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
non_sequences=[],
n_steps=n_steps,
)
state_val = np.array([1.0, 2.0])
numba_mode = get_mode("NUMBA") # .including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([init_x, n_steps], [output])
test_input_vals = (state_val, n_steps_val)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
def test_grad_sitsot():
def get_sum_of_grad(inp):
scan_outputs, updates = scan(
fn=lambda x: x * 2, outputs_info=[inp], n_steps=5, mode="NUMBA"
)
return grad(scan_outputs.sum(), inp).sum()
floatX = config.floatX
inputs_test_values = [
np.random.default_rng(utt.fetch_seed()).random(3).astype(floatX)
]
utt.verify_grad(get_sum_of_grad, inputs_test_values, mode="NUMBA")
def test_mitmots_basic():
init_x = at.dvector()
seq = at.dvector()
def inner_fct(seq, state_old, state_current):
return state_old * 2 + state_current + seq
out, _ = scan(
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]}
)
g_outs = grad(out.sum(), [seq, init_x])
numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([seq, init_x], g_outs)
seq_val = np.arange(3)
init_x_val = np.r_[-2, -1]
test_input_vals = (seq_val, init_x_val)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论