提交 16cefd73 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

In gpu_ger, copy matrix if not contiguous

Also allow f-contiguous mat in Gpu sger.
上级 a6e31bc4
...@@ -384,7 +384,7 @@ class GpuGer(Op): ...@@ -384,7 +384,7 @@ class GpuGer(Op):
return Apply(self, [z, a, x, y], [z.type()]) return Apply(self, [z, a, x, y], [z.type()])
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
#z_out = alpha * dot(x,y) + beta * z_in #z_out = alpha * dot(x,y) + beta * z_in
...@@ -392,44 +392,57 @@ class GpuGer(Op): ...@@ -392,44 +392,57 @@ class GpuGer(Op):
#not inplace version, we copy z_in to z_out. #not inplace version, we copy z_in to z_out.
z_in, a, x, y = inputs z_in, a, x, y = inputs
z_out, = outputs z_out, = outputs
inplace = int(self.inplace)
fail = sub['fail'] fail = sub['fail']
sio = StringIO.StringIO() sio = StringIO.StringIO()
print >> sio, """ print >> sio, """
float %(name)s_alpha = ((dtype_%(a)s*)(%(a)s->data))[0]; float %(name)s_alpha = ((dtype_%(a)s*)(%(a)s->data))[0];
"""
if self.inplace: if (%(inplace)s
print >> sio, """ && (CudaNdarray_HOST_STRIDES(%(z_in)s)[0] >= 0)
&& (CudaNdarray_HOST_STRIDES(%(z_in)s)[1] >= 0)
&& ((CudaNdarray_HOST_DIMS(%(z_in)s)[0] <= 1)
|| (CudaNdarray_HOST_STRIDES(%(z_in)s)[0] == 1)
|| (CudaNdarray_HOST_DIMS(%(z_in)s)[1] <= 1)
|| (CudaNdarray_HOST_STRIDES(%(z_in)s)[1] == 1)))
{
// The input has an appropriate layout, we work inplace
Py_XDECREF(%(z_out)s); Py_XDECREF(%(z_out)s);
%(z_out)s = %(z_in)s; %(z_out)s = %(z_in)s;
Py_INCREF(%(z_out)s); Py_INCREF(%(z_out)s);
""" }
else: else if (%(z_out)s
print >> sio, """ && (%(z_out)s->nd == 2)
if (!%(z_out)s && (CudaNdarray_HOST_DIMS(%(z_out)s)[0]
|| (%(z_out)s->nd != 2) == CudaNdarray_HOST_DIMS(%(z_in)s)[0])
|| (CudaNdarray_HOST_DIMS(%(z_out)s)[0] != CudaNdarray_HOST_DIMS(%(z_in)s)[0]) && (CudaNdarray_HOST_DIMS(%(z_out)s)[1]
|| (CudaNdarray_HOST_DIMS(%(z_out)s)[1] != CudaNdarray_HOST_DIMS(%(z_in)s)[1]) == CudaNdarray_HOST_DIMS(%(z_in)s)[1])
) && (CudaNdarray_HOST_STRIDES(%(z_out)s)[0] >= 0)
&& (CudaNdarray_HOST_STRIDES(%(z_out)s)[1] >= 0)
&& ((CudaNdarray_HOST_DIMS(%(z_out)s)[0] <= 1)
|| (CudaNdarray_HOST_STRIDES(%(z_out)s)[0] == 1)
|| (CudaNdarray_HOST_DIMS(%(z_out)s)[1] <= 1)
|| (CudaNdarray_HOST_STRIDES(%(z_out)s)[1] == 1)))
{
// The existing output has an appropriate layout,
// copy the input data into it, then work inplace
if (CudaNdarray_CopyFromCudaNdarray(%(z_out)s, %(z_in)s))
{ {
Py_XDECREF(%(z_out)s); %(fail)s;
%(z_out)s = (CudaNdarray*)CudaNdarray_Copy(%(z_in)s);
if (!%(z_out)s)
{
%(fail)s;
}
} }
else }
else
{
// Copy the input, use the copy as output
Py_XDECREF(%(z_out)s);
%(z_out)s = (CudaNdarray*)CudaNdarray_Copy(%(z_in)s);
if (!%(z_out)s)
{ {
if (CudaNdarray_CopyFromCudaNdarray(%(z_out)s, %(z_in)s)) %(fail)s;
{
%(fail)s;
}
} }
""" }
print >> sio, """
if (CudaNdarray_sger(%(name)s_alpha, %(x)s, %(y)s, %(z_out)s)) if (CudaNdarray_sger(%(name)s_alpha, %(x)s, %(y)s, %(z_out)s))
{ {
%(fail)s; %(fail)s;
......
...@@ -3115,12 +3115,6 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y, ...@@ -3115,12 +3115,6 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y,
return -1; return -1;
} }
// Maybe this could work, but be safe for now
if (!CudaNdarray_is_c_contiguous(A)) {
PyErr_SetString(PyExc_NotImplementedError, "non-c continugous A in sger");
return -1;
}
// Since Sger expects A in col-major, we invert x and y to fake this.
int x_strides = CudaNdarray_HOST_STRIDES(x)[0]; int x_strides = CudaNdarray_HOST_STRIDES(x)[0];
const CudaNdarray * x_ = x; const CudaNdarray * x_ = x;
if(x_strides == 0){ if(x_strides == 0){
...@@ -3131,7 +3125,7 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y, ...@@ -3131,7 +3125,7 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y,
" that have more then 1 elements!"); " that have more then 1 elements!");
return -1; return -1;
} }
x_strides = 4; x_strides = 1;
} else if(x_strides < 0){ } else if(x_strides < 0){
x_ = (CudaNdarray*)CudaNdarray_Copy(x); x_ = (CudaNdarray*)CudaNdarray_Copy(x);
x_strides = CudaNdarray_HOST_STRIDES(x_)[0]; x_strides = CudaNdarray_HOST_STRIDES(x_)[0];
...@@ -3147,17 +3141,40 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y, ...@@ -3147,17 +3141,40 @@ int CudaNdarray_sger(float alpha, const CudaNdarray * x, const CudaNdarray * y,
" that have more then 1 elements!"); " that have more then 1 elements!");
return -1; return -1;
} }
y_strides = 4; y_strides = 1;
} else if(y_strides < 0){ } else if(y_strides < 0){
y_ = (CudaNdarray*)CudaNdarray_Copy(y); y_ = (CudaNdarray*)CudaNdarray_Copy(y);
y_strides = CudaNdarray_HOST_STRIDES(y_)[0]; y_strides = CudaNdarray_HOST_STRIDES(y_)[0];
} }
if(CudaNdarray_SIZE(A)){ if(CudaNdarray_SIZE(A)){
cublasSger(CudaNdarray_HOST_DIMS(y)[0], CudaNdarray_HOST_DIMS(x)[0], alpha, // If A is in col-major
CudaNdarray_DEV_DATA(y_), y_strides, if ((CudaNdarray_HOST_DIMS(A)[0] <= 1)
CudaNdarray_DEV_DATA(x_), x_strides, || ((CudaNdarray_HOST_STRIDES(A)[0] == 1)
CudaNdarray_DEV_DATA(A), CudaNdarray_HOST_DIMS(A)[1]); && (CudaNdarray_HOST_STRIDES(A)[1] > 0)))
{
cublasSger(CudaNdarray_HOST_DIMS(x)[0], CudaNdarray_HOST_DIMS(y)[0], alpha,
CudaNdarray_DEV_DATA(x_), x_strides,
CudaNdarray_DEV_DATA(y_), y_strides,
CudaNdarray_DEV_DATA(A), CudaNdarray_HOST_STRIDES(A)[1]);
}
// Since Sger expects A in col-major, we invert x and y to fake this.
else if ((CudaNdarray_HOST_DIMS(A)[1] <= 1)
|| ((CudaNdarray_HOST_STRIDES(A)[1] == 1)
&& (CudaNdarray_HOST_STRIDES(A)[0] > 0)))
{
cublasSger(CudaNdarray_HOST_DIMS(y)[0], CudaNdarray_HOST_DIMS(x)[0], alpha,
CudaNdarray_DEV_DATA(y_), y_strides,
CudaNdarray_DEV_DATA(x_), x_strides,
CudaNdarray_DEV_DATA(A), CudaNdarray_HOST_STRIDES(A)[0]);
}
// A has to be either c- or f-contiguous, with no negative strides
else
{
PyErr_SetString(PyExc_NotImplementedError,
"non-contiguous A, or negative strides, in sger");
return -1;
}
} }
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
if(x_ != x) if(x_ != x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论