提交 20194cfc authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix JAX list index error in Scan conversion

上级 2e1808a5
......@@ -481,7 +481,7 @@ def jax_funcify_Scan(op):
# + inner_in_non_seqs
inner_in_mit_sot_flatten = []
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
inner_in_mit_sot_flatten.extend(array[index])
inner_in_mit_sot_flatten.extend(array[jnp.array(index)])
inner_scan_inputs = sum(
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论