提交 ee47dcc9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba BlockDiag: Fix failure with mixed readable/non-readable arrays

上级 c48a8b3a
......@@ -3,6 +3,7 @@ import warnings
import numpy as np
from pytensor import config
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
......@@ -30,6 +31,10 @@ from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric
from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
from pytensor.link.numba.dispatch.string_codegen import (
CODE_TOKEN,
build_source_code,
)
from pytensor.tensor.slinalg import (
LU,
QR,
......@@ -222,24 +227,69 @@ def numba_funcify_LUFactor(op, node, **kwargs):
@register_funcify_default_op_cache_key(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype
"""
@numba_basic.numba_njit
def block_diag(*arrs):
shapes = np.array([a.shape for a in arrs], dtype="int")
out_shape = [int(s) for s in np.sum(shapes, axis=0)]
out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype)
Because we have variadic arguments we need to use codegen.
The generated code looks something like:
def block_diagonal(arr0, arr1, arr2):
out_r = arr0.shape[0] + arr1.shape[0] + arr2.shape[0]
out_c = arr0.shape[1] + arr1.shape[1] + arr2.shape[1]
out = np.zeros((out_r, out_c), dtype=np.float64)
r, c = 0, 0
# no strict argument because it is incompatible with numba
for arr, shape in zip(arrs, shapes):
rr, cc = shape
out[r : r + rr, c : c + cc] = arr
rr, cc = arr0.shape
out[r: r + rr, c: c + cc] = arr0
r += rr
c += cc
rr, cc = arr1.shape
out[r: r + rr, c: c + cc] = arr1
r += rr
c += cc
rr, cc = arr2.shape
out[r: r + rr, c: c + cc] = arr2
r += rr
c += cc
return out
"""
dtype = node.outputs[0].dtype
n_inp = len(node.inputs)
arg_names = [f"arr{i}" for i in range(n_inp)]
code = [
f"def block_diagonal({', '.join(arg_names)}):",
CODE_TOKEN.INDENT,
f"out_r = {' + '.join(f'{a}.shape[0]' for a in arg_names)}",
f"out_c = {' + '.join(f'{a}.shape[1]' for a in arg_names)}",
f"out = np.zeros((out_r, out_c), dtype=np.{dtype})",
CODE_TOKEN.EMPTY_LINE,
"r, c = 0, 0",
]
for i, arg_name in enumerate(arg_names):
code.extend(
[
f"rr, cc = {arg_name}.shape",
f"out[r: r + rr, c: c + cc] = {arg_name}",
"r += rr",
"c += cc",
CODE_TOKEN.EMPTY_LINE,
]
)
code.append("return out")
code_txt = build_source_code(code)
block_diag = compile_numba_function_src(
code_txt,
"block_diagonal",
globals() | {"np": np},
)
return block_diag
cache_key = 1
return numba_basic.numba_njit(block_diag), cache_key
@register_funcify_default_op_cache_key(Solve)
......
from collections.abc import Sequence
from enum import Enum, auto
def create_tuple_string(x):
if len(x) == 1:
return f"({x[0]},)"
else:
return f"({', '.join(x)})"
class CODE_TOKEN(Enum):
INDENT = auto()
DEDENT = auto()
EMPTY_LINE = auto()
def build_source_code(code: Sequence[str | CODE_TOKEN]) -> str:
lines = []
indentation_level = 0
for line in code:
if line is CODE_TOKEN.INDENT:
indentation_level += 1
elif line is CODE_TOKEN.DEDENT:
indentation_level -= 1
assert indentation_level >= 0
elif line is CODE_TOKEN.EMPTY_LINE:
lines.append("")
else:
lines.append(f"{' ' * indentation_level}{line}")
return "\n".join(lines)
......@@ -811,6 +811,24 @@ def test_block_diag():
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
def test_block_diag_with_read_only_inp():
# Regression test where numba would complain a about *args containing both read-only and regular inputs
# Currently, constants are read-only for numba, but for future-proofing we add an explicitly read-only input as well
x = pt.tensor("x", shape=(2, 2))
x_read_only = pt.tensor("x", shape=(2, 2))
x_const = pt.constant(np.ones((2, 2), dtype=x.type.dtype), name="x_read_only")
out = pt.linalg.block_diag(x, x_read_only, x_const)
x_test = np.ones((2, 2), dtype=x.type.dtype)
x_read_only_test = x_test.copy()
x_read_only_test.flags.writeable = False
compare_numba_and_py(
[x, x_read_only],
[out],
[x_test, x_read_only_test],
)
@pytest.mark.parametrize("inverse", [True, False], ids=["p_inv", "p"])
def test_pivot_to_permutation(inverse):
from pytensor.tensor.slinalg import pivot_to_permutation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论