提交 bf1d0950 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

added unit tests for transinv_to_invtrans and tag_solve_triangular

上级 b9595fe3
...@@ -12,6 +12,8 @@ from theano.tensor.basic import _allclose ...@@ -12,6 +12,8 @@ from theano.tensor.basic import _allclose
from theano.tests.test_rop import break_op from theano.tests.test_rop import break_op
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano import config from theano import config
from theano.tensor.nlinalg import MatrixInverse
from theano.tensor import DimShuffle
# The one in comment are not tested... # The one in comment are not tested...
from theano.sandbox.linalg.ops import (cholesky, from theano.sandbox.linalg.ops import (cholesky,
...@@ -20,6 +22,7 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -20,6 +22,7 @@ from theano.sandbox.linalg.ops import (cholesky,
matrix_inverse, matrix_inverse,
pinv, pinv,
Solve, Solve,
solve,
diag, diag,
ExtractDiag, ExtractDiag,
extract_diag, extract_diag,
...@@ -137,3 +140,31 @@ def test_spectral_radius_bound(): ...@@ -137,3 +140,31 @@ def test_spectral_radius_bound():
except ValueError: except ValueError:
ok = True ok = True
assert ok assert ok
def test_transinv_to_invtrans():
X = tensor.matrix('X')
Y = tensor.nlinalg.matrix_inverse(X)
Z = Y.transpose()
f = theano.function([X], Z)
for node in f.maker.fgraph.toposort():
if isinstance(node.op, MatrixInverse):
assert isinstance(node.inputs[0].owner.op, DimShuffle)
if isinstance(node.op, DimShuffle):
assert node.inputs[0].name == 'X'
def test_tag_solve_triangular():
cholesky_lower = Cholesky(lower=True)
cholesky_upper = Cholesky(lower=False)
A = tensor.matrix('A')
x = tensor.vector('x')
L = cholesky_lower(A)
U = cholesky_upper(A)
b1 = solve(L, x)
b2 = solve(U, x)
b = b1 + b2
f = theano.function([A,x], b)
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.A_structure in ['lower_triangular', 'upper_triangular']
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论