提交 bbf66863 authored 作者: Ilya Kulikov's avatar Ilya Kulikov

out=a.type() fix and tests for 3D tensors added

上级 6e93976e
...@@ -741,8 +741,7 @@ class TensorInv(Op): ...@@ -741,8 +741,7 @@ class TensorInv(Op):
def make_node(self, a): def make_node(self, a):
a = as_tensor_variable(a) a = as_tensor_variable(a)
out_dtype = a.dtype out = a.type()
out = theano.tensor.tensor4(dtype=out_dtype)
return Apply(self, [a], [out]) return Apply(self, [a], [out])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
......
...@@ -525,7 +525,10 @@ class test_TensorInv(utt.InferShapeTester): ...@@ -525,7 +525,10 @@ class test_TensorInv(utt.InferShapeTester):
def setUp(self): def setUp(self):
super(test_TensorInv, self).setUp() super(test_TensorInv, self).setUp()
self.A = tensor.tensor4("A", dtype=theano.config.floatX) self.A = tensor.tensor4("A", dtype=theano.config.floatX)
self.B = tensor.tensor3("B", dtype=theano.config.floatX)
self.a = numpy.random.rand(4, 6, 8, 3).astype(theano.config.floatX) self.a = numpy.random.rand(4, 6, 8, 3).astype(theano.config.floatX)
self.b = numpy.random.rand(2, 15, 30).astype(theano.config.floatX)
self.b1 = numpy.random.rand(30, 2, 15).astype(theano.config.floatX) # for ind=1 since we need prod(b1.shape[:ind]) == prod(b1.shape[ind:])
def test_infer_shape(self): def test_infer_shape(self):
A = self.A A = self.A
...@@ -539,6 +542,18 @@ class test_TensorInv(utt.InferShapeTester): ...@@ -539,6 +542,18 @@ class test_TensorInv(utt.InferShapeTester):
A = self.A A = self.A
Ai = tensorinv(A) Ai = tensorinv(A)
n_ainv = numpy.linalg.tensorinv(self.a) n_ainv = numpy.linalg.tensorinv(self.a)
tf = function([A], [Ai]) tf_a = function([A], [Ai])
t_ainv = tf(self.a) t_ainv = tf_a(self.a)
assert _allclose(n_ainv, t_ainv) assert _allclose(n_ainv, t_ainv)
B = self.B
Bi = tensorinv(B)
Bi1 = tensorinv(B, ind=1)
n_binv = numpy.linalg.tensorinv(self.b)
n_binv1 = numpy.linalg.tensorinv(self.b1, ind=1)
tf_b = function([B], [Bi])
tf_b1 = function([B], [Bi1])
t_binv = tf_b(self.b)
t_binv1 = tf_b1(self.b1)
assert _allclose(t_binv, n_binv)
assert _allclose(t_binv1, n_binv1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论