added float32 tests of gemm

上级 13db39ca
......@@ -1108,8 +1108,8 @@ class t_gemm(unittest.TestCase):
B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l=gof.cc.OpWiseCLinker):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
def t(z,x,y,a=1.0, b=0.0,l=gof.cc.OpWiseCLinker,dt='float64'):
z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b]
z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b)
......@@ -1123,17 +1123,17 @@ class t_gemm(unittest.TestCase):
t(C,A,B)
t(C.T, A, B)
t(C, A.T, B)
t(C, A.T, B, dt='float32')
t(C, A, B.T)
t(C.T, A.T, B)
t(C, A.T, B.T)
t(C, A.T, B.T, dt='float32')
t(C.T, A, B.T)
t(C.T, A.T, B.T)
t(C.T, A.T, B.T, dt='float32')
t(C, A[:,:2], B[:2, :])
t(C.T, A[:,:2], B[:2, :])
t(C.T, A[:,:2], B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:, :2].T)
t(C.T, A[:2,:].T, B[:, :2].T)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论