提交 ed6ca162 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Inplace Blockwise and core versions of Cholesky and Solve Ops.

上级 b8dbd4ca
...@@ -583,6 +583,12 @@ class Op(MetaObject): ...@@ -583,6 +583,12 @@ class Op(MetaObject):
) )
return self.make_py_thunk(node, storage_map, compute_map, no_recycling) return self.make_py_thunk(node, storage_map, compute_map, no_recycling)
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
"""Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`."""
# TODO: Document this in the Create your own Op docs
# By default, do nothing
return self
def __str__(self): def __str__(self):
return getattr(type(self), "__name__", super().__str__()) return getattr(type(self), "__name__", super().__str__())
......
...@@ -45,6 +45,7 @@ class Blockwise(Op): ...@@ -45,6 +45,7 @@ class Blockwise(Op):
signature: str | None = None, signature: str | None = None,
name: str | None = None, name: str | None = None,
gufunc_spec: tuple[str, int, int] | None = None, gufunc_spec: tuple[str, int, int] | None = None,
destroy_map=None,
**kwargs, **kwargs,
): ):
""" """
...@@ -79,6 +80,15 @@ class Blockwise(Op): ...@@ -79,6 +80,15 @@ class Blockwise(Op):
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec self.gufunc_spec = gufunc_spec
self._gufunc = None self._gufunc = None
if destroy_map is not None:
self.destroy_map = destroy_map
if self.destroy_map != core_op.destroy_map:
# Note: Should be fine for destroy_map of Blockwise to be more extensive than that of core_op
# But we are not using that anywhere yet, so this check is fine for now
raise ValueError(
f"Blockwise destroy_map {self.destroy_map} must be the same as that of the core_op {core_op} {core_op.destroy_map}"
)
super().__init__(**kwargs) super().__init__(**kwargs)
def __getstate__(self): def __getstate__(self):
......
import itertools
from pytensor.compile import Supervisor
from pytensor.compile.mode import optdb from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter from pytensor.graph import Constant, node_rewriter
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
...@@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node): ...@@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node):
# We register this rewrite late, so that other rewrites need only target Blockwise Ops # We register this rewrite late, so that other rewrites need only target Blockwise Ops
# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
optdb.register( optdb.register(
"local_useless_unbatched_blockwise", "local_useless_unbatched_blockwise",
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True), out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
"fast_run", "fast_run",
"fast_compile", "fast_compile",
"blockwise", "blockwise",
position=49, position=60,
) )
...@@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node): ...@@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node):
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)]) new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
copy_stack_trace(node.outputs[0], new_out) copy_stack_trace(node.outputs[0], new_out)
return [new_out] return [new_out]
@node_rewriter(tracks=[Blockwise], inplace=True)
def blockwise_inplace(fgraph, node):
blockwise_op = node.op
if blockwise_op.destroy_map:
# Op already has inplace
return
# Find out valid inputs for inplacing
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
protected_inputs = [
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
protected_inputs.extend(fgraph.outputs)
allowed_inplace_inputs = [
idx
for idx, inp in enumerate(node.inputs)
if
(
# Constants would need to be recreated every time if inplaced
not isinstance(inp, Constant)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and not fgraph.has_destroyers([inp])
and inp not in protected_inputs
)
]
if not allowed_inplace_inputs:
return None
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
)
if not inplace_core_op.destroy_map:
return None
# Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in inplace_core_op.destroy_map.values():
for destroyed_input in destroyed_inputs:
if destroyed_input not in allowed_inplace_inputs:
raise ValueError(
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
)
# Recreate core_op with inplace
inplace_blockwise_op = Blockwise(
core_op=inplace_core_op,
signature=blockwise_op.signature,
name=blockwise_op.name,
gufunc_spec=blockwise_op.gufunc_spec,
destroy_map=inplace_core_op.destroy_map,
)
out = inplace_blockwise_op.make_node(*node.inputs).outputs
copy_stack_trace(node.outputs, out)
return out
optdb.register(
"blockwise_inplace",
in2out(blockwise_inplace),
"fast_run",
"inplace",
position=50.1,
)
...@@ -28,57 +28,68 @@ logger = logging.getLogger(__name__) ...@@ -28,57 +28,68 @@ logger = logging.getLogger(__name__)
class Cholesky(Op): class Cholesky(Op):
"""
Return a triangular matrix square root of positive semi-definite `x`.
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
Parameters
----------
lower : bool, default=True
Whether to return the lower or upper cholesky factor
on_error : ['raise', 'nan']
If on_error is set to 'raise', this Op will raise a
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
If on_error is set to 'nan', it will return a matrix containing
nans instead.
"""
# TODO: inplace
# TODO: for specific dtypes
# TODO: LAPACK wrapper with in-place behavior, for solve also # TODO: LAPACK wrapper with in-place behavior, for solve also
__props__ = ("lower", "destructive", "on_error") __props__ = ("lower", "check_finite", "on_error", "overwrite_a")
gufunc_signature = "(m,m)->(m,m)" gufunc_signature = "(m,m)->(m,m)"
def __init__(self, *, lower=True, check_finite=True, on_error="raise"): def __init__(
self,
*,
lower: bool = True,
check_finite: bool = True,
on_error: Literal["raise", "nan"] = "raise",
overwrite_a: bool = False,
):
self.lower = lower self.lower = lower
self.destructive = False
self.check_finite = check_finite self.check_finite = check_finite
if on_error not in ("raise", "nan"): if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or ""nan"') raise ValueError('on_error must be one of "raise" or ""nan"')
self.on_error = on_error self.on_error = on_error
self.overwrite_a = overwrite_a
if self.overwrite_a:
self.destroy_map = {0: [0]}
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[0]]
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
assert x.ndim == 2 if x.type.ndim != 2:
return Apply(self, [x], [x.type()]) raise TypeError(
f"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
)
# Call scipy to find output dtype
dtype = scipy.linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype
return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
x = inputs[0] [x] = inputs
z = outputs[0] [out] = outputs
try: try:
z[0] = scipy.linalg.cholesky( # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
x, lower=self.lower, check_finite=self.check_finite # If we have a `C_CONTIGUOUS` array we transpose to benefit from it
).astype(x.dtype) if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
out[0] = scipy.linalg.cholesky(
x.T,
lower=not self.lower,
check_finite=self.check_finite,
overwrite_a=True,
).T
else:
out[0] = scipy.linalg.cholesky(
x,
lower=self.lower,
check_finite=self.check_finite,
overwrite_a=self.overwrite_a,
)
except scipy.linalg.LinAlgError: except scipy.linalg.LinAlgError:
if self.on_error == "raise": if self.on_error == "raise":
raise raise
else: else:
z[0] = (np.zeros(x.shape) * np.nan).astype(x.dtype) out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype)
def L_op(self, inputs, outputs, gradients): def L_op(self, inputs, outputs, gradients):
""" """
...@@ -131,11 +142,66 @@ class Cholesky(Op): ...@@ -131,11 +142,66 @@ class Cholesky(Op):
else: else:
return [grad] return [grad]
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if not allowed_inplace_inputs:
return self
new_props = self._props_dict() # type: ignore
new_props["overwrite_a"] = True
return type(self)(**new_props)
def cholesky(x, lower=True, on_error="raise", check_finite=False):
return Blockwise( def cholesky(
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite) x: "TensorLike",
)(x) lower: bool = True,
*,
check_finite: bool = False,
overwrite_a: bool = False,
on_error: Literal["raise", "nan"] = "raise",
):
"""
Return a triangular matrix square root of positive semi-definite `x`.
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
Parameters
----------
x: tensor_like
lower : bool, default=True
Whether to return the lower or upper cholesky factor
check_finite : bool, default=False
Whether to check that the input matrix contains only finite numbers.
overwrite_a: bool, ignored
Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
for consistency with scipy.linalg.cholesky.
on_error : ['raise', 'nan']
If on_error is set to 'raise', this Op will raise a `scipy.linalg.LinAlgError` if the matrix is not positive definite.
If on_error is set to 'nan', it will return a matrix containing nans instead.
Returns
-------
TensorVariable
Lower or upper triangular Cholesky factor of `x`
Example
-------
.. testcode::
import pytensor
import pytensor.tensor as pt
import numpy as np
x = pt.tensor('x', shape=(5, 5), dtype='float64')
L = pt.linalg.cholesky(x)
f = pytensor.function([x], L)
x_value = np.random.normal(size=(5, 5))
x_value = x_value @ x_value.T # Ensures x is positive definite
L_value = f(x_value)
assert np.allclose(L_value @ L_value.T, x_value)
"""
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
class SolveBase(Op): class SolveBase(Op):
...@@ -145,6 +211,8 @@ class SolveBase(Op): ...@@ -145,6 +211,8 @@ class SolveBase(Op):
"lower", "lower",
"check_finite", "check_finite",
"b_ndim", "b_ndim",
"overwrite_a",
"overwrite_b",
) )
def __init__( def __init__(
...@@ -153,6 +221,8 @@ class SolveBase(Op): ...@@ -153,6 +221,8 @@ class SolveBase(Op):
lower=False, lower=False,
check_finite=True, check_finite=True,
b_ndim, b_ndim,
overwrite_a=False,
overwrite_b=False,
): ):
self.lower = lower self.lower = lower
self.check_finite = check_finite self.check_finite = check_finite
...@@ -162,9 +232,25 @@ class SolveBase(Op): ...@@ -162,9 +232,25 @@ class SolveBase(Op):
self.gufunc_signature = "(m,m),(m)->(m)" self.gufunc_signature = "(m,m),(m)->(m)"
else: else:
self.gufunc_signature = "(m,m),(m,n)->(m,n)" self.gufunc_signature = "(m,m),(m,n)->(m,n)"
self.overwrite_a = overwrite_a
self.overwrite_b = overwrite_b
destroy_map = {}
if self.overwrite_a and self.overwrite_b:
# An output destroying two inputs is not yet supported
# destroy_map[0] = [0, 1]
raise NotImplementedError(
"It's not yet possible to overwrite_a and overwrite_b simultaneously"
)
elif self.overwrite_a:
destroy_map[0] = [0]
elif self.overwrite_b:
destroy_map[0] = [1]
self.destroy_map = destroy_map
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
pass raise NotImplementedError(
"SolveBase should be subclassed with an perform method"
)
def make_node(self, A, b): def make_node(self, A, b):
A = as_tensor_variable(A) A = as_tensor_variable(A)
...@@ -235,7 +321,16 @@ def _default_b_ndim(b, b_ndim): ...@@ -235,7 +321,16 @@ def _default_b_ndim(b, b_ndim):
class CholeskySolve(SolveBase): class CholeskySolve(SolveBase):
__props__ = (
"lower",
"check_finite",
"b_ndim",
"overwrite_b",
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
if kwargs.get("overwrite_a", False):
raise ValueError("overwrite_a is not supported for CholeskySolve")
kwargs.setdefault("lower", True) kwargs.setdefault("lower", True)
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -245,13 +340,23 @@ class CholeskySolve(SolveBase): ...@@ -245,13 +340,23 @@ class CholeskySolve(SolveBase):
(C, self.lower), (C, self.lower),
b, b,
check_finite=self.check_finite, check_finite=self.check_finite,
overwrite_b=self.overwrite_b,
) )
output_storage[0][0] = rval output_storage[0][0] = rval
def L_op(self, *args, **kwargs): def L_op(self, *args, **kwargs):
# TODO: Base impl should work, let's try it
raise NotImplementedError() raise NotImplementedError()
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if 1 in allowed_inplace_inputs:
new_props = self._props_dict() # type: ignore
new_props["overwrite_b"] = True
return type(self)(**new_props)
else:
return self
def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
"""Solve the linear equations A x = b, given the Cholesky factorization of A. """Solve the linear equations A x = b, given the Cholesky factorization of A.
...@@ -286,9 +391,12 @@ class SolveTriangular(SolveBase): ...@@ -286,9 +391,12 @@ class SolveTriangular(SolveBase):
"lower", "lower",
"check_finite", "check_finite",
"b_ndim", "b_ndim",
"overwrite_b",
) )
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs): def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
if kwargs.get("overwrite_a", False):
raise ValueError("overwrite_a is not supported for SolverTriangulare")
super().__init__(**kwargs) super().__init__(**kwargs)
self.trans = trans self.trans = trans
self.unit_diagonal = unit_diagonal self.unit_diagonal = unit_diagonal
...@@ -302,6 +410,7 @@ class SolveTriangular(SolveBase): ...@@ -302,6 +410,7 @@ class SolveTriangular(SolveBase):
trans=self.trans, trans=self.trans,
unit_diagonal=self.unit_diagonal, unit_diagonal=self.unit_diagonal,
check_finite=self.check_finite, check_finite=self.check_finite,
overwrite_b=self.overwrite_b,
) )
def L_op(self, inputs, outputs, output_gradients): def L_op(self, inputs, outputs, output_gradients):
...@@ -314,6 +423,14 @@ class SolveTriangular(SolveBase): ...@@ -314,6 +423,14 @@ class SolveTriangular(SolveBase):
return res return res
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if 1 in allowed_inplace_inputs:
new_props = self._props_dict() # type: ignore
new_props["overwrite_b"] = True
return type(self)(**new_props)
else:
return self
def solve_triangular( def solve_triangular(
a: TensorVariable, a: TensorVariable,
...@@ -374,6 +491,8 @@ class Solve(SolveBase): ...@@ -374,6 +491,8 @@ class Solve(SolveBase):
"lower", "lower",
"check_finite", "check_finite",
"b_ndim", "b_ndim",
"overwrite_a",
"overwrite_b",
) )
def __init__(self, *, assume_a="gen", **kwargs): def __init__(self, *, assume_a="gen", **kwargs):
...@@ -391,8 +510,24 @@ class Solve(SolveBase): ...@@ -391,8 +510,24 @@ class Solve(SolveBase):
lower=self.lower, lower=self.lower,
check_finite=self.check_finite, check_finite=self.check_finite,
assume_a=self.assume_a, assume_a=self.assume_a,
overwrite_a=self.overwrite_a,
overwrite_b=self.overwrite_b,
) )
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if not allowed_inplace_inputs:
return self
new_props = self._props_dict() # type: ignore
# PyTensor doesn't allow an output to destroy two inputs yet
# new_props["overwrite_a"] = 0 in allowed_inplace_inputs
# new_props["overwrite_b"] = 1 in allowed_inplace_inputs
if 1 in allowed_inplace_inputs:
# Give preference to overwrite_b
new_props["overwrite_b"] = True
else: # allowed inputs == [0]
new_props["overwrite_a"] = True
return type(self)(**new_props)
def solve( def solve(
a, a,
......
...@@ -3,10 +3,11 @@ from itertools import product ...@@ -3,10 +3,11 @@ from itertools import product
import numpy as np import numpy as np
import pytest import pytest
import scipy.linalg
import pytensor import pytensor
from pytensor import config, function from pytensor import In, config, function
from pytensor.compile import get_mode from pytensor.compile import get_default_mode, get_mode
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
...@@ -15,7 +16,15 @@ from pytensor.tensor import diagonal, log, tensor ...@@ -15,7 +16,15 @@ from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular from pytensor.tensor.slinalg import (
Cholesky,
Solve,
SolveBase,
cho_solve,
cholesky,
solve,
solve_triangular,
)
from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.utils import _parse_gufunc_signature
...@@ -398,3 +407,105 @@ def test_cop_with_params(): ...@@ -398,3 +407,105 @@ def test_cop_with_params():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
fn(np.zeros((5, 3, 2)) - 1) fn(np.zeros((5, 3, 2)) - 1)
@pytest.mark.skipif(
config.mode == "FAST_COMPILE",
reason="inplace rewrites disabled when mode is FAST_COMPILE",
)
class TestInplace:
@pytest.mark.parametrize("is_batched", (False, True))
def test_cholesky(self, is_batched):
X = tensor("X", shape=(5, None, None) if is_batched else (None, None))
L = cholesky(X, lower=True)
f = function([In(X, mutable=True)], L)
assert not L.owner.op.core_op.destroy_map
if is_batched:
[cholesky_op] = [
node.op.core_op
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Cholesky)
]
else:
[cholesky_op] = [
node.op
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, Cholesky)
]
assert cholesky_op.destroy_map == {0: [0]}
rng = np.random.default_rng(441 + is_batched)
X_val = rng.normal(size=(10, 10)).astype(config.floatX)
X_val_in = X_val @ X_val.T
if is_batched:
X_val_in = np.broadcast_to(X_val_in, (5, *X_val_in.shape)).copy()
X_val_in_copy = X_val_in.copy()
f(X_val_in)
np.testing.assert_allclose(
X_val_in,
np.linalg.cholesky(X_val_in_copy),
atol=1e-5 if config.floatX == "float32" else 0,
)
@pytest.mark.parametrize("batched_A", (False, True))
@pytest.mark.parametrize("batched_b", (False, True))
@pytest.mark.parametrize("solve_fn", (solve, solve_triangular, cho_solve))
def test_solve(self, solve_fn, batched_A, batched_b):
A = tensor("A", shape=(5, 3, 3) if batched_A else (3, 3))
b = tensor("b", shape=(5, 3) if batched_b else (3,))
if solve_fn == cho_solve:
# Special signature for cho_solve
x = solve_fn((A, True), b, b_ndim=1)
else:
x = solve_fn(A, b, b_ndim=1)
mode = get_default_mode().excluding("batched_vector_b_solve_to_matrix_b_solve")
fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode)
op = fn.maker.fgraph.outputs[0].owner.op
if batched_A or batched_b:
assert isinstance(op, Blockwise) and isinstance(op.core_op, SolveBase)
if batched_A and not batched_b:
if solve_fn == solve:
assert op.destroy_map == {0: [0]}
else:
# SolveTriangular does not destroy A
assert op.destroy_map == {}
else:
assert op.destroy_map == {0: [1]}
else:
assert isinstance(op, SolveBase)
assert op.destroy_map == {0: [1]}
# We test with an F_CONTIGUOUS (core) A as only that will be destroyed by scipy
rng = np.random.default_rng(
487 + batched_A + 2 * batched_b + sum(map(ord, solve_fn.__name__))
)
A_val = np.swapaxes(rng.normal(size=A.type.shape).astype(A.type.dtype), -1, -2)
b_val = np.random.normal(size=b.type.shape).astype(b.type.dtype)
A_val_copy = A_val.copy()
b_val_copy = b_val.copy()
out = fn(A_val, b_val)
if solve_fn == cho_solve:
def core_scipy_fn(A, b):
return scipy.linalg.cho_solve((A, True), b)
else:
core_scipy_fn = getattr(scipy.linalg, solve_fn.__name__)
expected_out = np.vectorize(core_scipy_fn, signature="(m,m),(m)->(m)")(
A_val_copy, b_val_copy
)
np.testing.assert_allclose(
out, expected_out, atol=1e-6 if config.floatX == "float32" else 0
)
# Confirm input was destroyed
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
...@@ -197,7 +197,10 @@ class TestSolveBase(utt.InferShapeTester): ...@@ -197,7 +197,10 @@ class TestSolveBase(utt.InferShapeTester):
A = matrix() A = matrix()
b = matrix() b = matrix()
y = SolveBase(b_ndim=2)(A, b) y = SolveBase(b_ndim=2)(A, b)
assert y.__repr__() == "SolveBase{lower=False, check_finite=True, b_ndim=2}.0" assert (
y.__repr__()
== "SolveBase{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
)
class TestSolve(utt.InferShapeTester): class TestSolve(utt.InferShapeTester):
...@@ -361,7 +364,7 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -361,7 +364,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def test_repr(self): def test_repr(self):
assert ( assert (
repr(CholeskySolve(lower=True, b_ndim=1)) repr(CholeskySolve(lower=True, b_ndim=1))
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1)" == "CholeskySolve(lower=True,check_finite=True,b_ndim=1,overwrite_b=False)"
) )
def test_infer_shape(self): def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论