提交 25e41c39 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Exclude `numba` rewrites from JAX Scan rewrites

上级 90c6f980
...@@ -29,7 +29,10 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -29,7 +29,10 @@ def jax_funcify_Scan(op: Scan, **kwargs):
# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode) # Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
rewriter = ( rewriter = (
get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer get_mode(op.mode)
.including("jax")
.excluding("numba", *JAX._optimizer.exclude)
.optimizer
) )
rewriter(op.fgraph) rewriter(op.fgraph)
scan_inner_func = jax_funcify(op.fgraph, **kwargs) scan_inner_func = jax_funcify(op.fgraph, **kwargs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论