提交 bcca6bb1 authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

fix test test_multMatVect

上级 f97acd71
......@@ -877,25 +877,36 @@ def test_gradient_scan():
def test_multMatVect():
A = tensor.lmatrix('A')
s = tensor.ivector('s')
m = tensor.iscalar('m')
A1 = tensor.lmatrix('A1')
s1 = tensor.ivector('s1')
m1 = tensor.iscalar('m1')
A2 = tensor.lmatrix('A2')
s2 = tensor.ivector('s2')
m2 = tensor.iscalar('m2')
g0 = rng_mrg.DotModulo()(A, s, m)
f0 = theano.function([A, s, m], g0)
g0 = rng_mrg.DotModulo()(A1, s1, m1, A2, s2, m2)
f0 = theano.function([A1, s1, m1, A2, s2, m2], g0)
A = numpy.random.randint(0, numpy.iinfo(numpy.int32).max, (3, 3)).astype('int64')
s = numpy.random.randint(0, numpy.iinfo(numpy.int32).max, 3).astype('int32')
m = numpy.asarray(numpy.random.randint(numpy.iinfo(numpy.int32).max), dtype="int32")
A1 = numpy.random.randint(0, numpy.iinfo(numpy.int32).max, (3, 3)).astype('int64')
s1 = numpy.random.randint(0, numpy.iinfo(numpy.int32).max, 3).astype('int32')
m1 = numpy.asarray(numpy.random.randint(numpy.iinfo(numpy.int32).max), dtype="int32")
A2 = numpy.random.randint(0, numpy.iinfo(numpy.int32).max, (3, 3)).astype('int64')
s2 = numpy.random.randint(0, numpy.iinfo(numpy.int32).max, 3).astype('int32')
m2 = numpy.asarray(numpy.random.randint(numpy.iinfo(numpy.int32).max), dtype="int32")
f0.input_storage[0].storage[0] = A
f0.input_storage[1].storage[0] = s
f0.input_storage[2].storage[0] = m
f0.input_storage[0].storage[0] = A1
f0.input_storage[1].storage[0] = s1
f0.input_storage[2].storage[0] = m1
f0.input_storage[3].storage[0] = A2
f0.input_storage[4].storage[0] = s2
f0.input_storage[5].storage[0] = m2
r_a = rng_mrg.matVecModM(A, s, m)
r_b = f0.fn()
assert numpy.allclose(r_a, r_b)
r_a1 = rng_mrg.matVecModM(A1, s1, m1)
r_a2 = rng_mrg.matVecModM(A2, s2, m2)
r_b = f0.fn()[0]
assert numpy.allclose(r_a1, r_b[:3])
assert numpy.allclose(r_a2, r_b[3:])
if __name__ == "__main__":
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论