提交 42178734 authored 作者: Frederic's avatar Frederic

In tests, use comparison that give more error message when it fail.

上级 85dc9a34
......@@ -19,9 +19,8 @@ from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar,
gemm_inplace, gemm_no_inplace,
InconsistencyError, Ger, ger, ger_destructive)
from theano.tests import unittest_tools
from test_basic import (_approx_eq, as_tensor_variable, inplace_func,
compile, inplace)
#, constant, eval_outputs)
from test_basic import (as_tensor_variable, inplace_func,
compile, inplace)
import theano.tensor.blas_scipy
......@@ -50,7 +49,6 @@ class t_gemm(TestCase):
"""
def setUp(self):
unittest_tools.seed_rng()
_approx_eq.debug = 0
Gemm.debug = False
@staticmethod
......@@ -84,8 +82,7 @@ class t_gemm(TestCase):
z_after = self._gemm(z_orig, a, x, y, b)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1
self.assertTrue(_approx_eq(z_after, z))
unittest_tools.assert_allclose(z_after, z)
if a == 0.0 and b == 1.0:
return
elif z_orig.size == 0:
......@@ -150,7 +147,6 @@ class t_gemm(TestCase):
self.rand(3, 5), self.rand(5, 4), -1.0)
def test10(self):
_approx_eq.debug = 1
self.cmp(self.rand(3, 4), -1.0, self.rand(3, 5), self.rand(5, 4), 0.0)
def test11(self):
......@@ -281,14 +277,11 @@ class t_gemm(TestCase):
f = inplace_func([], gemm_inplace(tz, ta, tx, ty, tb),
mode=compile.Mode(optimizer=None, linker=l))
f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)),
(z_orig, z_after, z, z_after - z))
unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)),
(z_orig, z_after, z, z_after - z))
unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)),
(z_orig, z_after, z, z_after - z))
unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
#tz.value *= 0 # clear z's value
y_T = ty.get_value(borrow=True).T
......@@ -298,7 +291,7 @@ class t_gemm(TestCase):
f()
# test that the transposed version of multiplication gives
# same answer
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True).T))
unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True).T)
t(C, A, B)
t(C.T, A, B)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论