GpuJoin.c_code works now for axis==0 and axis==1

上级 71a47454
...@@ -2942,17 +2942,17 @@ class GpuJoin(tensor.Join, GpuOp): ...@@ -2942,17 +2942,17 @@ class GpuJoin(tensor.Join, GpuOp):
out[0] = rval out[0] = rval
def c_code(self, node, name, inputs, out_, sub): def c_code(self, node, name, inputs, out_, sub):
if node.inputs[0].data not in [0, 1]:
raise NotImplementedError()
# only works for the first two axis
if len(inputs) != 3: if len(inputs) != 3:
# only works for two arrays # only works for two arrays
raise NotImplementedError() raise NotImplementedError()
if any([i.ndim != 2 for i in node.inputs[1:]]): if any([i.ndim != 2 for i in node.inputs[1:]]):
# only works for type T.matrix # only works for type T.matrix
raise NotImplementedError() raise NotImplementedError()
if node.inputs[0].data !=0:
# only works for axis==0
print inputs[0]
raise NotImplementedError()
axis = inputs[0]
input_1 = inputs[1] input_1 = inputs[1]
input_2 = inputs[2] input_2 = inputs[2]
axis = inputs[0] axis = inputs[0]
...@@ -2960,7 +2960,9 @@ class GpuJoin(tensor.Join, GpuOp): ...@@ -2960,7 +2960,9 @@ class GpuJoin(tensor.Join, GpuOp):
out = out_[0] out = out_[0]
str = """ str = """
int axis = PyInt_AsLong((PyObject*)%(axis)s);
int nd = CudaNdarray_NDIM(%(input_1)s); int nd = CudaNdarray_NDIM(%(input_1)s);
int dims_array1[nd]; int dims_array1[nd];
int errorcode; int errorcode;
for(int i = 0; i<nd; i+=1){ for(int i = 0; i<nd; i+=1){
...@@ -2974,55 +2976,145 @@ class GpuJoin(tensor.Join, GpuOp): ...@@ -2974,55 +2976,145 @@ class GpuJoin(tensor.Join, GpuOp):
} }
int dims_out[nd]; int dims_out[nd];
dims_out[0] = dims_array1[0] + dims_array2[0]; if(axis==0)
dims_out[1] = dims_array1[1]; {
dims_out[0] = dims_array1[0] + dims_array2[0];
dims_out[1] = dims_array1[1];
}
if(axis==1)
{
dims_out[0] = dims_array1[0];
dims_out[1] = dims_array1[1] + dims_array2[1];
}
if (CudaNdarray_prep_output(& %(out)s, 2, dims_out)) if (CudaNdarray_prep_output(& %(out)s, 2, dims_out))
{ {
%(fail)s; %(fail)s;
} }
PyObject *slice; PyObject *slice;
PyObject *out_sub; PyObject *out_sub;
PyObject *start, *end, *step; PyObject *start, *stop, *step;
step = NULL;
start = PyInt_FromLong(0);
end = PyInt_FromLong(dims_array1[0]); if(axis==0)
step = PyInt_FromLong(1); {
slice = PySlice_New(start, end, step); start = PyInt_FromLong(0);
out_sub = CudaNdarray_Subscript((PyObject*)%(out)s, slice); stop = PyInt_FromLong(dims_array1[0]);
errorcode = CudaNdarray_CopyFromCudaNdarray((CudaNdarray*)out_sub, %(input_1)s); slice = PySlice_New(start, stop, step);
if((slice == NULL) || (out_sub == NULL) || (errorcode != 0)){ out_sub = CudaNdarray_Subscript((PyObject*)%(out)s, slice);
errorcode = CudaNdarray_CopyFromCudaNdarray((CudaNdarray*)out_sub, %(input_1)s);
if((slice == NULL) || (out_sub == NULL) || (errorcode != 0))
{
Py_XDECREF(slice);
Py_XDECREF(out_sub);
Py_XDECREF(start);
Py_XDECREF(stop);
Py_XDECREF(step);
Py_XDECREF(%(out)s);
%(fail)s;
}
Py_XDECREF(start);
Py_XDECREF(slice); Py_XDECREF(slice);
Py_XDECREF(out_sub); Py_XDECREF(out_sub);
Py_XDECREF(start);
Py_XDECREF(end);
Py_XDECREF(step);
Py_XDECREF(%(out)s);
%(fail)s;
}
start = end; start = stop;
end = PyInt_FromLong(PyInt_AsLong(start) + dims_array2[0]); stop = PyInt_FromLong(PyInt_AsLong(start) + dims_array2[0]);
step = PyInt_FromLong(1); slice = PySlice_New(start, stop, step);
slice = PySlice_New(start, end, step); out_sub = CudaNdarray_Subscript((PyObject*)%(out)s, slice);
out_sub = CudaNdarray_Subscript((PyObject*)%(out)s, slice); errorcode = CudaNdarray_CopyFromCudaNdarray((CudaNdarray*)out_sub, %(input_2)s);
errorcode = CudaNdarray_CopyFromCudaNdarray((CudaNdarray*)out_sub, %(input_2)s); if((slice == NULL) || (out_sub == NULL) || (errorcode != 0))
if((slice == NULL) || (out_sub == NULL) || (errorcode != 0)){ {
Py_XDECREF(slice);
Py_XDECREF(out_sub);
Py_XDECREF(start);
Py_XDECREF(stop);
Py_XDECREF(step);
Py_XDECREF(%(out)s);
%(fail)s;
}
Py_XDECREF(slice); Py_XDECREF(slice);
Py_XDECREF(out_sub); Py_XDECREF(out_sub);
Py_XDECREF(start); Py_XDECREF(start);
Py_XDECREF(end); Py_XDECREF(stop);
Py_XDECREF(step); Py_XDECREF(step);
Py_XDECREF(%(out)s);
%(fail)s;
} }
Py_XDECREF(slice); if(axis==1)
Py_XDECREF(out_sub); {
Py_XDECREF(start); PyObject *slice_tuple;
Py_XDECREF(end); PyObject *full_slice;
Py_XDECREF(step); PyObject *section_slice;
PyObject *start_axis2, *stop_axis2;
start = PyInt_FromLong(0);
stop = PyInt_FromLong(dims_out[0]);
stop_axis2 = PyInt_FromLong(dims_array1[1]);
slice_tuple = PyTuple_New(2);
full_slice = PySlice_New(start, stop, step);
section_slice = PySlice_New(start, stop_axis2, step);
PyTuple_SetItem(slice_tuple, 0, full_slice);
PyTuple_SetItem(slice_tuple, 1, section_slice);
out_sub = CudaNdarray_Subscript((PyObject*)%(out)s, slice_tuple);
errorcode = CudaNdarray_CopyFromCudaNdarray((CudaNdarray*)out_sub, %(input_1)s);
if((full_slice == NULL) || (section_slice == NULL) || (out_sub == NULL) || (errorcode != 0))
{
Py_XDECREF(full_slice);
Py_XDECREF(section_slice);
Py_XDECREF(slice_tuple);
Py_XDECREF(out_sub);
Py_XDECREF(start);
Py_XDECREF(stop);
Py_XDECREF(step);
Py_XDECREF(start_axis2);
Py_XDECREF(stop_axis2);
Py_XDECREF(%(out)s);
%(fail)s;
}
Py_XDECREF(stop);
Py_XDECREF(full_slice);
Py_XDECREF(section_slice);
Py_XDECREF(out_sub);
Py_XDECREF(slice_tuple);
start_axis2 = stop_axis2;
stop = PyInt_FromLong(dims_out[0]);
stop_axis2 = PyInt_FromLong(dims_array2[1] + dims_array1[1]);
slice_tuple = PyTuple_New(2);
full_slice = PySlice_New(start, stop, step);
section_slice = PySlice_New(start_axis2, stop_axis2, step);
PyTuple_SetItem(slice_tuple, 0, full_slice);
PyTuple_SetItem(slice_tuple, 1, section_slice);
out_sub = CudaNdarray_Subscript((PyObject*)%(out)s, slice_tuple);
errorcode = CudaNdarray_CopyFromCudaNdarray((CudaNdarray*)out_sub, %(input_2)s);
if((full_slice == NULL) || (section_slice == NULL) || (out_sub == NULL) || (errorcode != 0))
{
Py_XDECREF(full_slice);
Py_XDECREF(section_slice);
Py_XDECREF(slice_tuple);
Py_XDECREF(out_sub);
Py_XDECREF(start);
Py_XDECREF(stop);
Py_XDECREF(step);
Py_XDECREF(start_axis2);
Py_XDECREF(stop_axis2);
Py_XDECREF(%(out)s);
%(fail)s;
}
Py_XDECREF(full_slice);
Py_XDECREF(section_slice);
Py_XDECREF(slice_tuple);
Py_XDECREF(out_sub);
Py_XDECREF(start);
Py_XDECREF(stop);
Py_XDECREF(step);
Py_XDECREF(start_axis2);
Py_XDECREF(stop_axis2);
}
"""% locals() """% locals()
return str return str
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论