Unverified 提交 378cb406 authored 作者: Tanish's avatar Tanish 提交者: GitHub

Rewriting the kron function using JAX implementation (#684)

* Update the kron function to use numpy implementation and move the function to `tensor.nlinalg.py`
上级 f97d9ea3
...@@ -1010,6 +1010,40 @@ def tensorsolve(a, b, axes=None): ...@@ -1010,6 +1010,40 @@ def tensorsolve(a, b, axes=None):
return TensorSolve(axes)(a, b) return TensorSolve(axes)(a, b)
def kron(a, b):
"""Kronecker product.
Same as np.kron(a, b)
Parameters
----------
a: array_like
b: array_like
Returns
-------
array_like with a.ndim + b.ndim - 2 dimensions
"""
a = as_tensor_variable(a)
b = as_tensor_variable(b)
if a.ndim + b.ndim <= 2:
raise TypeError(
"kron: inputs dimensions must sum to 3 or more. "
f"You passed {int(a.ndim)} and {int(b.ndim)}."
)
if a.ndim < b.ndim:
a = ptb.expand_dims(a, tuple(range(b.ndim - a.ndim)))
elif b.ndim < a.ndim:
b = ptb.expand_dims(b, tuple(range(a.ndim - b.ndim)))
a_reshaped = ptb.expand_dims(a, tuple(range(1, 2 * a.ndim, 2)))
b_reshaped = ptb.expand_dims(b, tuple(range(0, 2 * b.ndim, 2)))
out_shape = tuple(a.shape * b.shape)
output_out_of_shape = a_reshaped * b_reshaped
output_reshaped = output_out_of_shape.reshape(out_shape)
return output_reshaped
__all__ = [ __all__ = [
"pinv", "pinv",
"inv", "inv",
...@@ -1025,4 +1059,5 @@ __all__ = [ ...@@ -1025,4 +1059,5 @@ __all__ = [
"norm", "norm",
"tensorinv", "tensorinv",
"tensorsolve", "tensorsolve",
"kron",
] ]
...@@ -15,7 +15,7 @@ from pytensor.tensor import as_tensor_variable ...@@ -15,7 +15,7 @@ from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm from pytensor.tensor import math as ptm
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import matrix_dot from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -559,51 +559,6 @@ def eigvalsh(a, b, lower=True): ...@@ -559,51 +559,6 @@ def eigvalsh(a, b, lower=True):
return Eigvalsh(lower)(a, b) return Eigvalsh(lower)(a, b)
def kron(a, b):
"""Kronecker product.
Same as scipy.linalg.kron(a, b).
Parameters
----------
a: array_like
b: array_like
Returns
-------
array_like with a.ndim + b.ndim - 2 dimensions
Notes
-----
numpy.kron(a, b) != scipy.linalg.kron(a, b)!
They don't have the same shape and order when
a.ndim != b.ndim != 2.
"""
a = as_tensor_variable(a)
b = as_tensor_variable(b)
if a.ndim + b.ndim <= 2:
raise TypeError(
"kron: inputs dimensions must sum to 3 or more. "
f"You passed {int(a.ndim)} and {int(b.ndim)}."
)
o = ptm.outer(a, b)
o = o.reshape(ptb.concatenate((a.shape, b.shape)), ndim=a.ndim + b.ndim)
shf = o.dimshuffle(0, 2, 1, *range(3, o.ndim))
if shf.ndim == 3:
shf = o.dimshuffle(1, 0, 2)
o = shf.flatten()
else:
o = shf.reshape(
(
o.shape[0] * o.shape[2],
o.shape[1] * o.shape[3],
*(o.shape[i] for i in range(4, o.ndim)),
)
)
return o
class Expm(Op): class Expm(Op):
""" """
Compute the matrix exponential of a square array. Compute the matrix exponential of a square array.
...@@ -1021,7 +976,6 @@ __all__ = [ ...@@ -1021,7 +976,6 @@ __all__ = [
"cholesky", "cholesky",
"solve", "solve",
"eigvalsh", "eigvalsh",
"kron",
"expm", "expm",
"solve_discrete_lyapunov", "solve_discrete_lyapunov",
"solve_continuous_lyapunov", "solve_continuous_lyapunov",
......
...@@ -17,6 +17,7 @@ from pytensor.tensor.nlinalg import ( ...@@ -17,6 +17,7 @@ from pytensor.tensor.nlinalg import (
det, det,
eig, eig,
eigh, eigh,
kron,
lstsq, lstsq,
matrix_dot, matrix_dot,
matrix_inverse, matrix_inverse,
...@@ -580,3 +581,42 @@ class TestTensorInv(utt.InferShapeTester): ...@@ -580,3 +581,42 @@ class TestTensorInv(utt.InferShapeTester):
t_binv1 = tf_b1(self.b1) t_binv1 = tf_b1(self.b1)
assert _allclose(t_binv, n_binv) assert _allclose(t_binv, n_binv)
assert _allclose(t_binv1, n_binv1) assert _allclose(t_binv1, n_binv1)
class TestKron(utt.InferShapeTester):
rng = np.random.default_rng(43)
def setup_method(self):
self.op = kron
super().setup_method()
@pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
@pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)])
def test_perform(self, shp0, shp1):
if len(shp0) + len(shp1) == 2:
pytest.skip("Sum of shp0 and shp1 must be more than 2")
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
# Using the np.kron to compare outputs
np_val = np.kron(a, b)
np.testing.assert_allclose(out, np_val)
@pytest.mark.parametrize(
"i, shp0, shp1",
[(0, (2, 3), (6, 7)), (1, (2, 3), (4, 3, 5)), (2, (2, 4, 3), (4, 3, 5))],
)
def test_kron_commutes_with_inv(self, i, shp0, shp1):
if (pytensor.config.floatX == "float32") & (i == 2):
pytest.skip("Half precision insufficient for test 3 to pass")
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
b = self.rng.random(shp1).astype(config.floatX)
lhs_f = function([x, y], pinv(kron(x, y)))
rhs_f = function([x, y], kron(pinv(x), pinv(y)))
atol = 1e-4 if config.floatX == "float32" else 1e-12
np.testing.assert_allclose(lhs_f(a, b), rhs_f(a, b), atol=atol)
...@@ -20,7 +20,6 @@ from pytensor.tensor.slinalg import ( ...@@ -20,7 +20,6 @@ from pytensor.tensor.slinalg import (
cholesky, cholesky,
eigvalsh, eigvalsh,
expm, expm,
kron,
solve, solve,
solve_continuous_lyapunov, solve_continuous_lyapunov,
solve_discrete_are, solve_discrete_are,
...@@ -512,46 +511,6 @@ def test_expm_grad_3(): ...@@ -512,46 +511,6 @@ def test_expm_grad_3():
utt.verify_grad(expm, [A], rng=rng) utt.verify_grad(expm, [A], rng=rng)
class TestKron(utt.InferShapeTester):
rng = np.random.default_rng(43)
def setup_method(self):
self.op = kron
super().setup_method()
def test_perform(self):
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
# Newer versions of scipy want 4 dimensions at least,
# so we have to add a dimension to a and flatten the result.
if len(shp0) + len(shp1) == 3:
scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten()
else:
scipy_val = scipy.linalg.kron(a, b)
np.testing.assert_allclose(out, scipy_val)
def test_numpy_2d(self):
for shp0 in [(2, 3)]:
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6, 7)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
assert np.allclose(out, np.kron(a, b))
def test_solve_discrete_lyapunov_via_direct_real(): def test_solve_discrete_lyapunov_via_direct_real():
N = 5 N = 5
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论