提交 cb417fe5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Numba scan: reuse scalar arrays for taps from vector inputs

Indexing vector inputs to create taps during scan, yields numeric variables which must be wrapped again into scalar arrays before passing into the inernal function. This commit pre-allocates such arrays and reuses them during looping.
上级 3f5b76b1
...@@ -112,25 +112,46 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -112,25 +112,46 @@ def numba_funcify_Scan(op, node, **kwargs):
# Inner-inputs are ordered as follows: # Inner-inputs are ordered as follows:
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + # sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
# shared-inputs + non-sequences. # shared-inputs + non-sequences.
temp_scalar_storage_alloc_stmts: List[str] = []
inner_in_exprs_scalar: List[str] = []
inner_in_exprs: List[str] = [] inner_in_exprs: List[str] = []
def add_inner_in_expr( def add_inner_in_expr(
outer_in_name: str, tap_offset: Optional[int], storage_size_var: Optional[str] outer_in_name: str,
tap_offset: Optional[int],
storage_size_var: Optional[str],
vector_slice_opt: bool,
): ):
"""Construct an inner-input expression.""" """Construct an inner-input expression."""
storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name) storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name)
indexed_inner_in_str = ( if vector_slice_opt:
storage_name indexed_inner_in_str_scalar = idx_to_str(
if tap_offset is None storage_name, tap_offset, size=storage_size_var, allow_scalar=True
else idx_to_str( )
storage_name, tap_offset, size=storage_size_var, allow_scalar=False temp_storage = f"{storage_name}_temp_scalar_{tap_offset}"
storage_dtype = outer_in_var.type.numpy_dtype.name
temp_scalar_storage_alloc_stmts.append(
f"{temp_storage} = np.empty((), dtype=np.{storage_dtype})"
)
inner_in_exprs_scalar.append(
f"{temp_storage}[()] = {indexed_inner_in_str_scalar}"
)
indexed_inner_in_str = temp_storage
else:
indexed_inner_in_str = (
storage_name
if tap_offset is None
else idx_to_str(
storage_name, tap_offset, size=storage_size_var, allow_scalar=False
)
) )
)
inner_in_exprs.append(indexed_inner_in_str) inner_in_exprs.append(indexed_inner_in_str)
for outer_in_name in outer_in_seqs_names: for outer_in_name in outer_in_seqs_names:
# These outer-inputs are indexed without offsets or storage wrap-around # These outer-inputs are indexed without offsets or storage wrap-around
add_inner_in_expr(outer_in_name, 0, None) outer_in_var = outer_in_names_to_vars[outer_in_name]
is_vector = outer_in_var.ndim == 1
add_inner_in_expr(outer_in_name, 0, None, vector_slice_opt=is_vector)
inner_in_names_to_input_taps: Dict[str, Tuple[int, ...]] = dict( inner_in_names_to_input_taps: Dict[str, Tuple[int, ...]] = dict(
zip( zip(
...@@ -232,7 +253,13 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -232,7 +253,13 @@ def numba_funcify_Scan(op, node, **kwargs):
for in_tap in input_taps: for in_tap in input_taps:
tap_offset = in_tap + tap_storage_size tap_offset = in_tap + tap_storage_size
assert tap_offset >= 0 assert tap_offset >= 0
add_inner_in_expr(outer_in_name, tap_offset, storage_size_name) is_vector = outer_in_var.ndim == 1
add_inner_in_expr(
outer_in_name,
tap_offset,
storage_size_name,
vector_slice_opt=is_vector,
)
output_taps = inner_in_names_to_output_taps.get( output_taps = inner_in_names_to_output_taps.get(
outer_in_name, [tap_storage_size] outer_in_name, [tap_storage_size]
...@@ -253,7 +280,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -253,7 +280,7 @@ def numba_funcify_Scan(op, node, **kwargs):
else: else:
storage_size_stmt = "" storage_size_stmt = ""
add_inner_in_expr(outer_in_name, None, None) add_inner_in_expr(outer_in_name, None, None, vector_slice_opt=False)
inner_out_to_outer_in_stmts.append(storage_name) inner_out_to_outer_in_stmts.append(storage_name)
output_idx = outer_output_names.index(storage_name) output_idx = outer_output_names.index(storage_name)
...@@ -325,7 +352,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -325,7 +352,7 @@ def numba_funcify_Scan(op, node, **kwargs):
) )
for name in outer_in_non_seqs_names: for name in outer_in_non_seqs_names:
add_inner_in_expr(name, None, None) add_inner_in_expr(name, None, None, vector_slice_opt=False)
if op.info.as_while: if op.info.as_while:
# The inner function will return a boolean as the last value # The inner function will return a boolean as the last value
...@@ -333,9 +360,11 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -333,9 +360,11 @@ def numba_funcify_Scan(op, node, **kwargs):
assert len(inner_in_exprs) == len(op.fgraph.inputs) assert len(inner_in_exprs) == len(op.fgraph.inputs)
inner_scalar_in_args_to_temp_storage = "\n".join(inner_in_exprs_scalar)
inner_in_args = create_arg_string(inner_in_exprs) inner_in_args = create_arg_string(inner_in_exprs)
inner_outputs = create_tuple_string(inner_output_names) inner_outputs = create_tuple_string(inner_output_names)
input_storage_block = "\n".join(storage_alloc_stmts) input_storage_block = "\n".join(storage_alloc_stmts)
input_temp_scalar_storage_block = "\n".join(temp_scalar_storage_alloc_stmts)
output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts) output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts)
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts) inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
...@@ -348,9 +377,13 @@ def scan({", ".join(outer_in_names)}): ...@@ -348,9 +377,13 @@ def scan({", ".join(outer_in_names)}):
{indent(input_storage_block, " " * 4)} {indent(input_storage_block, " " * 4)}
{indent(input_temp_scalar_storage_block, " " * 4)}
i = 0 i = 0
cond = np.array(False) cond = np.array(False)
while i < n_steps and not cond.item(): while i < n_steps and not cond.item():
{indent(inner_scalar_in_args_to_temp_storage, " " * 8)}
{inner_outputs} = scan_inner_func({inner_in_args}) {inner_outputs} = scan_inner_func({inner_in_args})
{indent(inner_out_post_processing_block, " " * 8)} {indent(inner_out_post_processing_block, " " * 8)}
{indent(inner_out_to_outer_out_stmts, " " * 8)} {indent(inner_out_to_outer_out_stmts, " " * 8)}
......
import numpy as np import numpy as np
import pytest import pytest
import pytensor
import pytensor.tensor as at import pytensor.tensor as at
from pytensor import config, function, grad from pytensor import config, function, grad
from pytensor.compile.mode import Mode, get_mode from pytensor.compile.mode import Mode, get_mode
...@@ -9,7 +10,7 @@ from pytensor.scalar import Log1p ...@@ -9,7 +10,7 @@ from pytensor.scalar import Log1p
from pytensor.scan.basic import scan from pytensor.scan.basic import scan
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.scan.utils import until from pytensor.scan.utils import until
from pytensor.tensor import log, vector from pytensor.tensor import log, scalar, vector
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import RandomStream from pytensor.tensor.random.utils import RandomStream
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -442,3 +443,54 @@ def test_inner_graph_optimized(): ...@@ -442,3 +443,54 @@ def test_inner_graph_optimized():
assert isinstance(inner_scan_node.op, Elemwise) and isinstance( assert isinstance(inner_scan_node.op, Elemwise) and isinstance(
inner_scan_node.op.scalar_op, Log1p inner_scan_node.op.scalar_op, Log1p
) )
def test_vector_taps_benchmark(benchmark):
"""Test vector taps performance.
Vector taps get indexed into numeric types, that must be wrapped back into
scalar arrays. The numba Scan implementation has an optimization to reuse
these scalar arrays instead of allocating them in every iteration.
"""
n_steps = 1000
seq1 = vector("seq1", dtype="float64", shape=(n_steps,))
seq2 = vector("seq2", dtype="float64", shape=(n_steps,))
mitsot_init = vector("mitsot_init", dtype="float64", shape=(2,))
sitsot_init = scalar("sitsot_init", dtype="float64")
def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
mitsot3 = mitsot1 + seq2 + mitsot2 + seq1
sitsot2 = sitsot1 + mitsot3
return mitsot3, sitsot2
outs, _ = scan(
fn=step,
sequences=[seq1, seq2],
outputs_info=[
dict(initial=mitsot_init, taps=[-2, -1]),
dict(initial=sitsot_init, taps=[-1]),
],
)
rng = np.random.default_rng(474)
test = {
seq1: rng.normal(size=n_steps),
seq2: rng.normal(size=n_steps),
mitsot_init: rng.normal(size=(2,)),
sitsot_init: rng.normal(),
}
numba_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("NUMBA"))
scan_nodes = [
node for node in numba_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
assert len(scan_nodes) == 1
numba_res = numba_fn(*test.values())
ref_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("FAST_COMPILE"))
ref_res = ref_fn(*test.values())
for numba_r, ref_r in zip(numba_res, ref_res):
np.testing.assert_array_almost_equal(numba_r, ref_r)
benchmark(numba_fn, *test.values())
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论