提交 f1f333d0 authored 作者: yobibyte's avatar yobibyte

Add tests of tensorsolve for type upcast.

上级 fb202ecb
...@@ -821,7 +821,7 @@ class TensorSolve(Op): ...@@ -821,7 +821,7 @@ class TensorSolve(Op):
def tensorsolve(a, b, axes=None): def tensorsolve(a, b, axes=None):
""" """
Theano utilization of numpy.linalg.tensorsolve Theano utilization of numpy.linalg.tensorsolve. Does not run on GPU!
Solve the tensor equation ``a x = b`` for x. Solve the tensor equation ``a x = b`` for x.
It is assumed that all indices of `x` are summed over in the product, It is assumed that all indices of `x` are summed over in the product,
......
...@@ -165,6 +165,7 @@ def test_svd(): ...@@ -165,6 +165,7 @@ def test_svd():
def test_tensorsolve(): def test_tensorsolve():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
A = tensor.tensor4("A", dtype=theano.config.floatX) A = tensor.tensor4("A", dtype=theano.config.floatX)
B = tensor.matrix("B", dtype=theano.config.floatX) B = tensor.matrix("B", dtype=theano.config.floatX)
X = tensorsolve(A, B) X = tensorsolve(A, B)
...@@ -179,6 +180,34 @@ def test_tensorsolve(): ...@@ -179,6 +180,34 @@ def test_tensorsolve():
t_x = fn(a, b) t_x = fn(a, b)
assert _allclose(n_x, t_x) assert _allclose(n_x, t_x)
# check the type upcast now
C = tensor.tensor4("C", dtype='float32')
D = tensor.matrix("D", dtype='float64')
Y = tensorsolve(C, D)
fn = function([C, D], [Y])
c = numpy.eye(2*3*4).astype('float32')
c.shape = (2*3, 4, 2, 3*4)
d = rng.rand(2*3, 4).astype('float64')
n_y = numpy.linalg.tensorsolve(c, d)
t_y = fn(c, d)
assert _allclose(n_y, t_y)
assert n_y.dtype == Y.dtype
# check the type upcast now
E = tensor.tensor4("E", dtype='int32')
F = tensor.matrix("F", dtype='float64')
Z = tensorsolve(E, F)
fn = function([E, F], [Z])
e = numpy.eye(2*3*4).astype('int32')
e.shape = (2*3, 4, 2, 3*4)
f = rng.rand(2*3, 4).astype('float64')
n_z = numpy.linalg.tensorsolve(e, f)
t_z = fn(e, f)
assert _allclose(n_z, t_z)
assert n_z.dtype == Z.dtype
def test_inverse_singular(): def test_inverse_singular():
singular = numpy.array([[1, 0, 0]] + [[0, 1, 0]] * 2, singular = numpy.array([[1, 0, 0]] + [[0, 1, 0]] * 2,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论