提交 ad68c7f9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Exclude backend incompatible rewrites in Scan dispatch

上级 8dc67eae
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from pytensor.compile.mode import JAX
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
...@@ -17,8 +18,8 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -17,8 +18,8 @@ def jax_funcify_Scan(op: Scan, **kwargs):
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX" "Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
) )
# Optimize inner graph # Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
rewriter = op.mode_instance.optimizer rewriter = op.mode_instance.excluding(*JAX._optimizer.exclude).optimizer
rewriter(op.fgraph) rewriter(op.fgraph)
scan_inner_func = jax_funcify(op.fgraph, **kwargs) scan_inner_func = jax_funcify(op.fgraph, **kwargs)
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from numba import types from numba import types
from numba.extending import overload from numba.extending import overload
from pytensor.compile.mode import NUMBA
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
create_arg_string, create_arg_string,
...@@ -58,7 +59,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -58,7 +59,7 @@ def numba_funcify_Scan(op, node, **kwargs):
# TODO: Not sure this is the right place to do this, should we have a rewrite that # TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of Scan? # explicitly triggers the optimization of the inner graphs of Scan?
# The C-code defers it to the make_thunk phase # The C-code defers it to the make_thunk phase
rewriter = op.mode_instance.optimizer rewriter = op.mode_instance.excluding(*NUMBA._optimizer.exclude).optimizer
rewriter(op.fgraph) rewriter(op.fgraph)
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph)) scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
......
...@@ -13,7 +13,7 @@ from pytensor.scan.basic import scan ...@@ -13,7 +13,7 @@ from pytensor.scan.basic import scan
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.tensor import random from pytensor.tensor import random
from pytensor.tensor.math import gammaln, log from pytensor.tensor.math import gammaln, log
from pytensor.tensor.type import dmatrix, dvector, lscalar, scalar, vector from pytensor.tensor.type import dmatrix, dvector, lscalar, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -418,3 +418,12 @@ def test_nd_scan_sit_sot_with_carry(): ...@@ -418,3 +418,12 @@ def test_nd_scan_sit_sot_with_carry():
test_input_vals = [x0_val, A_val] test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals) compare_jax_and_py(fg, test_input_vals)
def test_default_mode_excludes_incompatible_rewrites():
# See issue #426
A = matrix("A")
B = matrix("B")
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
fg = FunctionGraph([A, B], [out])
compare_jax_and_py(fg, [np.eye(3), np.eye(3)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论