Unverified 提交 61506810 authored 作者: Abdalaziz Rashid's avatar Abdalaziz Rashid 提交者: GitHub

Implement exponentiation by squaring for matrix_power

Closes #117.
上级 ae2919bf
...@@ -10,6 +10,7 @@ from numpy.testing import assert_array_almost_equal ...@@ -10,6 +10,7 @@ from numpy.testing import assert_array_almost_equal
from theano import tensor, function from theano import tensor, function
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
from theano import config from theano import config
from theano.configparser import change_flags
from theano.tensor.nlinalg import ( from theano.tensor.nlinalg import (
MatrixInverse, MatrixInverse,
matrix_inverse, matrix_inverse,
...@@ -532,24 +533,30 @@ class TestLstsq: ...@@ -532,24 +533,30 @@ class TestLstsq:
f([2, 1], [2, 1], [2, 1]) f([2, 1], [2, 1], [2, 1])
class TestMatrix_power: class TestMatrixPower:
def test_numpy_compare(self): @change_flags(compute_test_value="raise")
rng = np.random.RandomState(utt.fetch_seed()) @pytest.mark.parametrize("n", [-1, 0, 1, 2, 3, 4, 5, 11])
def test_numpy_compare(self, n):
a = np.array([[0.1231101, 0.72381381], [0.28748201, 0.43036511]]).astype(
theano.config.floatX
)
A = tensor.matrix("A", dtype=theano.config.floatX) A = tensor.matrix("A", dtype=theano.config.floatX)
Q = matrix_power(A, 3) A.tag.test_value = a
fn = function([A], [Q]) Q = matrix_power(A, n)
a = rng.rand(4, 4).astype(theano.config.floatX) n_p = np.linalg.matrix_power(a, n)
assert np.allclose(n_p, Q.get_test_value())
n_p = np.linalg.matrix_power(a, 3)
t_p = fn(a)
assert np.allclose(n_p, t_p)
def test_non_square_matrix(self): def test_non_square_matrix(self):
rng = np.random.RandomState(utt.fetch_seed())
A = tensor.matrix("A", dtype=theano.config.floatX) A = tensor.matrix("A", dtype=theano.config.floatX)
Q = matrix_power(A, 3) Q = matrix_power(A, 3)
f = function([A], [Q]) f = function([A], [Q])
a = rng.rand(4, 3).astype(theano.config.floatX) a = np.array(
[
[0.47497769, 0.81869379],
[0.74387558, 0.31780172],
[0.54381007, 0.28153101],
]
).astype(theano.config.floatX)
with pytest.raises(ValueError): with pytest.raises(ValueError):
f(a) f(a)
......
...@@ -663,17 +663,43 @@ class lstsq(Op): ...@@ -663,17 +663,43 @@ class lstsq(Op):
def matrix_power(M, n): def matrix_power(M, n):
""" r"""
Raise a square matrix to the (integer) power n. Raise a square matrix to the (integer) power n.
This implementation uses exponentiation by squaring which is
significantly faster than the naive implementation.
The time complexity for exponentiation by squaring is
:math: `\mathcal{O}((n \log M)^k)`
Parameters Parameters
---------- ----------
M : Tensor variable M : Tensor variable
n : Python int n : Python int
""" """
result = 1 if n < 0:
for i in range(n): M = pinv(M)
result = theano.dot(result, M) n = abs(n)
# Shortcuts when 0 < n <= 3
if n == 0:
return tensor.eye(M.shape[-2])
elif n == 1:
return M
elif n == 2:
return theano.dot(M, M)
elif n == 3:
return theano.dot(theano.dot(M, M), M)
result = z = None
while n > 0:
z = M if z is None else theano.dot(z, z)
n, bit = divmod(n, 2)
if bit:
result = z if result is None else theano.dot(result, z)
return result return result
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论