提交 0226809b authored 作者: Frederic Bastien's avatar Frederic Bastien

test gemm.perform on complex type. It's c_code don't support complex.

上级 db03a0e8
......@@ -49,29 +49,35 @@ class t_gemm(TestCase):
def rand(*args):
return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b):
def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
z_orig = z.copy()
tz,ta,tx,ty,tb = [as_tensor_variable(p).type() for p in z,a,x,y,b]
f = inplace_func([tz,ta,tx,ty,tb], gemm_inplace(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l))
new_z = f(z,a,x,y,b)
z_after = self._gemm(z_orig, a, x, y, b)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1
self.failUnless(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0:
return
else:
self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'py')
if config.blas.ldflags:
# If blas.ldflags is equal to '', the C code will not be generated
cmp_linker(copy(z), a, x, y, b, 'c')
def cmp(self, z_, a_, x_, y_, b_):
for dtype in ['float32', 'float64', 'complex64', 'complex128']:
z = numpy.asarray(z_, dtype=dtype)
a = numpy.asarray(a_, dtype=dtype)
x = numpy.asarray(x_, dtype=dtype)
y = numpy.asarray(y_, dtype=dtype)
b = numpy.asarray(b_, dtype=dtype)
def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
z_orig = z.copy()
tz,ta,tx,ty,tb = [as_tensor_variable(p).type() for p in z,a,x,y,b]
f = inplace_func([tz,ta,tx,ty,tb], gemm_inplace(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l))
new_z = f(z,a,x,y,b)
z_after = self._gemm(z_orig, a, x, y, b)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1
self.failUnless(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0:
return
else:
self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'py')
if config.blas.ldflags and not dtype.startswith("complex"):
# If blas.ldflags is equal to '', the C code will not be generated
cmp_linker(copy(z), a, x, y, b, 'c')
def test0a(self):
Gemm.debug = True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论