Unverified 提交 378cb406 authored 作者: Tanish's avatar Tanish 提交者: GitHub

Rewriting the kron function using JAX implementation (#684)

* Update the kron function to use numpy implementation and move the function to `tensor.nlinalg.py`
上级 f97d9ea3
......@@ -1010,6 +1010,40 @@ def tensorsolve(a, b, axes=None):
return TensorSolve(axes)(a, b)
def kron(a, b):
"""Kronecker product.
Same as np.kron(a, b)
Parameters
----------
a: array_like
b: array_like
Returns
-------
array_like with a.ndim + b.ndim - 2 dimensions
"""
a = as_tensor_variable(a)
b = as_tensor_variable(b)
if a.ndim + b.ndim <= 2:
raise TypeError(
"kron: inputs dimensions must sum to 3 or more. "
f"You passed {int(a.ndim)} and {int(b.ndim)}."
)
if a.ndim < b.ndim:
a = ptb.expand_dims(a, tuple(range(b.ndim - a.ndim)))
elif b.ndim < a.ndim:
b = ptb.expand_dims(b, tuple(range(a.ndim - b.ndim)))
a_reshaped = ptb.expand_dims(a, tuple(range(1, 2 * a.ndim, 2)))
b_reshaped = ptb.expand_dims(b, tuple(range(0, 2 * b.ndim, 2)))
out_shape = tuple(a.shape * b.shape)
output_out_of_shape = a_reshaped * b_reshaped
output_reshaped = output_out_of_shape.reshape(out_shape)
return output_reshaped
__all__ = [
"pinv",
"inv",
......@@ -1025,4 +1059,5 @@ __all__ = [
"norm",
"tensorinv",
"tensorsolve",
"kron",
]
......@@ -15,7 +15,7 @@ from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import matrix_dot
from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.variable import TensorVariable
......@@ -559,51 +559,6 @@ def eigvalsh(a, b, lower=True):
return Eigvalsh(lower)(a, b)
def kron(a, b):
"""Kronecker product.
Same as scipy.linalg.kron(a, b).
Parameters
----------
a: array_like
b: array_like
Returns
-------
array_like with a.ndim + b.ndim - 2 dimensions
Notes
-----
numpy.kron(a, b) != scipy.linalg.kron(a, b)!
They don't have the same shape and order when
a.ndim != b.ndim != 2.
"""
a = as_tensor_variable(a)
b = as_tensor_variable(b)
if a.ndim + b.ndim <= 2:
raise TypeError(
"kron: inputs dimensions must sum to 3 or more. "
f"You passed {int(a.ndim)} and {int(b.ndim)}."
)
o = ptm.outer(a, b)
o = o.reshape(ptb.concatenate((a.shape, b.shape)), ndim=a.ndim + b.ndim)
shf = o.dimshuffle(0, 2, 1, *range(3, o.ndim))
if shf.ndim == 3:
shf = o.dimshuffle(1, 0, 2)
o = shf.flatten()
else:
o = shf.reshape(
(
o.shape[0] * o.shape[2],
o.shape[1] * o.shape[3],
*(o.shape[i] for i in range(4, o.ndim)),
)
)
return o
class Expm(Op):
"""
Compute the matrix exponential of a square array.
......@@ -1021,7 +976,6 @@ __all__ = [
"cholesky",
"solve",
"eigvalsh",
"kron",
"expm",
"solve_discrete_lyapunov",
"solve_continuous_lyapunov",
......
......@@ -17,6 +17,7 @@ from pytensor.tensor.nlinalg import (
det,
eig,
eigh,
kron,
lstsq,
matrix_dot,
matrix_inverse,
......@@ -580,3 +581,42 @@ class TestTensorInv(utt.InferShapeTester):
t_binv1 = tf_b1(self.b1)
assert _allclose(t_binv, n_binv)
assert _allclose(t_binv1, n_binv1)
class TestKron(utt.InferShapeTester):
rng = np.random.default_rng(43)
def setup_method(self):
self.op = kron
super().setup_method()
@pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
@pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)])
def test_perform(self, shp0, shp1):
if len(shp0) + len(shp1) == 2:
pytest.skip("Sum of shp0 and shp1 must be more than 2")
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
# Using the np.kron to compare outputs
np_val = np.kron(a, b)
np.testing.assert_allclose(out, np_val)
@pytest.mark.parametrize(
"i, shp0, shp1",
[(0, (2, 3), (6, 7)), (1, (2, 3), (4, 3, 5)), (2, (2, 4, 3), (4, 3, 5))],
)
def test_kron_commutes_with_inv(self, i, shp0, shp1):
if (pytensor.config.floatX == "float32") & (i == 2):
pytest.skip("Half precision insufficient for test 3 to pass")
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
b = self.rng.random(shp1).astype(config.floatX)
lhs_f = function([x, y], pinv(kron(x, y)))
rhs_f = function([x, y], kron(pinv(x), pinv(y)))
atol = 1e-4 if config.floatX == "float32" else 1e-12
np.testing.assert_allclose(lhs_f(a, b), rhs_f(a, b), atol=atol)
......@@ -20,7 +20,6 @@ from pytensor.tensor.slinalg import (
cholesky,
eigvalsh,
expm,
kron,
solve,
solve_continuous_lyapunov,
solve_discrete_are,
......@@ -512,46 +511,6 @@ def test_expm_grad_3():
utt.verify_grad(expm, [A], rng=rng)
class TestKron(utt.InferShapeTester):
rng = np.random.default_rng(43)
def setup_method(self):
self.op = kron
super().setup_method()
def test_perform(self):
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
# Newer versions of scipy want 4 dimensions at least,
# so we have to add a dimension to a and flatten the result.
if len(shp0) + len(shp1) == 3:
scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten()
else:
scipy_val = scipy.linalg.kron(a, b)
np.testing.assert_allclose(out, scipy_val)
def test_numpy_2d(self):
for shp0 in [(2, 3)]:
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6, 7)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
assert np.allclose(out, np.kron(a, b))
def test_solve_discrete_lyapunov_via_direct_real():
N = 5
rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论