提交 5c350ab2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix ordering bug between untraced_sit_sot and nit_sot

Internally Scan places those outputs to the right of nit_sot. Helper function reordering to match user definition was not handling this correctly
上级 65cc349a
......@@ -661,7 +661,7 @@ def scan(
else:
actual_n_steps = pt.as_tensor(n_steps, dtype="int64", ndim=0)
# Since we've added all sequences now we need to level them up based on
# Since we've added all sequences now we need to level them off based on
# n_steps or their different shapes
scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
......@@ -1212,8 +1212,8 @@ def scan(
rightOrder = (
mit_sot_rightOrder
+ sit_sot_rightOrder
+ untraced_sit_sot_rightOrder
+ nit_sot_rightOrder
+ untraced_sit_sot_rightOrder
)
scan_out_list = [None] * len(rightOrder)
for idx, pos in enumerate(rightOrder):
......
......@@ -4470,3 +4470,49 @@ class ScanCompatibilityTests:
f_infershape = function([seq, init], out_seq_tm4[1].shape)
scan_nodes_infershape = scan_nodes_from_fct(f_infershape)
assert len(scan_nodes_infershape) == 0
@pytest.mark.parametrize("single_step", (True, False))
def test_scan_mapped_and_non_traced_output_ordering(single_step):
# Regression test for https://github.com/pymc-devs/pytensor/issues/1796
rng = random_generator_type("rng")
def x_then_rng(rng):
next_rng, x = pt.random.normal(rng=rng).owner.outputs
return x, next_rng
xs, final_rng = scan(
fn=x_then_rng,
outputs_info=[None, rng],
n_steps=1 if single_step else 5,
return_updates=False,
)
assert isinstance(xs.type, TensorType)
assert isinstance(final_rng.type, RandomGeneratorType)
def rng_then_x(rng):
x, next_rng = x_then_rng(rng)
return next_rng, x
final_rng, xs = scan(
fn=rng_then_x,
outputs_info=[rng, None],
n_steps=1 if single_step else 5,
return_updates=False,
)
assert isinstance(xs.type, TensorType)
assert isinstance(final_rng.type, RandomGeneratorType)
def rng_between_xs(rng):
x, next_rng = x_then_rng(rng)
return x, next_rng, x + 1, x + 2
xs1, final_rng, xs2, xs3 = scan(
fn=rng_between_xs,
outputs_info=[None, rng, None, None],
n_steps=1 if single_step else 5,
return_updates=False,
)
assert all(isinstance(xs.type, TensorType) for xs in (xs1, xs2, xs3))
assert isinstance(final_rng.type, RandomGeneratorType)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论