提交 48b4fd31 authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

fixes to tests

上级 05e5edc5
...@@ -60,7 +60,7 @@ class test_structured_add_s_v(unittest.TestCase): ...@@ -60,7 +60,7 @@ class test_structured_add_s_v(unittest.TestCase):
for format in ['csr', 'csc']: for format in ['csr', 'csc']:
for dtype in ['float32', 'float64']: for dtype in ['float32', 'float64']:
spmat = sp_types[format](random_lil((4, 3), dtype, 3)) spmat = sp_types[format](random_lil((4, 3), dtype, 3))
mat = numpy.ones(3, dtype=dtype) mat = numpy.asarray(numpy.random.rand(3), dtype=dtype)
S.verify_grad_sparse(S2.structured_add_s_v, S.verify_grad_sparse(S2.structured_add_s_v,
[spmat, mat], structured=True) [spmat, mat], structured=True)
...@@ -78,11 +78,11 @@ class test_structured_add_s_v(unittest.TestCase): ...@@ -78,11 +78,11 @@ class test_structured_add_s_v(unittest.TestCase):
spmat = sp_types[format](random_lil((4, 3), dtype, 3)) spmat = sp_types[format](random_lil((4, 3), dtype, 3))
spones = spmat.copy() spones = spmat.copy()
spones.data = numpy.ones_like(spones.data) spones.data = numpy.ones_like(spones.data)
mat = numpy.ones(3, dtype=dtype) mat = numpy.asarray(numpy.random.rand(3), dtype=dtype)
out = f(spmat, mat) out = f(spmat, mat)
assert numpy.all(out.toarray() == spones.multiply(spmat + mat)) assert numpy.allclose(out.toarray(), spones.multiply(spmat + mat))
class test_mul_s_v(unittest.TestCase): class test_mul_s_v(unittest.TestCase):
...@@ -96,7 +96,7 @@ class test_mul_s_v(unittest.TestCase): ...@@ -96,7 +96,7 @@ class test_mul_s_v(unittest.TestCase):
for format in ['csr', 'csc']: for format in ['csr', 'csc']:
for dtype in ['float32', 'float64']: for dtype in ['float32', 'float64']:
spmat = sp_types[format](random_lil((4, 3), dtype, 3)) spmat = sp_types[format](random_lil((4, 3), dtype, 3))
mat = numpy.ones(3, dtype=dtype) mat = numpy.asarray(numpy.random.rand(3), dtype=dtype)
S.verify_grad_sparse(S2.mul_s_v, S.verify_grad_sparse(S2.mul_s_v,
[spmat, mat], structured=True) [spmat, mat], structured=True)
...@@ -112,13 +112,11 @@ class test_mul_s_v(unittest.TestCase): ...@@ -112,13 +112,11 @@ class test_mul_s_v(unittest.TestCase):
f = theano.function([x, y], S2.mul_s_v(x, y)) f = theano.function([x, y], S2.mul_s_v(x, y))
spmat = sp_types[format](random_lil((4, 3), dtype, 3)) spmat = sp_types[format](random_lil((4, 3), dtype, 3))
spones = spmat.copy() mat = numpy.asarray(numpy.random.rand(3), dtype=dtype)
spones.data = numpy.ones_like(spones.data)
mat = numpy.ones(3, dtype=dtype)
out = f(spmat, mat) out = f(spmat, mat)
assert numpy.all(out.toarray() == (spmat.toarray() * mat)) assert numpy.allclose(out.toarray(), spmat.toarray() * mat)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论