提交 ed6ca162 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Inplace Blockwise and core versions of Cholesky and Solve Ops.

上级 b8dbd4ca
...@@ -583,6 +583,12 @@ class Op(MetaObject): ...@@ -583,6 +583,12 @@ class Op(MetaObject):
) )
return self.make_py_thunk(node, storage_map, compute_map, no_recycling) return self.make_py_thunk(node, storage_map, compute_map, no_recycling)
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
"""Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`."""
# TODO: Document this in the Create your own Op docs
# By default, do nothing
return self
def __str__(self): def __str__(self):
return getattr(type(self), "__name__", super().__str__()) return getattr(type(self), "__name__", super().__str__())
......
...@@ -45,6 +45,7 @@ class Blockwise(Op): ...@@ -45,6 +45,7 @@ class Blockwise(Op):
signature: str | None = None, signature: str | None = None,
name: str | None = None, name: str | None = None,
gufunc_spec: tuple[str, int, int] | None = None, gufunc_spec: tuple[str, int, int] | None = None,
destroy_map=None,
**kwargs, **kwargs,
): ):
""" """
...@@ -79,6 +80,15 @@ class Blockwise(Op): ...@@ -79,6 +80,15 @@ class Blockwise(Op):
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec self.gufunc_spec = gufunc_spec
self._gufunc = None self._gufunc = None
if destroy_map is not None:
self.destroy_map = destroy_map
if self.destroy_map != core_op.destroy_map:
# Note: Should be fine for destroy_map of Blockwise to be more extensive than that of core_op
# But we are not using that anywhere yet, so this check is fine for now
raise ValueError(
f"Blockwise destroy_map {self.destroy_map} must be the same as that of the core_op {core_op} {core_op.destroy_map}"
)
super().__init__(**kwargs) super().__init__(**kwargs)
def __getstate__(self): def __getstate__(self):
......
import itertools
from pytensor.compile import Supervisor
from pytensor.compile.mode import optdb from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter from pytensor.graph import Constant, node_rewriter
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
...@@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node): ...@@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node):
# We register this rewrite late, so that other rewrites need only target Blockwise Ops # We register this rewrite late, so that other rewrites need only target Blockwise Ops
# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
optdb.register( optdb.register(
"local_useless_unbatched_blockwise", "local_useless_unbatched_blockwise",
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True), out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
"fast_run", "fast_run",
"fast_compile", "fast_compile",
"blockwise", "blockwise",
position=49, position=60,
) )
...@@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node): ...@@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node):
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)]) new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
copy_stack_trace(node.outputs[0], new_out) copy_stack_trace(node.outputs[0], new_out)
return [new_out] return [new_out]
@node_rewriter(tracks=[Blockwise], inplace=True)
def blockwise_inplace(fgraph, node):
blockwise_op = node.op
if blockwise_op.destroy_map:
# Op already has inplace
return
# Find out valid inputs for inplacing
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
protected_inputs = [
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
protected_inputs.extend(fgraph.outputs)
allowed_inplace_inputs = [
idx
for idx, inp in enumerate(node.inputs)
if
(
# Constants would need to be recreated every time if inplaced
not isinstance(inp, Constant)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and not fgraph.has_destroyers([inp])
and inp not in protected_inputs
)
]
if not allowed_inplace_inputs:
return None
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
)
if not inplace_core_op.destroy_map:
return None
# Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in inplace_core_op.destroy_map.values():
for destroyed_input in destroyed_inputs:
if destroyed_input not in allowed_inplace_inputs:
raise ValueError(
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
)
# Recreate core_op with inplace
inplace_blockwise_op = Blockwise(
core_op=inplace_core_op,
signature=blockwise_op.signature,
name=blockwise_op.name,
gufunc_spec=blockwise_op.gufunc_spec,
destroy_map=inplace_core_op.destroy_map,
)
out = inplace_blockwise_op.make_node(*node.inputs).outputs
copy_stack_trace(node.outputs, out)
return out
optdb.register(
"blockwise_inplace",
in2out(blockwise_inplace),
"fast_run",
"inplace",
position=50.1,
)
差异被折叠。
...@@ -3,10 +3,11 @@ from itertools import product ...@@ -3,10 +3,11 @@ from itertools import product
import numpy as np import numpy as np
import pytest import pytest
import scipy.linalg
import pytensor import pytensor
from pytensor import config, function from pytensor import In, config, function
from pytensor.compile import get_mode from pytensor.compile import get_default_mode, get_mode
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
...@@ -15,7 +16,15 @@ from pytensor.tensor import diagonal, log, tensor ...@@ -15,7 +16,15 @@ from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular from pytensor.tensor.slinalg import (
Cholesky,
Solve,
SolveBase,
cho_solve,
cholesky,
solve,
solve_triangular,
)
from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.utils import _parse_gufunc_signature
...@@ -398,3 +407,105 @@ def test_cop_with_params(): ...@@ -398,3 +407,105 @@ def test_cop_with_params():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
fn(np.zeros((5, 3, 2)) - 1) fn(np.zeros((5, 3, 2)) - 1)
@pytest.mark.skipif(
config.mode == "FAST_COMPILE",
reason="inplace rewrites disabled when mode is FAST_COMPILE",
)
class TestInplace:
@pytest.mark.parametrize("is_batched", (False, True))
def test_cholesky(self, is_batched):
X = tensor("X", shape=(5, None, None) if is_batched else (None, None))
L = cholesky(X, lower=True)
f = function([In(X, mutable=True)], L)
assert not L.owner.op.core_op.destroy_map
if is_batched:
[cholesky_op] = [
node.op.core_op
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Cholesky)
]
else:
[cholesky_op] = [
node.op
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, Cholesky)
]
assert cholesky_op.destroy_map == {0: [0]}
rng = np.random.default_rng(441 + is_batched)
X_val = rng.normal(size=(10, 10)).astype(config.floatX)
X_val_in = X_val @ X_val.T
if is_batched:
X_val_in = np.broadcast_to(X_val_in, (5, *X_val_in.shape)).copy()
X_val_in_copy = X_val_in.copy()
f(X_val_in)
np.testing.assert_allclose(
X_val_in,
np.linalg.cholesky(X_val_in_copy),
atol=1e-5 if config.floatX == "float32" else 0,
)
@pytest.mark.parametrize("batched_A", (False, True))
@pytest.mark.parametrize("batched_b", (False, True))
@pytest.mark.parametrize("solve_fn", (solve, solve_triangular, cho_solve))
def test_solve(self, solve_fn, batched_A, batched_b):
A = tensor("A", shape=(5, 3, 3) if batched_A else (3, 3))
b = tensor("b", shape=(5, 3) if batched_b else (3,))
if solve_fn == cho_solve:
# Special signature for cho_solve
x = solve_fn((A, True), b, b_ndim=1)
else:
x = solve_fn(A, b, b_ndim=1)
mode = get_default_mode().excluding("batched_vector_b_solve_to_matrix_b_solve")
fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode)
op = fn.maker.fgraph.outputs[0].owner.op
if batched_A or batched_b:
assert isinstance(op, Blockwise) and isinstance(op.core_op, SolveBase)
if batched_A and not batched_b:
if solve_fn == solve:
assert op.destroy_map == {0: [0]}
else:
# SolveTriangular does not destroy A
assert op.destroy_map == {}
else:
assert op.destroy_map == {0: [1]}
else:
assert isinstance(op, SolveBase)
assert op.destroy_map == {0: [1]}
# We test with an F_CONTIGUOUS (core) A as only that will be destroyed by scipy
rng = np.random.default_rng(
487 + batched_A + 2 * batched_b + sum(map(ord, solve_fn.__name__))
)
A_val = np.swapaxes(rng.normal(size=A.type.shape).astype(A.type.dtype), -1, -2)
b_val = np.random.normal(size=b.type.shape).astype(b.type.dtype)
A_val_copy = A_val.copy()
b_val_copy = b_val.copy()
out = fn(A_val, b_val)
if solve_fn == cho_solve:
def core_scipy_fn(A, b):
return scipy.linalg.cho_solve((A, True), b)
else:
core_scipy_fn = getattr(scipy.linalg, solve_fn.__name__)
expected_out = np.vectorize(core_scipy_fn, signature="(m,m),(m)->(m)")(
A_val_copy, b_val_copy
)
np.testing.assert_allclose(
out, expected_out, atol=1e-6 if config.floatX == "float32" else 0
)
# Confirm input was destroyed
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
...@@ -197,7 +197,10 @@ class TestSolveBase(utt.InferShapeTester): ...@@ -197,7 +197,10 @@ class TestSolveBase(utt.InferShapeTester):
A = matrix() A = matrix()
b = matrix() b = matrix()
y = SolveBase(b_ndim=2)(A, b) y = SolveBase(b_ndim=2)(A, b)
assert y.__repr__() == "SolveBase{lower=False, check_finite=True, b_ndim=2}.0" assert (
y.__repr__()
== "SolveBase{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
)
class TestSolve(utt.InferShapeTester): class TestSolve(utt.InferShapeTester):
...@@ -361,7 +364,7 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -361,7 +364,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def test_repr(self): def test_repr(self):
assert ( assert (
repr(CholeskySolve(lower=True, b_ndim=1)) repr(CholeskySolve(lower=True, b_ndim=1))
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1)" == "CholeskySolve(lower=True,check_finite=True,b_ndim=1,overwrite_b=False)"
) )
def test_infer_shape(self): def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论