提交 ed6ca162 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Inplace Blockwise and core versions of Cholesky and Solve Ops.

上级 b8dbd4ca
......@@ -583,6 +583,12 @@ class Op(MetaObject):
)
return self.make_py_thunk(node, storage_map, compute_map, no_recycling)
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
"""Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`."""
# TODO: Document this in the Create your own Op docs
# By default, do nothing
return self
def __str__(self):
return getattr(type(self), "__name__", super().__str__())
......
......@@ -45,6 +45,7 @@ class Blockwise(Op):
signature: str | None = None,
name: str | None = None,
gufunc_spec: tuple[str, int, int] | None = None,
destroy_map=None,
**kwargs,
):
"""
......@@ -79,6 +80,15 @@ class Blockwise(Op):
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec
self._gufunc = None
if destroy_map is not None:
self.destroy_map = destroy_map
if self.destroy_map != core_op.destroy_map:
# Note: Should be fine for destroy_map of Blockwise to be more extensive than that of core_op
# But we are not using that anywhere yet, so this check is fine for now
raise ValueError(
f"Blockwise destroy_map {self.destroy_map} must be the same as that of the core_op {core_op} {core_op.destroy_map}"
)
super().__init__(**kwargs)
def __getstate__(self):
......
import itertools
from pytensor.compile import Supervisor
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot
......@@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node):
# We register this rewrite late, so that other rewrites need only target Blockwise Ops
# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
optdb.register(
"local_useless_unbatched_blockwise",
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
"fast_run",
"fast_compile",
"blockwise",
position=49,
position=60,
)
......@@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node):
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
@node_rewriter(tracks=[Blockwise], inplace=True)
def blockwise_inplace(fgraph, node):
blockwise_op = node.op
if blockwise_op.destroy_map:
# Op already has inplace
return
# Find out valid inputs for inplacing
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
protected_inputs = [
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
protected_inputs.extend(fgraph.outputs)
allowed_inplace_inputs = [
idx
for idx, inp in enumerate(node.inputs)
if
(
# Constants would need to be recreated every time if inplaced
not isinstance(inp, Constant)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and not fgraph.has_destroyers([inp])
and inp not in protected_inputs
)
]
if not allowed_inplace_inputs:
return None
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
)
if not inplace_core_op.destroy_map:
return None
# Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in inplace_core_op.destroy_map.values():
for destroyed_input in destroyed_inputs:
if destroyed_input not in allowed_inplace_inputs:
raise ValueError(
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
)
# Recreate core_op with inplace
inplace_blockwise_op = Blockwise(
core_op=inplace_core_op,
signature=blockwise_op.signature,
name=blockwise_op.name,
gufunc_spec=blockwise_op.gufunc_spec,
destroy_map=inplace_core_op.destroy_map,
)
out = inplace_blockwise_op.make_node(*node.inputs).outputs
copy_stack_trace(node.outputs, out)
return out
optdb.register(
"blockwise_inplace",
in2out(blockwise_inplace),
"fast_run",
"inplace",
position=50.1,
)
差异被折叠。
......@@ -3,10 +3,11 @@ from itertools import product
import numpy as np
import pytest
import scipy.linalg
import pytensor
from pytensor import config, function
from pytensor.compile import get_mode
from pytensor import In, config, function
from pytensor.compile import get_default_mode, get_mode
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
......@@ -15,7 +16,15 @@ from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
from pytensor.tensor.slinalg import (
Cholesky,
Solve,
SolveBase,
cho_solve,
cholesky,
solve,
solve_triangular,
)
from pytensor.tensor.utils import _parse_gufunc_signature
......@@ -398,3 +407,105 @@ def test_cop_with_params():
with pytest.raises(AssertionError):
fn(np.zeros((5, 3, 2)) - 1)
@pytest.mark.skipif(
config.mode == "FAST_COMPILE",
reason="inplace rewrites disabled when mode is FAST_COMPILE",
)
class TestInplace:
@pytest.mark.parametrize("is_batched", (False, True))
def test_cholesky(self, is_batched):
X = tensor("X", shape=(5, None, None) if is_batched else (None, None))
L = cholesky(X, lower=True)
f = function([In(X, mutable=True)], L)
assert not L.owner.op.core_op.destroy_map
if is_batched:
[cholesky_op] = [
node.op.core_op
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Cholesky)
]
else:
[cholesky_op] = [
node.op
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, Cholesky)
]
assert cholesky_op.destroy_map == {0: [0]}
rng = np.random.default_rng(441 + is_batched)
X_val = rng.normal(size=(10, 10)).astype(config.floatX)
X_val_in = X_val @ X_val.T
if is_batched:
X_val_in = np.broadcast_to(X_val_in, (5, *X_val_in.shape)).copy()
X_val_in_copy = X_val_in.copy()
f(X_val_in)
np.testing.assert_allclose(
X_val_in,
np.linalg.cholesky(X_val_in_copy),
atol=1e-5 if config.floatX == "float32" else 0,
)
@pytest.mark.parametrize("batched_A", (False, True))
@pytest.mark.parametrize("batched_b", (False, True))
@pytest.mark.parametrize("solve_fn", (solve, solve_triangular, cho_solve))
def test_solve(self, solve_fn, batched_A, batched_b):
A = tensor("A", shape=(5, 3, 3) if batched_A else (3, 3))
b = tensor("b", shape=(5, 3) if batched_b else (3,))
if solve_fn == cho_solve:
# Special signature for cho_solve
x = solve_fn((A, True), b, b_ndim=1)
else:
x = solve_fn(A, b, b_ndim=1)
mode = get_default_mode().excluding("batched_vector_b_solve_to_matrix_b_solve")
fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode)
op = fn.maker.fgraph.outputs[0].owner.op
if batched_A or batched_b:
assert isinstance(op, Blockwise) and isinstance(op.core_op, SolveBase)
if batched_A and not batched_b:
if solve_fn == solve:
assert op.destroy_map == {0: [0]}
else:
# SolveTriangular does not destroy A
assert op.destroy_map == {}
else:
assert op.destroy_map == {0: [1]}
else:
assert isinstance(op, SolveBase)
assert op.destroy_map == {0: [1]}
# We test with an F_CONTIGUOUS (core) A as only that will be destroyed by scipy
rng = np.random.default_rng(
487 + batched_A + 2 * batched_b + sum(map(ord, solve_fn.__name__))
)
A_val = np.swapaxes(rng.normal(size=A.type.shape).astype(A.type.dtype), -1, -2)
b_val = np.random.normal(size=b.type.shape).astype(b.type.dtype)
A_val_copy = A_val.copy()
b_val_copy = b_val.copy()
out = fn(A_val, b_val)
if solve_fn == cho_solve:
def core_scipy_fn(A, b):
return scipy.linalg.cho_solve((A, True), b)
else:
core_scipy_fn = getattr(scipy.linalg, solve_fn.__name__)
expected_out = np.vectorize(core_scipy_fn, signature="(m,m),(m)->(m)")(
A_val_copy, b_val_copy
)
np.testing.assert_allclose(
out, expected_out, atol=1e-6 if config.floatX == "float32" else 0
)
# Confirm input was destroyed
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
......@@ -197,7 +197,10 @@ class TestSolveBase(utt.InferShapeTester):
A = matrix()
b = matrix()
y = SolveBase(b_ndim=2)(A, b)
assert y.__repr__() == "SolveBase{lower=False, check_finite=True, b_ndim=2}.0"
assert (
y.__repr__()
== "SolveBase{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
)
class TestSolve(utt.InferShapeTester):
......@@ -361,7 +364,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def test_repr(self):
assert (
repr(CholeskySolve(lower=True, b_ndim=1))
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1)"
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1,overwrite_b=False)"
)
def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论