提交 6e93976e authored 作者: Ilya Kulikov's avatar Ilya Kulikov

infer_shape for TensorInv added, test for it added

上级 dfd24e90
...@@ -742,7 +742,7 @@ class TensorInv(Op): ...@@ -742,7 +742,7 @@ class TensorInv(Op):
def make_node(self, a): def make_node(self, a):
a = as_tensor_variable(a) a = as_tensor_variable(a)
out_dtype = a.dtype out_dtype = a.dtype
out = theano.tensor.matrix(dtype=out_dtype) out = theano.tensor.tensor4(dtype=out_dtype)
return Apply(self, [a], [out]) return Apply(self, [a], [out])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -750,6 +750,10 @@ class TensorInv(Op): ...@@ -750,6 +750,10 @@ class TensorInv(Op):
(x,) = outputs (x,) = outputs
x[0] = self._numop(a, self.ind) x[0] = self._numop(a, self.ind)
def infer_shape(self, node, shapes):
sp = shapes[0][self.ind:] + shapes[0][:self.ind]
return [sp]
def tensorinv(a, ind=2): def tensorinv(a, ind=2):
""" """
......
...@@ -38,6 +38,7 @@ from theano.tensor.nlinalg import ( MatrixInverse, ...@@ -38,6 +38,7 @@ from theano.tensor.nlinalg import ( MatrixInverse,
matrix_power, matrix_power,
norm, norm,
svd, svd,
TensorInv,
tensorinv tensorinv
) )
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
...@@ -520,15 +521,24 @@ class T_NormTests(unittest.TestCase): ...@@ -520,15 +521,24 @@ class T_NormTests(unittest.TestCase):
assert _allclose(n_n, t_n) assert _allclose(n_n, t_n)
def test_tensorinv(): class test_TensorInv(utt.InferShapeTester):
A = tensor.tensor4("A", dtype=theano.config.floatX) def setUp(self):
X = tensorinv(A) super(test_TensorInv, self).setUp()
tf = function([A], [X]) self.A = tensor.tensor4("A", dtype=theano.config.floatX)
self.a = numpy.random.rand(4, 6, 8, 3).astype(theano.config.floatX)
a = numpy.eye(4 * 6).astype(theano.config.floatX)
a.shape = (4, 6, 8, 3)
n_ainv = numpy.linalg.tensorinv(a) def test_infer_shape(self):
t_ainv = tf(a) A = self.A
Ai = tensorinv(A)
self._compile_and_check([A], # theano.function inputs
[Ai], # theano.function outputs
[self.a], # value to substitute
TensorInv)
assert _allclose(n_ainv, t_ainv) def test_eval(self):
A = self.A
Ai = tensorinv(A)
n_ainv = numpy.linalg.tensorinv(self.a)
tf = function([A], [Ai])
t_ainv = tf(self.a)
assert _allclose(n_ainv, t_ainv)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论