added float32 tests of gemm

上级 13db39ca
...@@ -1108,8 +1108,8 @@ class t_gemm(unittest.TestCase): ...@@ -1108,8 +1108,8 @@ class t_gemm(unittest.TestCase):
B = self.rand(4,5)[:,:4] B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4] C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l=gof.cc.OpWiseCLinker): def t(z,x,y,a=1.0, b=0.0,l=gof.cc.OpWiseCLinker,dt='float64'):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b] z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b]
z_orig = z.copy() z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b) z_after = self._gemm(z, a, x, y, b)
...@@ -1123,17 +1123,17 @@ class t_gemm(unittest.TestCase): ...@@ -1123,17 +1123,17 @@ class t_gemm(unittest.TestCase):
t(C,A,B) t(C,A,B)
t(C.T, A, B) t(C.T, A, B)
t(C, A.T, B) t(C, A.T, B, dt='float32')
t(C, A, B.T) t(C, A, B.T)
t(C.T, A.T, B) t(C.T, A.T, B)
t(C, A.T, B.T) t(C, A.T, B.T, dt='float32')
t(C.T, A, B.T) t(C.T, A, B.T)
t(C.T, A.T, B.T) t(C.T, A.T, B.T, dt='float32')
t(C, A[:,:2], B[:2, :]) t(C, A[:,:2], B[:2, :])
t(C.T, A[:,:2], B[:2, :]) t(C.T, A[:,:2], B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:2, :]) t(C, A[:2,:].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :]) t(C.T, A[:2,:].T, B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:, :2].T) t(C, A[:2,:].T, B[:, :2].T)
t(C.T, A[:2,:].T, B[:, :2].T) t(C.T, A[:2,:].T, B[:, :2].T)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论