提交 c0291c58 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1447 from nouiz/err_msg

Use dtype specific comparison
......@@ -5954,7 +5954,7 @@ class test_tensordot(unittest.TestCase):
bval = rand(5)
out0 = numpy.tensordot(aval, bval, axes)
out1 = f1(aval, bval)
self.assertTrue(numpy.allclose(out0, out1), (out0, out1))
utt.assert_allclose(out0, out1)
utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test matrix-vector
......@@ -5964,8 +5964,8 @@ class test_tensordot(unittest.TestCase):
f2 = inplace_func([avec, bmat], c)
aval = rand(5)
bval = rand(8, 5)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f2(aval, bval)))
utt.assert_allclose(numpy.tensordot(aval, bval, axes),
f2(aval, bval))
utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test matrix-matrix
......@@ -5983,8 +5983,8 @@ class test_tensordot(unittest.TestCase):
f3 = inplace_func([amat, bmat], c)
aval = rand(*shps[0])
bval = rand(*shps[1])
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f3(aval, bval)))
utt.assert_allclose(numpy.tensordot(aval, bval, axes),
f3(aval, bval))
utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test ndarray-matrix, sum over one dim of matrix
......@@ -6001,8 +6001,8 @@ class test_tensordot(unittest.TestCase):
f4 = inplace_func([atens, bmat], c)
aval = rand(*shps[0])
bval = rand(*shps[1])
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f4(aval, bval)))
utt.assert_allclose(numpy.tensordot(aval, bval, axes),
f4(aval, bval))
utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test ndarray-ndarray
......@@ -6013,15 +6013,15 @@ class test_tensordot(unittest.TestCase):
f5 = inplace_func([atens, btens], c)
aval = rand(4, 3, 5, 2)
bval = rand(3, 4, 2)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f5(aval, bval)))
utt.assert_allclose(numpy.tensordot(aval, bval, axes),
f5(aval, bval))
utt.verify_grad(self.TensorDot(axes), [aval, bval])
axes = (axes[1], axes[0])
c = tensordot(btens, atens, axes)
f6 = inplace_func([btens, atens], c)
self.assertTrue(numpy.allclose(numpy.tensordot(bval, aval, axes),
f6(bval, aval)))
utt.assert_allclose(numpy.tensordot(bval, aval, axes),
f6(bval, aval))
utt.verify_grad(self.TensorDot(axes), [bval, aval])
def test_raise_error(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论