提交 8d307b48 authored 作者: notoraptor's avatar notoraptor

Add tests for alt gemm (currently missing).

Fix typo in test blas c (useless but for coherence).
上级 6fd65836
......@@ -373,6 +373,98 @@ class t_gemm(TestCase):
1, 0, 2)), dt='float32')
class t_gemm_no_flags(object):
gemm = gemm_no_inplace
M = 4
N = 5
K = 6
slice_step = 3
def setUp(self):
unittest_tools.seed_rng()
def get_variable(self, V, to_transpose, to_slice):
if to_transpose:
V = V.T
if to_slice:
V = V[::self.slice_step]
return V
def get_function(self, dtype,
transpose_A=False, transpose_B=False, transpose_C=False,
slice_A=False, slice_B=False, slice_C=False):
alpha = theano.tensor.scalar(dtype=dtype, name='alpha')
beta = theano.tensor.scalar(dtype=dtype, name='beta')
A = theano.tensor.matrix(dtype=dtype, name='A')
B = theano.tensor.matrix(dtype=dtype, name='B')
C = theano.tensor.matrix(dtype=dtype, name='C')
A1 = self.get_variable(A, transpose_A, slice_A)
B1 = self.get_variable(B, transpose_B, slice_B)
C1 = self.get_variable(C, transpose_C, slice_C)
return theano.function([alpha, A, B, beta, C], self.gemm(C1, alpha, A1, B1, beta))
def generate_value(self, dtype, width, height, to_transpose, to_slice):
if to_slice:
if to_transpose:
shape = (height, width * self.slice_step)
else:
shape = (width * self.slice_step, height)
else:
if to_transpose:
shape = (height, width)
else:
shape = (width, height)
return np.random.random(shape).astype(dtype)
def get_data(self, dtype, alpha, beta,
transpose_A=False, transpose_B=False, transpose_C=False,
slice_A=False, slice_B=False, slice_C=False):
A = self.generate_value(dtype, self.M, self.N, transpose_A, slice_A)
B = self.generate_value(dtype, self.N, self.K, transpose_B, slice_B)
C = self.generate_value(dtype, self.M, self.K, transpose_C, slice_C)
return (alpha, A, B, beta, C)
def get_value(self, V, to_transpose, to_slice):
if to_transpose:
V = V.T
if to_slice:
V = V[::self.slice_step]
return V
def compute_ref(self, alpha, A, B, beta, C,
transpose_A, transpose_B, transpose_C,
slice_A, slice_B, slice_C):
A = self.get_value(A, transpose_A, slice_A)
B = self.get_value(B, transpose_B, slice_B)
C = self.get_value(C, transpose_C, slice_C)
return alpha * np.dot(A, B) + beta * C
@theano.change_flags({'blas.ldflags': ''})
def run_gemm(self, dtype, ALPHA, BETA,
transpose_A, transpose_B, transpose_C,
slice_A, slice_B, slice_C):
f = self.get_function(dtype, transpose_A, transpose_B, transpose_C, slice_A, slice_B, slice_C)
values = self.get_data(dtype, ALPHA, BETA, transpose_A, transpose_B, transpose_C, slice_A, slice_B, slice_C)
assert any(isinstance(node.op, Gemm) for node in f.maker.fgraph.apply_nodes)
z_val = f(*values)
assert z_val.dtype == dtype
assert tuple(z_val.shape) == (self.M, self.K)
ref_val = self.compute_ref(*(values + (transpose_A, transpose_B, transpose_C, slice_A, slice_B, slice_C)))
unittest_tools.assert_allclose(ref_val, z_val)
def test_gemm(self):
from itertools import product
dtypes = ('float32', 'float64')
scalars = (0, 1, -2)
booleans = (False, True)
# dtype, alpha, beta, transA, transB, transC, sliceA, sliceB, sliceC
iterables = [dtypes] + ([scalars] * 2) + ([booleans] * 6)
for dtype, alpha, beta, tA, tB, tC, sA, sB, sC in product(*iterables):
yield (self.run_gemm, dtype, alpha, beta, tA, tB, tC, sA, sB, sC)
def test_res_is_a():
X, Y, Z, a, b = XYZab()
......
......@@ -344,7 +344,7 @@ class TestCGemvNoFlags(object):
A_2 = A_1
x_2 = x
y_2 = y
return theano.function([alpha, A, x, beta, y], self.gemv(y_2, alpha, A_2, x_2, beta))
return theano.function([alpha, A, x, beta, y], self.gemv(y_2, alpha, A_2, x_2, beta), mode=self.mode)
def get_data(self, dtype, alpha, beta, transpose_A=False, slice_tensors=False):
if slice_tensors:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论