提交 6cd840c7 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add test for gemm with non-contiguous arrays

上级 e32e92b4
......@@ -255,6 +255,60 @@ class t_gemm(TestCase):
return
self.fail()
def test_non_contiguous(self):
# Like test_transposes but with matrices without any
# continuous dimension
A = self.rand(4,4,3)
B = self.rand(4,4,3)
C = self.rand(4,4,3)
def t(z, x, y, a=1.0, b=0.0, l='c|py', dt='float64'):
z, a, x, y, b = [theano._asarray(p, dtype=dt) for p in z, a, x, y, b]
z_orig = z.copy()
z_after = numpy.zeros_like(z_orig)
for i in xrange(3):
z_after[:,:,i] = self._gemm(z[:,:,i], a, x[:,:,i], y[:,:,i], b)
tz, ta, tx, ty, tb = [shared(p) for p in z, a, x, y, b]
for i in xrange(3):
f_i = inplace_func([],
gemm_inplace(tz[:,:,i], ta, tx[:,:,i], ty[:,:,i], tb),
mode=compile.Mode(optimizer=None, linker=l))
for j in xrange(3):
# tz will not _always_ be overwritten,
# and adding update={...} in the call to function()
# will create cycles, so we update by hand.
z_i = f_i()
z = tz.get_value(borrow=True, return_internal_type=True)
z[:,:,i] = z_i
self.assertTrue(
_approx_eq(z_after[:,:,i],
tz.get_value(borrow=True)[:,:,i]),
(z_orig[:,:,i], z_after[:,:,i],
z[:,:,i], z_after[:,:,i] - z[:,:,i]))
tz_i = gemm_no_inplace(tz[:,:,i], ta, tx[:,:,i], ty[:,:,i], tb)
g_i = theano.function([], tz_i,
updates={tz:T.set_subtensor(tz[:,:,i], tz_i)},
mode=compile.Mode(optimizer=None, linker=l))
for j in xrange(3):
g_i()
self.assertTrue(
_approx_eq(z_after[:,:,i],
tz.get_value(borrow=True)[:,:,i]),
(z_orig[:,:,i], z_after[:,:,i],
z[:,:,i], z_after[:,:,i] - z[:,:,i]))
t(C, A, B)
t(C.transpose((1,0,2)), A, B)
t(C, A.transpose((1,0,2)), B, dt='float32')
t(C, A, B.transpose((1,0,2)))
t(C.transpose((1,0,2)), A.transpose((1,0,2)), B)
t(C, A.transpose((1,0,2)), B.transpose((1,0,2)), dt='float32')
t(C.transpose((1,0,2)), A, B.transpose((1,0,2)))
t(C.transpose((1,0,2)), A.transpose((1,0,2)), B.transpose((1,0,2)), dt='float32')
def test_res_is_a():
X,Y,Z,a,b = XYZab()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论