提交 41ebb2de authored 作者: Tanjay94's avatar Tanjay94

Removed unnecessary A_Xinv_B function.

上级 cd8c9c69
...@@ -863,36 +863,6 @@ def spectral_radius_bound(X, log2_exponent): ...@@ -863,36 +863,6 @@ def spectral_radius_bound(X, log2_exponent):
2 ** (-log2_exponent)) 2 ** (-log2_exponent))
class A_Xinv_b(Op):
"""Product of form a inv(X) b"""
def make_node(self, a, X, b):
assert imported_scipy, (
"Scipy not available. Scipy is needed for the A_Xinv_b op")
a = as_tensor_variable(a)
X = as_tensor_variable(X)
b = as_tensor_variable(b)
assert a.ndim == 2
assert X.ndim == 2
assert b.ndim == 2
o = theano.tensor.matrix(dtype=X.dtype)
return Apply(self, [a, X, b], [o])
def perform(self, ndoe, inputs, outstor):
a, X, b = inputs
iX = numpy.linalg.inv(X)
z = numpy.dot(numpy.dot(a, iX), b)
outstor[0][0] = z
def grad(self, inputs, g_outputs):
gz, = g_outputs
a, X, b = inputs
iX = matrix_inverse(X)
ga = matrix_dot(gz, b.T, iX.T)
gX = -matrix_dot(iX.T, a, gz, b.T, iX.T)
gb = matrix_dot(iX.T, a.T, gz)
return [ga, gX, gb]
class Eig(Op): class Eig(Op):
"""Compute the eigenvalues and right eigenvectors of a square array. """Compute the eigenvalues and right eigenvectors of a square array.
......
...@@ -33,7 +33,7 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -33,7 +33,7 @@ from theano.sandbox.linalg.ops import (cholesky,
imported_scipy, imported_scipy,
Eig, Eig,
inv_as_solve, inv_as_solve,
A_Xinv_b A_Xinv_B
) )
from theano.sandbox.linalg import eig, eigh, eigvalsh from theano.sandbox.linalg import eig, eigh, eigvalsh
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
...@@ -674,39 +674,3 @@ def test_eigvalsh_grad(): ...@@ -674,39 +674,3 @@ def test_eigvalsh_grad():
b = 10 * numpy.eye(5, 5) + rng.randn(5, 5) b = 10 * numpy.eye(5, 5) + rng.randn(5, 5)
tensor.verify_grad(lambda a, b: eigvalsh(a, b).dot([1, 2, 3, 4, 5]), tensor.verify_grad(lambda a, b: eigvalsh(a, b).dot([1, 2, 3, 4, 5]),
[a, b], rng=numpy.random) [a, b], rng=numpy.random)
class test_A_Xinv_b():
def test_A_Xinv_b(self):
x = tensor.matrix()
y = tensor.matrix()
z = tensor.matrix()
m = A_Xinv_b()(x, y, z)
f = function([x, y, z], m)
X = [[1, 1], [1, 1]]
Y = [[2, 1], [3, 4]]
Z = [[1, 1], [1, 1]]
assert numpy.allclose(f(X, Y, Z), [[0.4, 0.4], [0.4, 0.4]])
def test_shape_conflict(self):
x = tensor.matrix()
y = tensor.matrix()
z = tensor.matrix()
m = A_Xinv_b()(x, y, z)
f = function([x, y, z], m)
X = [[1, 1, 1], [1, 1, 1]]
Y = [[2, -9], [3, 4]]
Z = [[1, 1], [1, 1]]
assert_raises(ValueError, f, X, Y, Z)
def test_grad(self):
x = tensor.matrix()
y = tensor.matrix()
z = tensor.matrix()
m = A_Xinv_b()(x, y, z)
f = function([x, y, z], m)
X = numpy.asarray([[1, 1], [1, 1]], dtype='float32')
Y = numpy.asarray([[2, 1], [3, 4]], dtype='float32')
Z = numpy.asarray([[1, 1], [1, 1]], dtype='float32')
theano.tests.unittest_tools.verify_grad(A_Xinv_b(),
[X, Y, Z])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论