提交 42178734 authored 作者: Frederic's avatar Frederic

In tests, use comparison that give more error message when it fail.

上级 85dc9a34
...@@ -19,9 +19,8 @@ from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar, ...@@ -19,9 +19,8 @@ from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar,
gemm_inplace, gemm_no_inplace, gemm_inplace, gemm_no_inplace,
InconsistencyError, Ger, ger, ger_destructive) InconsistencyError, Ger, ger, ger_destructive)
from theano.tests import unittest_tools from theano.tests import unittest_tools
from test_basic import (_approx_eq, as_tensor_variable, inplace_func, from test_basic import (as_tensor_variable, inplace_func,
compile, inplace) compile, inplace)
#, constant, eval_outputs)
import theano.tensor.blas_scipy import theano.tensor.blas_scipy
...@@ -50,7 +49,6 @@ class t_gemm(TestCase): ...@@ -50,7 +49,6 @@ class t_gemm(TestCase):
""" """
def setUp(self): def setUp(self):
unittest_tools.seed_rng() unittest_tools.seed_rng()
_approx_eq.debug = 0
Gemm.debug = False Gemm.debug = False
@staticmethod @staticmethod
...@@ -84,8 +82,7 @@ class t_gemm(TestCase): ...@@ -84,8 +82,7 @@ class t_gemm(TestCase):
z_after = self._gemm(z_orig, a, x, y, b) z_after = self._gemm(z_orig, a, x, y, b)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z) #print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1 unittest_tools.assert_allclose(z_after, z)
self.assertTrue(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0: if a == 0.0 and b == 1.0:
return return
elif z_orig.size == 0: elif z_orig.size == 0:
...@@ -150,7 +147,6 @@ class t_gemm(TestCase): ...@@ -150,7 +147,6 @@ class t_gemm(TestCase):
self.rand(3, 5), self.rand(5, 4), -1.0) self.rand(3, 5), self.rand(5, 4), -1.0)
def test10(self): def test10(self):
_approx_eq.debug = 1
self.cmp(self.rand(3, 4), -1.0, self.rand(3, 5), self.rand(5, 4), 0.0) self.cmp(self.rand(3, 4), -1.0, self.rand(3, 5), self.rand(5, 4), 0.0)
def test11(self): def test11(self):
...@@ -281,14 +277,11 @@ class t_gemm(TestCase): ...@@ -281,14 +277,11 @@ class t_gemm(TestCase):
f = inplace_func([], gemm_inplace(tz, ta, tx, ty, tb), f = inplace_func([], gemm_inplace(tz, ta, tx, ty, tb),
mode=compile.Mode(optimizer=None, linker=l)) mode=compile.Mode(optimizer=None, linker=l))
f() f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)), unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
(z_orig, z_after, z, z_after - z))
f() f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)), unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
(z_orig, z_after, z, z_after - z))
f() f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)), unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
(z_orig, z_after, z, z_after - z))
#tz.value *= 0 # clear z's value #tz.value *= 0 # clear z's value
y_T = ty.get_value(borrow=True).T y_T = ty.get_value(borrow=True).T
...@@ -298,7 +291,7 @@ class t_gemm(TestCase): ...@@ -298,7 +291,7 @@ class t_gemm(TestCase):
f() f()
# test that the transposed version of multiplication gives # test that the transposed version of multiplication gives
# same answer # same answer
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True).T)) unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True).T)
t(C, A, B) t(C, A, B)
t(C.T, A, B) t(C.T, A, B)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论