提交 92420c8f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not try to save initial values buffer size in Scan

This will always require a roll at the end, for a minimal gain
上级 a37de8a7
......@@ -1420,9 +1420,18 @@ def scan_save_mem(fgraph, node):
store_steps[i] = 0
break
if isinstance(this_slice[0], slice) and this_slice[0].start is None:
store_steps[i] = 0
break
if isinstance(this_slice[0], slice):
start = this_slice[0].start
if isinstance(start, Constant):
start = start.data
# Don't do anything if the subtensor is starting from the beginning of the buffer
# Or just skipping the initial values (default output returned to the user).
# Trimming the initial values would require a roll to align the buffer once scan is done
# As it always starts writing at position [0+max(taps)], and ends up at position [:max(taps)]
# It's cheaper to just keep the initial values in the buffer and slice them away (default output)
if start in (0, None, init_l[i]):
store_steps[i] = 0
break
# Special case for recurrent outputs where only the last result
# is requested. This is needed for this rewrite to apply to
......
......@@ -474,7 +474,7 @@ class TestScanSITSOTBuffer:
expected_buffer_size = 3
elif buffer_size == "whole":
xs_kept = xs # What users think is the whole buffer
expected_buffer_size = n_steps - 1
expected_buffer_size = n_steps
elif buffer_size == "whole+init":
xs_kept = xs.owner.inputs[0] # Whole buffer actually used by Scan
expected_buffer_size = n_steps
......
......@@ -643,35 +643,37 @@ def test_debugprint_compiled_fn():
# (i.e. from `Scan._fn`)
out = pytensor.function([M], out, updates=updates, mode="FAST_RUN")
expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0)
├─ 20000 [id B] (n_steps)
├─ [ 0 ... 998 19999] [id C] (outer_in_seqs-0)
├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
│ ├─ AllocEmpty{dtype='int64'} [id E] 0
│ │ └─ 20000 [id B]
│ ├─ [0] [id F]
│ └─ 1 [id G]
└─ <Tensor3(float64, shape=(20000, 2, 2))> [id H] (outer_in_non_seqs-0)
Inner graphs:
Scan{scan_fn, while_loop=False, inplace=all} [id A]
← Composite{switch(lt(0, i0), 1, 0)} [id I] (inner_out_sit_sot-0)
└─ Subtensor{i, j, k} [id J]
├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id K] -> [id H] (inner_in_non_seqs-0)
├─ ScalarFromTensor [id L]
│ └─ *0-<Scalar(int64, shape=())> [id M] -> [id C] (inner_in_seqs-0)
├─ ScalarFromTensor [id N]
│ └─ *1-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_sit_sot-0)
└─ 0 [id P]
Composite{switch(lt(0, i0), 1, 0)} [id I]
← Switch [id Q] 'o0'
├─ LT [id R]
│ ├─ 0 [id S]
│ └─ i0 [id T]
├─ 1 [id U]
└─ 0 [id S]
expected_output = """Subtensor{start:} [id A] 3
├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] 2 (outer_out_sit_sot-0)
│ ├─ 20000 [id C] (n_steps)
│ ├─ [ 0 ... 998 19999] [id D] (outer_in_seqs-0)
│ ├─ SetSubtensor{:stop} [id E] 1 (outer_in_sit_sot-0)
│ │ ├─ AllocEmpty{dtype='int64'} [id F] 0
│ │ │ └─ 20001 [id G]
│ │ ├─ [0] [id H]
│ │ └─ 1 [id I]
│ └─ <Tensor3(float64, shape=(20000, 2, 2))> [id J] (outer_in_non_seqs-0)
└─ 1 [id I]
Inner graphs:
Scan{scan_fn, while_loop=False, inplace=all} [id B]
← Composite{switch(lt(0, i0), 1, 0)} [id K] (inner_out_sit_sot-0)
└─ Subtensor{i, j, k} [id L]
├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id M] -> [id J] (inner_in_non_seqs-0)
├─ ScalarFromTensor [id N]
│ └─ *0-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_seqs-0)
├─ ScalarFromTensor [id P]
│ └─ *1-<Scalar(int64, shape=())> [id Q] -> [id E] (inner_in_sit_sot-0)
└─ 0 [id R]
Composite{switch(lt(0, i0), 1, 0)} [id K]
← Switch [id S] 'o0'
├─ LT [id T]
│ ├─ 0 [id U]
│ └─ i0 [id V]
├─ 1 [id W]
└─ 0 [id U]
"""
output_str = debugprint(out, file="str", print_op_info=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论