提交 cd8c9c69 authored 作者: Tanjay94's avatar Tanjay94

Fixed A_Xinv_b function to work properly and added a test code for it.

上级 50cb3c06
...@@ -879,13 +879,8 @@ class A_Xinv_b(Op): ...@@ -879,13 +879,8 @@ class A_Xinv_b(Op):
def perform(self, ndoe, inputs, outstor): def perform(self, ndoe, inputs, outstor):
a, X, b = inputs a, X, b = inputs
if 1: iX = numpy.linalg.inv(X)
L_factor = scipy.linalg.cho_factor(X) z = numpy.dot(numpy.dot(a, iX), b)
xb = scipy.linalg.cho_solve(L_factor, b)
xa = scipy.linalg.cho_solve(L_factor, a.T)
z = numpy.dot(xa.T, xb)
else:
raise NotImplementedError(self.X_structure)
outstor[0][0] = z outstor[0][0] = z
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
......
...@@ -686,18 +686,7 @@ class test_A_Xinv_b(): ...@@ -686,18 +686,7 @@ class test_A_Xinv_b():
X = [[1, 1], [1, 1]] X = [[1, 1], [1, 1]]
Y = [[2, 1], [3, 4]] Y = [[2, 1], [3, 4]]
Z = [[1, 1], [1, 1]] Z = [[1, 1], [1, 1]]
assert numpy.allclose(f(X, Y, Z), [[0.20408163, 0.20408163], [0.20408163, 0.20408163]]) assert numpy.allclose(f(X, Y, Z), [[0.4, 0.4], [0.4, 0.4]])
def test_definite_positive(self):
x = tensor.matrix()
y = tensor.matrix()
z = tensor.matrix()
m = A_Xinv_b()(x, y, z)
f = function([x, y, z], m)
X = [[1, 1], [1, 1]]
Y = [[2, -9], [3, 4]]
Z = [[1, 1], [1, 1]]
assert_raises(numpy.linalg.LinAlgError, f, X, Y, Z)
def test_shape_conflict(self): def test_shape_conflict(self):
x = tensor.matrix() x = tensor.matrix()
...@@ -708,7 +697,7 @@ class test_A_Xinv_b(): ...@@ -708,7 +697,7 @@ class test_A_Xinv_b():
X = [[1, 1, 1], [1, 1, 1]] X = [[1, 1, 1], [1, 1, 1]]
Y = [[2, -9], [3, 4]] Y = [[2, -9], [3, 4]]
Z = [[1, 1], [1, 1]] Z = [[1, 1], [1, 1]]
assert_raises(numpy.linalg.LinAlgError, f, X, Y, Z) assert_raises(ValueError, f, X, Y, Z)
def test_grad(self): def test_grad(self):
x = tensor.matrix() x = tensor.matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论