提交 5547eb08 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: Ricardo Vieira

Don't return tuple from `jax.scipy.linalg.qr` when `mode = 'r'` (only one return)

上级 84252b5b
...@@ -177,7 +177,8 @@ def jax_funcify_QR(op, **kwargs): ...@@ -177,7 +177,8 @@ def jax_funcify_QR(op, **kwargs):
mode = op.mode mode = op.mode
def qr(x, mode=mode): def qr(x, mode=mode):
return jax.scipy.linalg.qr(x, mode=mode) res = jax.scipy.linalg.qr(x, mode=mode)
return res[0] if len(res) == 1 else res
return qr return qr
......
...@@ -370,3 +370,15 @@ def test_jax_expm(): ...@@ -370,3 +370,15 @@ def test_jax_expm():
out = pt_slinalg.expm(A) out = pt_slinalg.expm(A)
compare_jax_and_py([A], [out], [A_val]) compare_jax_and_py([A], [out], [A_val])
@pytest.mark.parametrize("mode", ["full", "r"])
def test_jax_qr(mode):
# "full" and "r" modes are tested because "full" returns two matrices (Q, R), while (R,) returns only one.
# Pytensor does not return a tuple when only one output is expected.
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor(name="A", shape=(5, 5))
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
out = pt_slinalg.qr(A, mode=mode)
compare_jax_and_py([A], out, [A_val])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论