提交 0226809b authored 作者: Frederic Bastien's avatar Frederic Bastien

test gemm.perform on complex type. It's c_code don't support complex.

上级 db03a0e8
...@@ -49,29 +49,35 @@ class t_gemm(TestCase): ...@@ -49,29 +49,35 @@ class t_gemm(TestCase):
def rand(*args): def rand(*args):
return numpy.random.rand(*args) return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b): def cmp(self, z_, a_, x_, y_, b_):
def cmp_linker(z, a, x, y, b, l): for dtype in ['float32', 'float64', 'complex64', 'complex128']:
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b] z = numpy.asarray(z_, dtype=dtype)
z_orig = z.copy() a = numpy.asarray(a_, dtype=dtype)
tz,ta,tx,ty,tb = [as_tensor_variable(p).type() for p in z,a,x,y,b] x = numpy.asarray(x_, dtype=dtype)
y = numpy.asarray(y_, dtype=dtype)
f = inplace_func([tz,ta,tx,ty,tb], gemm_inplace(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l)) b = numpy.asarray(b_, dtype=dtype)
new_z = f(z,a,x,y,b) def cmp_linker(z, a, x, y, b, l):
z_after = self._gemm(z_orig, a, x, y, b) z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
z_orig = z.copy()
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z) tz,ta,tx,ty,tb = [as_tensor_variable(p).type() for p in z,a,x,y,b]
#_approx_eq.debug = 1
self.failUnless(_approx_eq(z_after, z)) f = inplace_func([tz,ta,tx,ty,tb], gemm_inplace(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l))
if a == 0.0 and b == 1.0: new_z = f(z,a,x,y,b)
return z_after = self._gemm(z_orig, a, x, y, b)
else:
self.failIf(numpy.all(z_orig == z)) #print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1
cmp_linker(copy(z), a, x, y, b, 'c|py') self.failUnless(_approx_eq(z_after, z))
cmp_linker(copy(z), a, x, y, b, 'py') if a == 0.0 and b == 1.0:
if config.blas.ldflags: return
# If blas.ldflags is equal to '', the C code will not be generated else:
cmp_linker(copy(z), a, x, y, b, 'c') self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'py')
if config.blas.ldflags and not dtype.startswith("complex"):
# If blas.ldflags is equal to '', the C code will not be generated
cmp_linker(copy(z), a, x, y, b, 'c')
def test0a(self): def test0a(self):
Gemm.debug = True Gemm.debug = True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论