提交 e98da52c authored 作者: Frederic Bastien's avatar Frederic Bastien

Make dot, dot22, dot22scalar, gemm on GPU work with dimensions of 0.

上级 c6fce90d
......@@ -2720,10 +2720,10 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
if (B->nd != 2) { PyErr_SetString(PyExc_ValueError, "non-matrix arg to gemm"); return -1; }
if (C->nd != 2) { PyErr_SetString(PyExc_ValueError, "non-matrix arg to gemm"); return -1; }
// We must allow dimensions to be zeros.
if ((CudaNdarray_HOST_DIMS(A)[1] != CudaNdarray_HOST_DIMS(B)[0])
|| (CudaNdarray_HOST_DIMS(A)[0] != CudaNdarray_HOST_DIMS(C)[0])
|| (CudaNdarray_HOST_DIMS(B)[1] != CudaNdarray_HOST_DIMS(C)[1])
|| (CudaNdarray_HOST_DIMS(A)[1] == 0))
|| (CudaNdarray_HOST_DIMS(B)[1] != CudaNdarray_HOST_DIMS(C)[1]))
{
PyErr_Format(PyExc_ValueError, "dimension mismatch in args to gemm (%i,%i)x(%i,%i)->(%i,%i)",
CudaNdarray_HOST_DIMS(A)[0],
......@@ -2814,6 +2814,9 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
//TODO: recognize the negative stride and make a copy of the offending argument,
//rather than aborting
#define CHK_STRIDE_SGEMM(T0, T1, D0, D1, D2, a, x, sx, y, sy, b, z, sz) \
if (sx == 0){sx = 1;}\
if (sy == 0){sy = 1;}\
if (sz == 0){sz = 1;}\
if ((sx > 0) && (sy > 0) && (sz > 0)) { \
cublasSgemm(T0, T1, D0, D1, D2, a, x, sx, y, sy, b, z, sz); \
} else { \
......
......@@ -28,29 +28,33 @@ def my_rand(*shape):
return theano._asarray(numpy.random.rand(*shape),dtype='float32')
def test_dot22():
a = tcn.shared_constructor(my_rand(4,4), 'a')
def cmp(a_shp, b_shp):
a = tcn.shared_constructor(my_rand(*a_shp), 'a')
b = tensor.fmatrix()
f = pfunc([b], [], updates=[(a, tensor.dot(a,b))], mode=mode_with_gpu)
a0 = a.get_value() * 1.0
print a0
for i, node in enumerate(f.maker.env.toposort()):
print i, node
bval = my_rand(4,4)
bval = my_rand(*b_shp)
f(bval)
print a.get_value()
assert numpy.allclose(numpy.dot(a0, bval), a.get_value())
cmp((3,4),(4,5))
cmp((0,4),(4,5))
cmp((3,4),(4,0))
cmp((3,0),(0,5))
cmp((0,4),(4,0))
cmp((0,0),(0,0))
def test_dot22scalar():
def cmp(a_shp, b_shp):
a = tensor.fmatrix()
b = tensor.fmatrix()
scalar = tensor.fscalar()
av = my_rand(4,4)
bv = my_rand(4,4)
av = my_rand(*a_shp)
bv = my_rand(*b_shp)
f = theano.function([a,b], tensor.dot(a,b)*numpy.asarray(4, 'float32'), mode=mode_with_gpu)
f2 = theano.function([a,b], tensor.dot(a,b)*numpy.asarray(4, 'float32'))
......@@ -72,9 +76,16 @@ def test_dot22scalar():
assert isinstance(t[3].op,tcn.HostFromGpu)
assert numpy.allclose(f(av,bv,0.5),f2(av,bv,0.5))
def test_gemm():
cmp((3,4),(4,5))
cmp((0,4),(4,5))
cmp((3,4),(4,0))
cmp((3,0),(0,5))
cmp((0,4),(4,0))
cmp((0,0),(0,0))
a = tcn.shared_constructor(my_rand(4,4), 'a')
def test_gemm():
def cmp(a_shp, b_shp):
a = tcn.shared_constructor(my_rand(*a_shp), 'a')
b = tensor.fmatrix('b')
c = tensor.fmatrix('c')
......@@ -83,20 +94,23 @@ def test_gemm():
assert any([node.op == tcn.blas.gpu_gemm_inplace for node in f.maker.env.toposort()])
a0 = a.get_value() * 1.0
print a0
for i, node in enumerate(f.maker.env.toposort()):
print i, node
bval = my_rand(4,4)
cval = my_rand(4,4)
bval = my_rand(*b_shp)
cval = my_rand(a_shp[0],b_shp[1])
f(bval,cval)
print a.get_value()
assert numpy.allclose(numpy.dot(a0, bval)+numpy.exp(cval), a.get_value())
cmp((3,4),(4,5))
cmp((0,4),(4,5))
cmp((3,4),(4,0))
cmp((3,0),(0,5))
cmp((0,4),(4,0))
cmp((0,0),(0,0))
def test_gemm_no_inplace():
a = tcn.shared_constructor(my_rand(4,4), 'a')
cval = my_rand(4,4)
def cmp(a_shp, b_shp):
a = tcn.shared_constructor(my_rand(*a_shp), 'a')
cval = my_rand(a_shp[0], b_shp[1])
c = tcn.shared_constructor(cval.copy(), 'c')
b = tcn.fmatrix('b')
......@@ -105,18 +119,21 @@ def test_gemm_no_inplace():
f = pfunc([b,b2], [tensor.dot(a,b2) + c], updates=[(a, tensor.dot(a,b) + c)], mode=mode_with_gpu)
a0 = a.get_value() * 1.0
#print a0
for i, node in enumerate(f.maker.env.toposort()):
print i, node
assert any([node.op == tcn.blas.gpu_gemm_no_inplace for node in f.maker.env.toposort()])
bval = my_rand(4,4)
bval2 = my_rand(4,4)
bval = my_rand(*b_shp)
bval2 = my_rand(*b_shp)
rval = f(bval,bval2)
#print a.get_value()
assert numpy.allclose(numpy.dot(a0, bval)+cval, a.get_value())
assert numpy.allclose(numpy.dot(a0, bval2)+cval, rval)
cmp((3,4),(4,5))
cmp((0,4),(4,5))
cmp((3,4),(4,0))
cmp((3,0),(0,5))
cmp((0,4),(4,0))
cmp((0,0),(0,0))
def test_outer():
x = tcn.shared_constructor(my_rand(8,), 'x')
y = tcn.shared_constructor(my_rand(6,), 'y')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论