Unverified 提交 61506810 authored 作者: Abdalaziz Rashid's avatar Abdalaziz Rashid 提交者: GitHub

Implement exponentiation by squaring for matrix_power

Closes #117.
上级 ae2919bf
......@@ -10,6 +10,7 @@ from numpy.testing import assert_array_almost_equal
from theano import tensor, function
from theano.tensor.basic import _allclose
from theano import config
from theano.configparser import change_flags
from theano.tensor.nlinalg import (
MatrixInverse,
matrix_inverse,
......@@ -532,24 +533,30 @@ class TestLstsq:
f([2, 1], [2, 1], [2, 1])
class TestMatrix_power:
def test_numpy_compare(self):
rng = np.random.RandomState(utt.fetch_seed())
class TestMatrixPower:
@change_flags(compute_test_value="raise")
@pytest.mark.parametrize("n", [-1, 0, 1, 2, 3, 4, 5, 11])
def test_numpy_compare(self, n):
a = np.array([[0.1231101, 0.72381381], [0.28748201, 0.43036511]]).astype(
theano.config.floatX
)
A = tensor.matrix("A", dtype=theano.config.floatX)
Q = matrix_power(A, 3)
fn = function([A], [Q])
a = rng.rand(4, 4).astype(theano.config.floatX)
n_p = np.linalg.matrix_power(a, 3)
t_p = fn(a)
assert np.allclose(n_p, t_p)
A.tag.test_value = a
Q = matrix_power(A, n)
n_p = np.linalg.matrix_power(a, n)
assert np.allclose(n_p, Q.get_test_value())
def test_non_square_matrix(self):
rng = np.random.RandomState(utt.fetch_seed())
A = tensor.matrix("A", dtype=theano.config.floatX)
Q = matrix_power(A, 3)
f = function([A], [Q])
a = rng.rand(4, 3).astype(theano.config.floatX)
a = np.array(
[
[0.47497769, 0.81869379],
[0.74387558, 0.31780172],
[0.54381007, 0.28153101],
]
).astype(theano.config.floatX)
with pytest.raises(ValueError):
f(a)
......
......@@ -663,17 +663,43 @@ class lstsq(Op):
def matrix_power(M, n):
"""
r"""
Raise a square matrix to the (integer) power n.
This implementation uses exponentiation by squaring which is
significantly faster than the naive implementation.
The time complexity for exponentiation by squaring is
:math: `\mathcal{O}((n \log M)^k)`
Parameters
----------
M : Tensor variable
n : Python int
"""
result = 1
for i in range(n):
result = theano.dot(result, M)
if n < 0:
M = pinv(M)
n = abs(n)
# Shortcuts when 0 < n <= 3
if n == 0:
return tensor.eye(M.shape[-2])
elif n == 1:
return M
elif n == 2:
return theano.dot(M, M)
elif n == 3:
return theano.dot(theano.dot(M, M), M)
result = z = None
while n > 0:
z = M if z is None else theano.dot(z, z)
n, bit = divmod(n, 2)
if bit:
result = z if result is None else theano.dot(result, z)
return result
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论