提交 8de08adf authored 作者: Frederic's avatar Frederic

Use dtype specific comparison

This should remove dtype specific to float32 in the nightly buildbot.
上级 ec5eaf4d
...@@ -5954,7 +5954,7 @@ class test_tensordot(unittest.TestCase): ...@@ -5954,7 +5954,7 @@ class test_tensordot(unittest.TestCase):
bval = rand(5) bval = rand(5)
out0 = numpy.tensordot(aval, bval, axes) out0 = numpy.tensordot(aval, bval, axes)
out1 = f1(aval, bval) out1 = f1(aval, bval)
self.assertTrue(numpy.allclose(out0, out1), (out0, out1)) utt.assert_allclose(out0, out1)
utt.verify_grad(self.TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test matrix-vector # Test matrix-vector
...@@ -5964,8 +5964,8 @@ class test_tensordot(unittest.TestCase): ...@@ -5964,8 +5964,8 @@ class test_tensordot(unittest.TestCase):
f2 = inplace_func([avec, bmat], c) f2 = inplace_func([avec, bmat], c)
aval = rand(5) aval = rand(5)
bval = rand(8, 5) bval = rand(8, 5)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes), utt.assert_allclose(numpy.tensordot(aval, bval, axes),
f2(aval, bval))) f2(aval, bval))
utt.verify_grad(self.TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test matrix-matrix # Test matrix-matrix
...@@ -5983,8 +5983,8 @@ class test_tensordot(unittest.TestCase): ...@@ -5983,8 +5983,8 @@ class test_tensordot(unittest.TestCase):
f3 = inplace_func([amat, bmat], c) f3 = inplace_func([amat, bmat], c)
aval = rand(*shps[0]) aval = rand(*shps[0])
bval = rand(*shps[1]) bval = rand(*shps[1])
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes), utt.assert_allclose(numpy.tensordot(aval, bval, axes),
f3(aval, bval))) f3(aval, bval))
utt.verify_grad(self.TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test ndarray-matrix, sum over one dim of matrix # Test ndarray-matrix, sum over one dim of matrix
...@@ -6001,8 +6001,8 @@ class test_tensordot(unittest.TestCase): ...@@ -6001,8 +6001,8 @@ class test_tensordot(unittest.TestCase):
f4 = inplace_func([atens, bmat], c) f4 = inplace_func([atens, bmat], c)
aval = rand(*shps[0]) aval = rand(*shps[0])
bval = rand(*shps[1]) bval = rand(*shps[1])
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes), utt.assert_allclose(numpy.tensordot(aval, bval, axes),
f4(aval, bval))) f4(aval, bval))
utt.verify_grad(self.TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test ndarray-ndarray # Test ndarray-ndarray
...@@ -6013,15 +6013,15 @@ class test_tensordot(unittest.TestCase): ...@@ -6013,15 +6013,15 @@ class test_tensordot(unittest.TestCase):
f5 = inplace_func([atens, btens], c) f5 = inplace_func([atens, btens], c)
aval = rand(4, 3, 5, 2) aval = rand(4, 3, 5, 2)
bval = rand(3, 4, 2) bval = rand(3, 4, 2)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes), utt.assert_allclose(numpy.tensordot(aval, bval, axes),
f5(aval, bval))) f5(aval, bval))
utt.verify_grad(self.TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
axes = (axes[1], axes[0]) axes = (axes[1], axes[0])
c = tensordot(btens, atens, axes) c = tensordot(btens, atens, axes)
f6 = inplace_func([btens, atens], c) f6 = inplace_func([btens, atens], c)
self.assertTrue(numpy.allclose(numpy.tensordot(bval, aval, axes), utt.assert_allclose(numpy.tensordot(bval, aval, axes),
f6(bval, aval))) f6(bval, aval))
utt.verify_grad(self.TensorDot(axes), [bval, aval]) utt.verify_grad(self.TensorDot(axes), [bval, aval])
def test_raise_error(self): def test_raise_error(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论