提交 e78065ec authored 作者: abalkin's avatar abalkin

Use numpy's assert_array_equal in tests.

上级 7254a4b1
...@@ -14,7 +14,7 @@ builtin_min = __builtin__.min ...@@ -14,7 +14,7 @@ builtin_min = __builtin__.min
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import numpy import numpy
from numpy.testing import dec from numpy.testing import dec, assert_array_equal
from numpy.testing.noseclasses import KnownFailureTest from numpy.testing.noseclasses import KnownFailureTest
import theano import theano
...@@ -7009,73 +7009,73 @@ class TestTensorInstanceMethods(unittest.TestCase): ...@@ -7009,73 +7009,73 @@ class TestTensorInstanceMethods(unittest.TestCase):
def test_argmin(self): def test_argmin(self):
X, _ = self.vars X, _ = self.vars
x, _ = self.vals x, _ = self.vals
self.assertTrue(numpy.all(X.argmin().eval({X: x}) == x.argmin())) assert_array_equal(X.argmin().eval({X: x}), x.argmin())
def test_argmax(self): def test_argmax(self):
X, _ = self.vars X, _ = self.vars
x, _ = self.vals x, _ = self.vals
self.assertTrue(numpy.all(X.argmax().eval({X: x}) == x.argmax())) assert_array_equal(X.argmax().eval({X: x}), x.argmax())
def test_argsort(self): def test_argsort(self):
X, _ = self.vars X, _ = self.vars
x, _ = self.vals x, _ = self.vals
self.assertTrue(numpy.all(X.argsort().eval({X: x}) == x.argsort())) assert_array_equal(X.argsort().eval({X: x}), x.argsort())
self.assertTrue(numpy.all(X.argsort(1).eval({X: x}) == x.argsort(1))) assert_array_equal(X.argsort(1).eval({X: x}), x.argsort(1))
def test_dot(self): def test_dot(self):
X, Y = self.vars X, Y = self.vars
x, y = self.vals x, y = self.vals
Z = X.clip(0.5 - Y, 0.5 + Y) Z = X.clip(0.5 - Y, 0.5 + Y)
z = x.clip(0.5 - y, 0.5 + y) z = x.clip(0.5 - y, 0.5 + y)
self.assertTrue(numpy.all(Z.eval({X: x, Y: y}) == z)) assert_array_equal(Z.eval({X: x, Y: y}), z)
def test_dot(self): def test_dot(self):
X, Y = self.vars X, Y = self.vars
x, y = self.vals x, y = self.vals
self.assertTrue(numpy.all(x.dot(y) == X.dot(Y).eval({X: x, Y: y}))) assert_array_equal(x.dot(y), X.dot(Y).eval({X: x, Y: y}))
Z = X.dot(Y) Z = X.dot(Y)
z = x.dot(y) z = x.dot(y)
self.assertTrue(numpy.all(x.dot(z) == X.dot(Z).eval({X: x, Z: z}))) assert_array_equal(x.dot(z), X.dot(Z).eval({X: x, Z: z}))
def test_real_imag(self): def test_real_imag(self):
X, Y = self.vars X, Y = self.vars
x, y = self.vals x, y = self.vals
Z = X + Y * 1j Z = X + Y * 1j
z = x + y * 1j z = x + y * 1j
self.assertTrue(numpy.all(Z.real.eval({Z: z}) == x)) assert_array_equal(Z.real.eval({Z: z}), x)
self.assertTrue(numpy.all(Z.imag.eval({Z: z}) == y)) assert_array_equal(Z.imag.eval({Z: z}), y)
def test_conj(self): def test_conj(self):
X, Y = self.vars X, Y = self.vars
x, y = self.vals x, y = self.vals
Z = X + Y * 1j Z = X + Y * 1j
z = x + y * 1j z = x + y * 1j
self.assertTrue(numpy.all(Z.conj().eval({Z: z}) == z.conj())) assert_array_equal(Z.conj().eval({Z: z}), z.conj())
def test_round(self): def test_round(self):
X, _ = self.vars X, _ = self.vars
x, _ = self.vals x, _ = self.vals
self.assertTrue(numpy.all(X.round().eval({X: x}) == x.round())) assert_array_equal(X.round().eval({X: x}), x.round())
def test_std(self): def test_std(self):
X, _ = self.vars X, _ = self.vars
x, _ = self.vals x, _ = self.vals
self.assertTrue(numpy.all(X.std().eval({X: x}) == x.std())) assert_array_equal(X.std().eval({X: x}), x.std())
def test_repeat(self): def test_repeat(self):
X, _ = self.vars X, _ = self.vars
x, _ = self.vals x, _ = self.vals
self.assertTrue(numpy.all(X.repeat(2).eval({X: x}) == x.repeat(2))) assert_array_equal(X.repeat(2).eval({X: x}), x.repeat(2))
def test_trace(self): def test_trace(self):
X, _ = self.vars X, _ = self.vars
x, _ = self.vals x, _ = self.vals
self.assertTrue(numpy.all(X.trace().eval({X: x}) == x.trace())) assert_array_equal(X.trace().eval({X: x}), x.trace())
def test_ravel(self): def test_ravel(self):
X, _ = self.vars X, _ = self.vars
x, _ = self.vals x, _ = self.vals
self.assertTrue(numpy.all(X.ravel().eval({X: x}) == x.ravel())) assert_array_equal(X.ravel().eval({X: x}), x.ravel())
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论