提交 28a7199a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Update FunctionGraph result usage in JAX Scan conversion

上级 0ee5acfe
...@@ -492,7 +492,7 @@ def jax_funcify_Scan(op, **kwargs): ...@@ -492,7 +492,7 @@ def jax_funcify_Scan(op, **kwargs):
def jax_inner_func(carry, x): def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x) inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = [fn(*inner_args) for fn in jax_aet_inner_func] inner_scan_outs = list(jax_aet_inner_func(*inner_args))
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs) new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
return new_carry, inner_scan_outs return new_carry, inner_scan_outs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论