提交 ed6ca162 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Inplace Blockwise and core versions of Cholesky and Solve Ops.

上级 b8dbd4ca
......@@ -583,6 +583,12 @@ class Op(MetaObject):
)
return self.make_py_thunk(node, storage_map, compute_map, no_recycling)
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
"""Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`."""
# TODO: Document this in the Create your own Op docs
# By default, do nothing
return self
def __str__(self):
return getattr(type(self), "__name__", super().__str__())
......
......@@ -45,6 +45,7 @@ class Blockwise(Op):
signature: str | None = None,
name: str | None = None,
gufunc_spec: tuple[str, int, int] | None = None,
destroy_map=None,
**kwargs,
):
"""
......@@ -79,6 +80,15 @@ class Blockwise(Op):
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec
self._gufunc = None
if destroy_map is not None:
self.destroy_map = destroy_map
if self.destroy_map != core_op.destroy_map:
# Note: Should be fine for destroy_map of Blockwise to be more extensive than that of core_op
# But we are not using that anywhere yet, so this check is fine for now
raise ValueError(
f"Blockwise destroy_map {self.destroy_map} must be the same as that of the core_op {core_op} {core_op.destroy_map}"
)
super().__init__(**kwargs)
def __getstate__(self):
......
import itertools
from pytensor.compile import Supervisor
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot
......@@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node):
# We register this rewrite late, so that other rewrites need only target Blockwise Ops
# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
optdb.register(
"local_useless_unbatched_blockwise",
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
"fast_run",
"fast_compile",
"blockwise",
position=49,
position=60,
)
......@@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node):
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
@node_rewriter(tracks=[Blockwise], inplace=True)
def blockwise_inplace(fgraph, node):
blockwise_op = node.op
if blockwise_op.destroy_map:
# Op already has inplace
return
# Find out valid inputs for inplacing
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
protected_inputs = [
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
protected_inputs.extend(fgraph.outputs)
allowed_inplace_inputs = [
idx
for idx, inp in enumerate(node.inputs)
if
(
# Constants would need to be recreated every time if inplaced
not isinstance(inp, Constant)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and not fgraph.has_destroyers([inp])
and inp not in protected_inputs
)
]
if not allowed_inplace_inputs:
return None
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
)
if not inplace_core_op.destroy_map:
return None
# Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in inplace_core_op.destroy_map.values():
for destroyed_input in destroyed_inputs:
if destroyed_input not in allowed_inplace_inputs:
raise ValueError(
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
)
# Recreate core_op with inplace
inplace_blockwise_op = Blockwise(
core_op=inplace_core_op,
signature=blockwise_op.signature,
name=blockwise_op.name,
gufunc_spec=blockwise_op.gufunc_spec,
destroy_map=inplace_core_op.destroy_map,
)
out = inplace_blockwise_op.make_node(*node.inputs).outputs
copy_stack_trace(node.outputs, out)
return out
optdb.register(
"blockwise_inplace",
in2out(blockwise_inplace),
"fast_run",
"inplace",
position=50.1,
)
......@@ -28,57 +28,68 @@ logger = logging.getLogger(__name__)
class Cholesky(Op):
"""
Return a triangular matrix square root of positive semi-definite `x`.
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
Parameters
----------
lower : bool, default=True
Whether to return the lower or upper cholesky factor
on_error : ['raise', 'nan']
If on_error is set to 'raise', this Op will raise a
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
If on_error is set to 'nan', it will return a matrix containing
nans instead.
"""
# TODO: inplace
# TODO: for specific dtypes
# TODO: LAPACK wrapper with in-place behavior, for solve also
__props__ = ("lower", "destructive", "on_error")
__props__ = ("lower", "check_finite", "on_error", "overwrite_a")
gufunc_signature = "(m,m)->(m,m)"
def __init__(self, *, lower=True, check_finite=True, on_error="raise"):
def __init__(
self,
*,
lower: bool = True,
check_finite: bool = True,
on_error: Literal["raise", "nan"] = "raise",
overwrite_a: bool = False,
):
self.lower = lower
self.destructive = False
self.check_finite = check_finite
if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or ""nan"')
self.on_error = on_error
self.overwrite_a = overwrite_a
if self.overwrite_a:
self.destroy_map = {0: [0]}
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2
return Apply(self, [x], [x.type()])
if x.type.ndim != 2:
raise TypeError(
f"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
)
# Call scipy to find output dtype
dtype = scipy.linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype
return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)])
def perform(self, node, inputs, outputs):
x = inputs[0]
z = outputs[0]
[x] = inputs
[out] = outputs
try:
z[0] = scipy.linalg.cholesky(
x, lower=self.lower, check_finite=self.check_finite
).astype(x.dtype)
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
out[0] = scipy.linalg.cholesky(
x.T,
lower=not self.lower,
check_finite=self.check_finite,
overwrite_a=True,
).T
else:
out[0] = scipy.linalg.cholesky(
x,
lower=self.lower,
check_finite=self.check_finite,
overwrite_a=self.overwrite_a,
)
except scipy.linalg.LinAlgError:
if self.on_error == "raise":
raise
else:
z[0] = (np.zeros(x.shape) * np.nan).astype(x.dtype)
out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype)
def L_op(self, inputs, outputs, gradients):
"""
......@@ -131,11 +142,66 @@ class Cholesky(Op):
else:
return [grad]
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if not allowed_inplace_inputs:
return self
new_props = self._props_dict() # type: ignore
new_props["overwrite_a"] = True
return type(self)(**new_props)
def cholesky(x, lower=True, on_error="raise", check_finite=False):
return Blockwise(
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
)(x)
def cholesky(
x: "TensorLike",
lower: bool = True,
*,
check_finite: bool = False,
overwrite_a: bool = False,
on_error: Literal["raise", "nan"] = "raise",
):
"""
Return a triangular matrix square root of positive semi-definite `x`.
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
Parameters
----------
x: tensor_like
lower : bool, default=True
Whether to return the lower or upper cholesky factor
check_finite : bool, default=False
Whether to check that the input matrix contains only finite numbers.
overwrite_a: bool, ignored
Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
for consistency with scipy.linalg.cholesky.
on_error : ['raise', 'nan']
If on_error is set to 'raise', this Op will raise a `scipy.linalg.LinAlgError` if the matrix is not positive definite.
If on_error is set to 'nan', it will return a matrix containing nans instead.
Returns
-------
TensorVariable
Lower or upper triangular Cholesky factor of `x`
Example
-------
.. testcode::
import pytensor
import pytensor.tensor as pt
import numpy as np
x = pt.tensor('x', shape=(5, 5), dtype='float64')
L = pt.linalg.cholesky(x)
f = pytensor.function([x], L)
x_value = np.random.normal(size=(5, 5))
x_value = x_value @ x_value.T # Ensures x is positive definite
L_value = f(x_value)
assert np.allclose(L_value @ L_value.T, x_value)
"""
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
class SolveBase(Op):
......@@ -145,6 +211,8 @@ class SolveBase(Op):
"lower",
"check_finite",
"b_ndim",
"overwrite_a",
"overwrite_b",
)
def __init__(
......@@ -153,6 +221,8 @@ class SolveBase(Op):
lower=False,
check_finite=True,
b_ndim,
overwrite_a=False,
overwrite_b=False,
):
self.lower = lower
self.check_finite = check_finite
......@@ -162,9 +232,25 @@ class SolveBase(Op):
self.gufunc_signature = "(m,m),(m)->(m)"
else:
self.gufunc_signature = "(m,m),(m,n)->(m,n)"
self.overwrite_a = overwrite_a
self.overwrite_b = overwrite_b
destroy_map = {}
if self.overwrite_a and self.overwrite_b:
# An output destroying two inputs is not yet supported
# destroy_map[0] = [0, 1]
raise NotImplementedError(
"It's not yet possible to overwrite_a and overwrite_b simultaneously"
)
elif self.overwrite_a:
destroy_map[0] = [0]
elif self.overwrite_b:
destroy_map[0] = [1]
self.destroy_map = destroy_map
def perform(self, node, inputs, outputs):
pass
raise NotImplementedError(
"SolveBase should be subclassed with an perform method"
)
def make_node(self, A, b):
A = as_tensor_variable(A)
......@@ -235,7 +321,16 @@ def _default_b_ndim(b, b_ndim):
class CholeskySolve(SolveBase):
__props__ = (
"lower",
"check_finite",
"b_ndim",
"overwrite_b",
)
def __init__(self, **kwargs):
if kwargs.get("overwrite_a", False):
raise ValueError("overwrite_a is not supported for CholeskySolve")
kwargs.setdefault("lower", True)
super().__init__(**kwargs)
......@@ -245,13 +340,23 @@ class CholeskySolve(SolveBase):
(C, self.lower),
b,
check_finite=self.check_finite,
overwrite_b=self.overwrite_b,
)
output_storage[0][0] = rval
def L_op(self, *args, **kwargs):
# TODO: Base impl should work, let's try it
raise NotImplementedError()
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if 1 in allowed_inplace_inputs:
new_props = self._props_dict() # type: ignore
new_props["overwrite_b"] = True
return type(self)(**new_props)
else:
return self
def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
......@@ -286,9 +391,12 @@ class SolveTriangular(SolveBase):
"lower",
"check_finite",
"b_ndim",
"overwrite_b",
)
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
if kwargs.get("overwrite_a", False):
raise ValueError("overwrite_a is not supported for SolverTriangulare")
super().__init__(**kwargs)
self.trans = trans
self.unit_diagonal = unit_diagonal
......@@ -302,6 +410,7 @@ class SolveTriangular(SolveBase):
trans=self.trans,
unit_diagonal=self.unit_diagonal,
check_finite=self.check_finite,
overwrite_b=self.overwrite_b,
)
def L_op(self, inputs, outputs, output_gradients):
......@@ -314,6 +423,14 @@ class SolveTriangular(SolveBase):
return res
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if 1 in allowed_inplace_inputs:
new_props = self._props_dict() # type: ignore
new_props["overwrite_b"] = True
return type(self)(**new_props)
else:
return self
def solve_triangular(
a: TensorVariable,
......@@ -374,6 +491,8 @@ class Solve(SolveBase):
"lower",
"check_finite",
"b_ndim",
"overwrite_a",
"overwrite_b",
)
def __init__(self, *, assume_a="gen", **kwargs):
......@@ -391,8 +510,24 @@ class Solve(SolveBase):
lower=self.lower,
check_finite=self.check_finite,
assume_a=self.assume_a,
overwrite_a=self.overwrite_a,
overwrite_b=self.overwrite_b,
)
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if not allowed_inplace_inputs:
return self
new_props = self._props_dict() # type: ignore
# PyTensor doesn't allow an output to destroy two inputs yet
# new_props["overwrite_a"] = 0 in allowed_inplace_inputs
# new_props["overwrite_b"] = 1 in allowed_inplace_inputs
if 1 in allowed_inplace_inputs:
# Give preference to overwrite_b
new_props["overwrite_b"] = True
else: # allowed inputs == [0]
new_props["overwrite_a"] = True
return type(self)(**new_props)
def solve(
a,
......
......@@ -3,10 +3,11 @@ from itertools import product
import numpy as np
import pytest
import scipy.linalg
import pytensor
from pytensor import config, function
from pytensor.compile import get_mode
from pytensor import In, config, function
from pytensor.compile import get_default_mode, get_mode
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
......@@ -15,7 +16,15 @@ from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
from pytensor.tensor.slinalg import (
Cholesky,
Solve,
SolveBase,
cho_solve,
cholesky,
solve,
solve_triangular,
)
from pytensor.tensor.utils import _parse_gufunc_signature
......@@ -398,3 +407,105 @@ def test_cop_with_params():
with pytest.raises(AssertionError):
fn(np.zeros((5, 3, 2)) - 1)
@pytest.mark.skipif(
config.mode == "FAST_COMPILE",
reason="inplace rewrites disabled when mode is FAST_COMPILE",
)
class TestInplace:
@pytest.mark.parametrize("is_batched", (False, True))
def test_cholesky(self, is_batched):
X = tensor("X", shape=(5, None, None) if is_batched else (None, None))
L = cholesky(X, lower=True)
f = function([In(X, mutable=True)], L)
assert not L.owner.op.core_op.destroy_map
if is_batched:
[cholesky_op] = [
node.op.core_op
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Cholesky)
]
else:
[cholesky_op] = [
node.op
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, Cholesky)
]
assert cholesky_op.destroy_map == {0: [0]}
rng = np.random.default_rng(441 + is_batched)
X_val = rng.normal(size=(10, 10)).astype(config.floatX)
X_val_in = X_val @ X_val.T
if is_batched:
X_val_in = np.broadcast_to(X_val_in, (5, *X_val_in.shape)).copy()
X_val_in_copy = X_val_in.copy()
f(X_val_in)
np.testing.assert_allclose(
X_val_in,
np.linalg.cholesky(X_val_in_copy),
atol=1e-5 if config.floatX == "float32" else 0,
)
@pytest.mark.parametrize("batched_A", (False, True))
@pytest.mark.parametrize("batched_b", (False, True))
@pytest.mark.parametrize("solve_fn", (solve, solve_triangular, cho_solve))
def test_solve(self, solve_fn, batched_A, batched_b):
A = tensor("A", shape=(5, 3, 3) if batched_A else (3, 3))
b = tensor("b", shape=(5, 3) if batched_b else (3,))
if solve_fn == cho_solve:
# Special signature for cho_solve
x = solve_fn((A, True), b, b_ndim=1)
else:
x = solve_fn(A, b, b_ndim=1)
mode = get_default_mode().excluding("batched_vector_b_solve_to_matrix_b_solve")
fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode)
op = fn.maker.fgraph.outputs[0].owner.op
if batched_A or batched_b:
assert isinstance(op, Blockwise) and isinstance(op.core_op, SolveBase)
if batched_A and not batched_b:
if solve_fn == solve:
assert op.destroy_map == {0: [0]}
else:
# SolveTriangular does not destroy A
assert op.destroy_map == {}
else:
assert op.destroy_map == {0: [1]}
else:
assert isinstance(op, SolveBase)
assert op.destroy_map == {0: [1]}
# We test with an F_CONTIGUOUS (core) A as only that will be destroyed by scipy
rng = np.random.default_rng(
487 + batched_A + 2 * batched_b + sum(map(ord, solve_fn.__name__))
)
A_val = np.swapaxes(rng.normal(size=A.type.shape).astype(A.type.dtype), -1, -2)
b_val = np.random.normal(size=b.type.shape).astype(b.type.dtype)
A_val_copy = A_val.copy()
b_val_copy = b_val.copy()
out = fn(A_val, b_val)
if solve_fn == cho_solve:
def core_scipy_fn(A, b):
return scipy.linalg.cho_solve((A, True), b)
else:
core_scipy_fn = getattr(scipy.linalg, solve_fn.__name__)
expected_out = np.vectorize(core_scipy_fn, signature="(m,m),(m)->(m)")(
A_val_copy, b_val_copy
)
np.testing.assert_allclose(
out, expected_out, atol=1e-6 if config.floatX == "float32" else 0
)
# Confirm input was destroyed
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
......@@ -197,7 +197,10 @@ class TestSolveBase(utt.InferShapeTester):
A = matrix()
b = matrix()
y = SolveBase(b_ndim=2)(A, b)
assert y.__repr__() == "SolveBase{lower=False, check_finite=True, b_ndim=2}.0"
assert (
y.__repr__()
== "SolveBase{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
)
class TestSolve(utt.InferShapeTester):
......@@ -361,7 +364,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def test_repr(self):
assert (
repr(CholeskySolve(lower=True, b_ndim=1))
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1)"
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1,overwrite_b=False)"
)
def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论