提交 a37de8a7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Benchmark Scan buffer optimization in Numba

上级 d7edde21
......@@ -339,39 +339,6 @@ def test_scan_multiple_none_output():
compare_numba_and_py([A], result, test_input_vals)
@pytest.mark.parametrize("n_steps_val", [1, 5])
def test_scan_save_mem_basic(n_steps_val):
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
def f_pow2(x_tm2, x_tm1):
return 2 * x_tm1 + x_tm2
init_x = pt.dvector("init_x")
n_steps = pt.iscalar("n_steps")
output, _ = scan(
f_pow2,
sequences=[],
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
non_sequences=[],
n_steps=n_steps,
)
state_val = np.array([1.0, 2.0])
numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
test_input_vals = (state_val, n_steps_val)
compare_numba_and_py(
[init_x, n_steps],
[output],
test_input_vals,
numba_mode=numba_mode,
py_mode=py_mode,
)
def test_grad_sitsot():
def get_sum_of_grad(inp):
scan_outputs, updates = scan(
......@@ -482,3 +449,120 @@ def test_vector_taps_benchmark(benchmark):
np.testing.assert_array_almost_equal(numba_r, ref_r)
benchmark(numba_fn, *test.values())
@pytest.mark.parametrize(
"buffer_size", ("unit", "aligned", "misaligned", "whole", "whole+init")
)
@pytest.mark.parametrize("n_steps, op_size", [(10, 2), (512, 2), (512, 256)])
class TestScanSITSOTBuffer:
def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
x0 = pt.vector(shape=(op_size,), dtype="float64")
xs, _ = pytensor.scan(
fn=lambda xtm1: (xtm1 + 1),
outputs_info=[x0],
n_steps=n_steps - 1, # 1- makes it easier to align/misalign
)
if buffer_size == "unit":
xs_kept = xs[-1] # Only last state is used
expected_buffer_size = 2
elif buffer_size == "aligned":
xs_kept = xs[-2:] # The buffer will be aligned at the end of the 9 steps
expected_buffer_size = 2
elif buffer_size == "misaligned":
xs_kept = xs[-3:] # The buffer will be misaligned at the end of the 9 steps
expected_buffer_size = 3
elif buffer_size == "whole":
xs_kept = xs # What users think is the whole buffer
expected_buffer_size = n_steps - 1
elif buffer_size == "whole+init":
xs_kept = xs.owner.inputs[0] # Whole buffer actually used by Scan
expected_buffer_size = n_steps
x_test = np.zeros(x0.type.shape)
numba_fn, _ = compare_numba_and_py(
[x0],
[xs_kept],
test_inputs=[x_test],
numba_mode="NUMBA", # Default doesn't include optimizations
eval_obj_mode=False,
)
[scan_node] = [
node
for node in numba_fn.maker.fgraph.toposort()
if isinstance(node.op, Scan)
]
buffer = scan_node.inputs[1]
assert buffer.type.shape[0] == expected_buffer_size
if benchmark is not None:
numba_fn.trust_input = True
benchmark(numba_fn, x_test)
def test_sit_sot_buffer(self, n_steps, op_size, buffer_size):
self.buffer_tester(n_steps, op_size, buffer_size, benchmark=None)
def test_sit_sot_buffer_benchmark(self, n_steps, op_size, buffer_size, benchmark):
self.buffer_tester(n_steps, op_size, buffer_size, benchmark=benchmark)
@pytest.mark.parametrize("constant_n_steps", [False, True])
@pytest.mark.parametrize("n_steps_val", [1, 1000])
class TestScanMITSOTBuffer:
def buffer_tester(self, constant_n_steps, n_steps_val, benchmark=None):
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
def f_pow2(x_tm2, x_tm1):
return 2 * x_tm1 + x_tm2
init_x = pt.vector("init_x", shape=(2,))
n_steps = pt.iscalar("n_steps")
output, _ = scan(
f_pow2,
sequences=[],
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
non_sequences=[],
n_steps=n_steps_val if constant_n_steps else n_steps,
)
init_x_val = np.array([1.0, 2.0], dtype=init_x.type.dtype)
test_vals = (
[init_x_val]
if constant_n_steps
else [init_x_val, np.asarray(n_steps_val, dtype=n_steps.type.dtype)]
)
numba_fn, _ = compare_numba_and_py(
[init_x] if constant_n_steps else [init_x, n_steps],
[output[-1]],
test_vals,
numba_mode="NUMBA",
eval_obj_mode=False,
)
if n_steps_val == 1 and constant_n_steps:
# There's no Scan in the graph when nsteps=constant(1)
return
# Check the buffer size as been optimized
[scan_node] = [
node
for node in numba_fn.maker.fgraph.toposort()
if isinstance(node.op, Scan)
]
[mitsot_buffer] = scan_node.op.outer_mitsot(scan_node.inputs)
mitsot_buffer_shape = mitsot_buffer.shape.eval(
{init_x: init_x_val, n_steps: n_steps_val},
accept_inplace=True,
on_unused_input="ignore",
)
assert tuple(mitsot_buffer_shape) == (3,)
if benchmark is not None:
numba_fn.trust_input = True
benchmark(numba_fn, *test_vals)
def test_mit_sot_buffer(self, constant_n_steps, n_steps_val):
self.buffer_tester(constant_n_steps, n_steps_val, benchmark=None)
def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark):
self.buffer_tester(constant_n_steps, n_steps_val, benchmark=benchmark)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论