提交 71d99506 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

re-enabled tests of tensor dot products (using dot op)

上级 86eea375
...@@ -4262,20 +4262,42 @@ class t_dot(unittest.TestCase): ...@@ -4262,20 +4262,42 @@ class t_dot(unittest.TestCase):
self.assertTrue(tz.shape == nz.shape) self.assertTrue(tz.shape == nz.shape)
self.assertTrue(_approx_eq(nz, tz)) self.assertTrue(_approx_eq(nz, tz))
#def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2) def test_dot_0d_0d(self):
#def test_dot_0d_1d(self): self.cmp_dot(1.1, rand(5)) self.cmp_dot(1.1, 2.2)
#def test_dot_0d_2d(self): self.cmp_dot(3.0, rand(6,7))
#def test_dot_0d_3d(self): self.cmp_dot(3.0, rand(8,6,7)) def test_dot_0d_1d(self):
#def test_dot_1d_0d(self): self.cmp_dot(rand(5), 1.1 ) self.cmp_dot(1.1, rand(5))
def test_dot_0d_2d(self):
self.cmp_dot(3.0, rand(6,7))
def test_dot_0d_3d(self):
self.cmp_dot(3.0, rand(8,6,7))
def test_dot_1d_0d(self):
self.cmp_dot(rand(5), 1.1 )
def test_dot_1d_1d(self): def test_dot_1d_1d(self):
self.cmp_dot(rand(5), rand(5)) self.cmp_dot(rand(5), rand(5))
def test_dot_1d0_1d0(self): def test_dot_1d0_1d0(self):
self.cmp_dot(rand(0), rand(0)) self.cmp_dot(rand(0), rand(0))
#numpy return matrix not aligned... #numpy return matrix not aligned...
#def test_dot_1d_1d0(self): self.cmp_dot(rand(5), rand(0)) def test_dot_1d_1d0(self):
try:
self.cmp_dot(rand(5), rand(0))
assert False
except ValueError:
pass
#numpy return matrix not aligned... #numpy return matrix not aligned...
#def test_dot_1d0_1d(self): self.cmp_dot(rand(0), rand(5)) def test_dot_1d0_1d(self):
try:
self.cmp_dot(rand(0), rand(5))
assert False
except ValueError:
pass
def test_dot_1d_2d(self): def test_dot_1d_2d(self):
self.cmp_dot(rand(6), rand(6, 7)) self.cmp_dot(rand(6), rand(6, 7))
...@@ -4288,8 +4310,12 @@ class t_dot(unittest.TestCase): ...@@ -4288,8 +4310,12 @@ class t_dot(unittest.TestCase):
def test_dot_1d0_2d0(self): def test_dot_1d0_2d0(self):
self.cmp_dot(rand(0), rand(0, 0)) self.cmp_dot(rand(0), rand(0, 0))
#def test_dot_1d_3d(self): self.cmp_dot(rand(6), rand(8,6,7))
#def test_dot_2d_0d(self): self.cmp_dot(rand(5,6), 1.0) def test_dot_1d_3d(self):
self.cmp_dot(rand(6), rand(8,6,7))
def test_dot_2d_0d(self):
self.cmp_dot(rand(5,6), 1.0)
def test_dot_2d_1d(self): def test_dot_2d_1d(self):
self.cmp_dot(rand(5, 6), rand(6)) self.cmp_dot(rand(5, 6), rand(6))
...@@ -4320,11 +4346,21 @@ class t_dot(unittest.TestCase): ...@@ -4320,11 +4346,21 @@ class t_dot(unittest.TestCase):
def test_dot_2d0_0_2d0(self): def test_dot_2d0_0_2d0(self):
self.cmp_dot(rand(0, 6), rand(6, 0)) self.cmp_dot(rand(0, 6), rand(6, 0))
#def test_dot_2d_3d(self): self.cmp_dot(rand(5,6), rand(8,6,7))
#def test_dot_3d_0d(self): self.cmp_dot(rand(4,5,6), 1.0) def test_dot_2d_3d(self):
#def test_dot_3d_1d(self): self.cmp_dot(rand(4,5,6), rand(6)) self.cmp_dot(rand(5,6), rand(8,6,7))
#def test_dot_3d_2d(self): self.cmp_dot(rand(4,5,6), rand(6,7))
#def test_dot_3d_3d(self): self.cmp_dot(rand(4,5,6), rand(8,6,7)) def test_dot_3d_0d(self):
self.cmp_dot(rand(4,5,6), 1.0)
def test_dot_3d_1d(self):
self.cmp_dot(rand(4,5,6), rand(6))
def test_dot_3d_2d(self):
self.cmp_dot(rand(4,5,6), rand(6,7))
def test_dot_3d_3d(self):
self.cmp_dot(rand(4,5,6), rand(8,6,7))
def not_aligned(self, x, y): def not_aligned(self, x, y):
ctv_backup = config.compute_test_value ctv_backup = config.compute_test_value
...@@ -4364,7 +4400,8 @@ class t_dot(unittest.TestCase): ...@@ -4364,7 +4400,8 @@ class t_dot(unittest.TestCase):
def test_align_1_2(self): def test_align_1_2(self):
self.not_aligned(rand(5), rand(6, 4)) self.not_aligned(rand(5), rand(6, 4))
#def test_align_1_3(self): self.not_aligned(rand(5), rand(6,4,7)) def test_align_1_3(self):
self.not_aligned(rand(5), rand(6,4,7))
def test_align_2_1(self): def test_align_2_1(self):
self.not_aligned(rand(5, 4), rand(6)) self.not_aligned(rand(5, 4), rand(6))
...@@ -4372,31 +4409,44 @@ class t_dot(unittest.TestCase): ...@@ -4372,31 +4409,44 @@ class t_dot(unittest.TestCase):
def test_align_2_1(self): def test_align_2_1(self):
self.not_aligned(rand(5, 4), rand(6, 7)) self.not_aligned(rand(5, 4), rand(6, 7))
#def test_align_2_3(self): self.not_aligned(rand(5,4), rand(6,7,8)) def test_align_2_3(self):
#def test_align_3_1(self): self.not_aligned(rand(5,4,3), rand(6)) self.not_aligned(rand(5,4), rand(6,7,8))
#def test_align_3_2(self): self.not_aligned(rand(5,4,3), rand(6,7))
#def test_align_3_3(self): self.not_aligned(rand(5,4,3), rand(6,7,8)) def test_align_3_1(self):
self.not_aligned(rand(5,4,3), rand(6))
def test_align_3_2(self):
self.not_aligned(rand(5,4,3), rand(6,7))
def test_align_3_3(self):
self.not_aligned(rand(5,4,3), rand(6,7,8))
def test_grad(self): def test_grad(self):
#utt.verify_grad(dot, [rand(2,3,4), rand(4)])
utt.verify_grad(dot, [rand(2, 3), rand(3, 2)]) utt.verify_grad(dot, [rand(2, 3), rand(3, 2)])
utt.verify_grad(dot, [rand(2), rand(2, 3)]) utt.verify_grad(dot, [rand(2), rand(2, 3)])
utt.verify_grad(dot, [rand(3, 2), rand(2)]) utt.verify_grad(dot, [rand(3, 2), rand(2)])
utt.verify_grad(dot, [rand(2), rand(2)]) utt.verify_grad(dot, [rand(2), rand(2)])
#utt.verify_grad(dot, [rand(), rand(2)]) utt.verify_grad(dot, [rand(), rand(2)])
#utt.verify_grad(dot, [rand(), rand(2,5)]) utt.verify_grad(dot, [rand(), rand(2,5)])
utt.verify_grad(dot, [rand(2), rand()])
utt.verify_grad(dot, [rand(2,5), rand()])
utt.verify_grad(dot, [rand(2,3,4), rand(4)])
utt.verify_grad(dot, [rand(3), rand(2,3,4)])
utt.verify_grad(dot, [rand(4,3), rand(2,3,4)])
utt.verify_grad(dot, [rand(2,3,4), rand(4,5)])
utt.verify_grad(dot, [rand(2,3,4), rand(3,4,5)])
def test_broadcastable_patterns(self): def test_broadcastable_patterns(self):
# #
# These examples hsould all work because we broadcastable or no, all dimensions of all # These examples should all work because we broadcastable or no, all dimensions of all
# results have size 1. # results have size 1.
# #
def val_for(r): def val_for(r):
if r.dtype.startswith('complex'): if r.dtype.startswith('complex'):
# We want to test complex at the same time, so we give a value # We want to test complex at the same time, so we give a value
# To the imaginary component. # To the imaginary component.
# This stange way to doing thing is the only way that worked on # This strange way of doing things is the only way that worked on
# numpy 1.4.1 # numpy 1.4.1
if r.ndim == 0: if r.ndim == 0:
return numpy.asarray(numpy.complex(1.1, 2.1), dtype=r.dtype) return numpy.asarray(numpy.complex(1.1, 2.1), dtype=r.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论