提交 2c447bb9 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix GpuGemv with negative strides

上级 554a55d9
...@@ -305,7 +305,7 @@ class GpuGemv(Op): ...@@ -305,7 +305,7 @@ class GpuGemv(Op):
return Apply(self, [z, a, x, y, b], [z.type()]) return Apply(self, [z, a, x, y, b], [z.type()])
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
#z_out = alpha * dot(x,y) + beta * z_in #z_out = alpha * dot(x,y) + beta * z_in
...@@ -313,44 +313,46 @@ class GpuGemv(Op): ...@@ -313,44 +313,46 @@ class GpuGemv(Op):
#not inplace version, we copy z_in to z_out. #not inplace version, we copy z_in to z_out.
z_in, a, x, y, b = inputs z_in, a, x, y, b = inputs
z_out, = outputs z_out, = outputs
inplace = int(self.inplace)
fail = sub['fail'] fail = sub['fail']
sio = StringIO.StringIO() sio = StringIO.StringIO()
print >> sio, """ print >> sio, """
float %(name)s_alpha = ((dtype_%(a)s*)(%(a)s->data))[0]; float %(name)s_alpha = ((dtype_%(a)s*)(%(a)s->data))[0];
float %(name)s_beta = ((dtype_%(b)s*)(%(b)s->data))[0]; float %(name)s_beta = ((dtype_%(b)s*)(%(b)s->data))[0];
"""
if self.inplace: if (%(inplace)s
print >> sio, """ && ((CudaNdarray_HOST_STRIDES(%(z_in)s)[0] > 0)
|| ((CudaNdarray_HOST_STRIDES(%(z_in)s)[0] == 0)
&& (CudaNdarray_HOST_DIMS(%(z_in)s)[0] == 1))))
{
// Work inplace on the input
Py_XDECREF(%(z_out)s); Py_XDECREF(%(z_out)s);
%(z_out)s = %(z_in)s; %(z_out)s = %(z_in)s;
Py_INCREF(%(z_out)s); Py_INCREF(%(z_out)s);
""" }
else: else if (%(z_out)s
print >> sio, """ && ((CudaNdarray_HOST_STRIDES(%(z_out)s)[0] > 0)
if (!%(z_out)s || ((CudaNdarray_HOST_STRIDES(%(z_out)s)[0] == 0)
|| (%(z_out)s->nd != 1) && (CudaNdarray_HOST_DIMS(%(z_out)s)[0] == 1))))
|| (CudaNdarray_HOST_DIMS(%(z_out)s)[0] != CudaNdarray_HOST_DIMS(%(z_in)s)[0]) {
) // Work on the output
if (CudaNdarray_CopyFromCudaNdarray(%(z_out)s, %(z_in)s))
{ {
Py_XDECREF(%(z_out)s); %(fail)s;
%(z_out)s = (CudaNdarray*)CudaNdarray_Copy(%(z_in)s);
if (!%(z_out)s)
{
%(fail)s;
}
} }
else }
else
{
// Copy
Py_XDECREF(%(z_out)s);
%(z_out)s = (CudaNdarray*)CudaNdarray_Copy(%(z_in)s);
if (!%(z_out)s)
{ {
if (CudaNdarray_CopyFromCudaNdarray(%(z_out)s, %(z_in)s)) %(fail)s;
{
%(fail)s;
}
} }
""" }
print >> sio, """
if (CudaNdarray_sgemv(%(name)s_alpha, %(x)s, %(y)s, %(name)s_beta, %(z_out)s)) if (CudaNdarray_sgemv(%(name)s_alpha, %(x)s, %(y)s, %(name)s_beta, %(z_out)s))
{ {
%(fail)s; %(fail)s;
......
...@@ -3029,8 +3029,7 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -3029,8 +3029,7 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
char N = 'N'; char N = 'N';
char T = 'T'; char T = 'T';
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n'; //std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
//TODO: recognize the negative stride and make a copy of the offending argument, // There should be no negative stride at that point
//rather than aborting
#define CHK_STRIDE_SGEMM(T0, T1, D0, D1, D2, a, x, sx, y, sy, b, z, sz) \ #define CHK_STRIDE_SGEMM(T0, T1, D0, D1, D2, a, x, sx, y, sy, b, z, sz) \
if (sx == 0){sx = 1;}\ if (sx == 0){sx = 1;}\
if (sy == 0){sy = 1;}\ if (sy == 0){sy = 1;}\
...@@ -3038,7 +3037,7 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -3038,7 +3037,7 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
if ((sx > 0) && (sy > 0) && (sz > 0)) { \ if ((sx > 0) && (sy > 0) && (sz > 0)) { \
cublasSgemm(T0, T1, D0, D1, D2, a, x, sx, y, sy, b, z, sz); \ cublasSgemm(T0, T1, D0, D1, D2, a, x, sx, y, sy, b, z, sz); \
} else { \ } else { \
PyErr_SetString(PyExc_NotImplementedError, "negative stride to sGemm");\ PyErr_SetString(PyExc_AssertionError, "negative stride to sGemm");\
Py_XDECREF(A);\ Py_XDECREF(A);\
Py_XDECREF(B);\ Py_XDECREF(B);\
return -1; \ return -1; \
...@@ -3093,22 +3092,53 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -3093,22 +3092,53 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
return -1; return -1;
} }
// a matrix has non-unit size and non-unit stride in both directions, we can't operate in-place // If matrix A has non-unit size and non-unit stride in both
// TODO: make a copy instead of returning in error // dimensions, or negative strides, we cannot operate, but we can
if (((CudaNdarray_HOST_DIMS(A)[0] > 1) && (CudaNdarray_HOST_STRIDES(A)[0] != 1)) && ((CudaNdarray_HOST_DIMS(A)[1] > 1) && (CudaNdarray_HOST_STRIDES(A)[1] != 1))) // make a copy.
{ PyErr_SetString(PyExc_NotImplementedError, "non-unit stride in gemv arg"); return -1; } if (((CudaNdarray_HOST_DIMS(A)[0] > 1)
&& (CudaNdarray_HOST_STRIDES(A)[0] != 1)
&& (CudaNdarray_HOST_DIMS(A)[1] > 1)
&& (CudaNdarray_HOST_STRIDES(A)[1] != 1))
|| (CudaNdarray_HOST_STRIDES(A)[0] < 0)
|| (CudaNdarray_HOST_STRIDES(A)[1] < 0))
{
const CudaNdarray* A_new = (CudaNdarray*) CudaNdarray_Copy(A);
if (!A_new)
return -1;
A = A_new;
}
else
{
// Incref A, so we can decref it at the end in all cases
Py_INCREF(A);
}
// If vector B as a negative stride, we also have to make a copy.
if (CudaNdarray_HOST_STRIDES(B)[0] < 0)
{
const CudaNdarray* B_new = (CudaNdarray*) CudaNdarray_Copy(B);
if (!B_new)
{
Py_XDECREF(A);
return -1;
}
B = B_new;
}
else
{
// Incref B, so we can decref it at the end in all cases
Py_INCREF(B);
}
// I don't know if cudablas handles negative strides // I don't know if cudablas handles negative strides
if ( (CudaNdarray_HOST_STRIDES(A)[0] < 0) if ( (CudaNdarray_HOST_STRIDES(A)[0] < 0)
|| (CudaNdarray_HOST_STRIDES(A)[1] < 0) || (CudaNdarray_HOST_STRIDES(A)[1] < 0))
|| (CudaNdarray_HOST_STRIDES(B)[0] < 0)
|| (CudaNdarray_HOST_STRIDES(C)[0] < 0))
{ {
PyErr_Format(PyExc_ValueError, "illegal strides in args to gemv (%i,%i)x(%i)->(%i)", PyErr_Format(PyExc_ValueError, "illegal strides in args to gemv (%i,%i)",
CudaNdarray_HOST_STRIDES(A)[0], CudaNdarray_HOST_STRIDES(A)[0],
CudaNdarray_HOST_STRIDES(A)[1], CudaNdarray_HOST_STRIDES(A)[1]);
CudaNdarray_HOST_STRIDES(B)[0], Py_XDECREF(A);
CudaNdarray_HOST_STRIDES(C)[0]); Py_XDECREF(B);
return -1; return -1;
} }
...@@ -3120,32 +3150,46 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -3120,32 +3150,46 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
int sb_0 = (CudaNdarray_HOST_DIMS(B)[0] > 1) ? CudaNdarray_HOST_STRIDES(B)[0] : 1; int sb_0 = (CudaNdarray_HOST_DIMS(B)[0] > 1) ? CudaNdarray_HOST_STRIDES(B)[0] : 1;
int sc_0 = (CudaNdarray_HOST_DIMS(C)[0] > 1) ? CudaNdarray_HOST_STRIDES(C)[0] : 1; int sc_0 = (CudaNdarray_HOST_DIMS(C)[0] > 1) ? CudaNdarray_HOST_STRIDES(C)[0] : 1;
if (sa_0 == 1) if (CudaNdarray_SIZE(C)) {
{ if ((sa_0 == 1) || (sa_0 == 0))
cublasSgemv('N', {
CudaNdarray_HOST_DIMS(A)[0], CudaNdarray_HOST_DIMS(A)[1], cublasSgemv('N',
alpha, CudaNdarray_HOST_DIMS(A)[0], CudaNdarray_HOST_DIMS(A)[1],
CudaNdarray_DEV_DATA(A), sa_1, alpha,
CudaNdarray_DEV_DATA(B), sb_0, CudaNdarray_DEV_DATA(A), sa_1,
beta, CudaNdarray_DEV_DATA(B), sb_0,
CudaNdarray_DEV_DATA(C), sc_0); beta,
} CudaNdarray_DEV_DATA(C), sc_0);
else if (sa_1 == 1) }
{ else if ((sa_1 == 1) || (sa_1 == 0))
cublasSgemv('T', {
CudaNdarray_HOST_DIMS(A)[1], CudaNdarray_HOST_DIMS(A)[0], cublasSgemv('T',
alpha, CudaNdarray_HOST_DIMS(A)[1], CudaNdarray_HOST_DIMS(A)[0],
CudaNdarray_DEV_DATA(A), sa_0, alpha,
CudaNdarray_DEV_DATA(B), sb_0, CudaNdarray_DEV_DATA(A), sa_0,
beta, CudaNdarray_DEV_DATA(B), sb_0,
CudaNdarray_DEV_DATA(C), sc_0); beta,
} CudaNdarray_DEV_DATA(C), sc_0);
else }
{ else
PyErr_SetString(PyExc_NotImplementedError, "too many strides strides in sgemv"); {
return -1; PyErr_Format(PyExc_AssertionError,
"Unexpected stride pattern in gemv: (%i, %i) x %i -> %i",
sa_0, sa_1, sb_0, sc_0);
printf("shapes are: (%i, %i) x %i -> %i\n",
CudaNdarray_HOST_DIMS(A)[0],
CudaNdarray_HOST_DIMS(A)[1],
CudaNdarray_HOST_DIMS(B)[0],
CudaNdarray_HOST_DIMS(C)[0]);
Py_XDECREF(A);
Py_XDECREF(B);
return -1;
}
} }
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
Py_XDECREF(A);
Py_XDECREF(B);
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (CUBLAS_STATUS_SUCCESS != err) if (CUBLAS_STATUS_SUCCESS != err)
{ {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论