提交 e8050919 authored 作者: Tanjay94's avatar Tanjay94

Added test for A_Xinv_b function.

上级 9e00f9a8
...@@ -38,6 +38,7 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -38,6 +38,7 @@ from theano.sandbox.linalg.ops import (cholesky,
from theano.sandbox.linalg import eig, eigh, eigvalsh from theano.sandbox.linalg import eig, eigh, eigvalsh
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
from nose.tools import assert_raises
def check_lower_triangular(pd, ch_f): def check_lower_triangular(pd, ch_f):
...@@ -672,7 +673,8 @@ def test_eigvalsh_grad(): ...@@ -672,7 +673,8 @@ def test_eigvalsh_grad():
[a, b], rng=numpy.random) [a, b], rng=numpy.random)
def test_A_Xinv_b(): class test_A_Xinv_b():
def test_A_Xinv_b(self):
x = tensor.matrix() x = tensor.matrix()
y = tensor.matrix() y = tensor.matrix()
z = tensor.matrix() z = tensor.matrix()
...@@ -682,3 +684,37 @@ def test_A_Xinv_b(): ...@@ -682,3 +684,37 @@ def test_A_Xinv_b():
Y = [[2, 1], [3, 4]] Y = [[2, 1], [3, 4]]
Z = [[1, 1], [1, 1]] Z = [[1, 1], [1, 1]]
assert numpy.allclose(f(X, Y, Z), [[0.20408163, 0.20408163], [0.20408163, 0.20408163]]) assert numpy.allclose(f(X, Y, Z), [[0.20408163, 0.20408163], [0.20408163, 0.20408163]])
def test_definite_positive(self):
x = tensor.matrix()
y = tensor.matrix()
z = tensor.matrix()
m = A_Xinv_b()(x, y, z)
f = function([x, y, z], m)
X = [[1, 1], [1, 1]]
Y = [[2, -9], [3, 4]]
Z = [[1, 1], [1, 1]]
assert_raises(numpy.linalg.LinAlgError, f, X, Y, Z)
def test_shape_conflict(self):
x = tensor.matrix()
y = tensor.matrix()
z = tensor.matrix()
m = A_Xinv_b()(x, y, z)
f = function([x, y, z], m)
X = [[1, 1, 1], [1, 1, 1]]
Y = [[2, -9], [3, 4]]
Z = [[1, 1], [1, 1]]
assert_raises(numpy.linalg.LinAlgError, f, X, Y, Z)
def test_grad(self):
x = tensor.matrix()
y = tensor.matrix()
z = tensor.matrix()
m = A_Xinv_b()(x, y, z)
f = function([x, y, z], m)
X = numpy.asarray([[1, 1], [1, 1]], dtype='float32')
Y = numpy.asarray([[2, 1], [3, 4]], dtype='float32')
Z = numpy.asarray([[1, 1], [1, 1]], dtype='float32')
theano.tests.unittest_tools.verify_grad(A_Xinv_b(),
[X, Y, Z])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论