Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
219f113a
提交
219f113a
authored
6月 22, 2011
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
misc changes as part of code review of CudaNdarray.__idiv__
上级
94974850
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
274 行增加
和
167 行删除
+274
-167
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+274
-167
没有找到文件。
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
219f113a
...
...
@@ -932,95 +932,151 @@ __global__ void name(const int d0, const int d1, const int d2, const int d3,\
template<typename T> __device__ T binary_iadd(T a, T b) { a = a+b; }
template<typename T> __device__ T binary_idiv(T a, T b) { a = a/b; }
decl_k_elemwise_binary_inplace_rowmajor_3(k_i
A
dd_3, binary_iadd<float>)
decl_k_elemwise_binary_inplace_rowmajor_4(k_i
A
dd_4, binary_iadd<float>)
decl_k_elemwise_binary_inplace_rowmajor_3(k_i
D
iv_3, binary_idiv<float>)
decl_k_elemwise_binary_inplace_rowmajor_4(k_i
D
iv_4, binary_idiv<float>)
decl_k_elemwise_binary_inplace_rowmajor_3(k_i
a
dd_3, binary_iadd<float>)
decl_k_elemwise_binary_inplace_rowmajor_4(k_i
a
dd_4, binary_iadd<float>)
decl_k_elemwise_binary_inplace_rowmajor_3(k_i
d
iv_3, binary_idiv<float>)
decl_k_elemwise_binary_inplace_rowmajor_4(k_i
d
iv_4, binary_idiv<float>)
*/
__global__ void k_iAdd_3(const int d0, const int d1, const int d2,\
float* a, const int sA0, const int sA1, const int sA2,\
const float* b, const int sB0, const int sB1, const int sB2){\
for (int i0 = blockIdx.x; i0 < d0; i0 += gridDim.x){\
for (int i1 = blockIdx.y; i1 < d1; i1 += gridDim.y){\
for (int i2 = threadIdx.x; i2 < d2; i2 += blockDim.x){\
a[i0*sA0 + i1*sA1 + i2*sA2]+= b[i0*sB0 + i1*sB1 + i2*sB2]; \
}\
}\
}\
}
__global__ void k_iAdd_4(const int d0, const int d1, const int d2, const int d3,\
float* a, const int sA0, const int sA1,\
const int sA2, const int sA3,\
const float* b, const int sB0, const int sB1,\
const int sB2, const int sB3){\
for (int i0 = blockIdx.x; i0 < d0; i0 += gridDim.x){\
for (int i1 = blockIdx.y; i1 < d1; i1 += gridDim.y){\
for (int i2 = threadIdx.x; i2 < d2; i2 += blockDim.x){\
for (int i3 = threadIdx.y; i3 < d3; i3 += blockDim.y){\
a[i0*sA0 + i1*sA1 + i2*sA2 + i3*sA3] += b[i0*sB0 + i1*sB1 + i2*sB2 + i3*sB3]; \
}\
}\
}\
}\
}
enum operator_t
{
IADD=0,
IDIV,
CPY,
N_ELEMWISE_OPS
};
__global__ void k_iDiv_3(const int d0, const int d1, const int d2,\
float* a, const int sA0, const int sA1, const int sA2,\
const float* b, const int sB0, const int sB1, const int sB2){\
for (int i0 = blockIdx.x; i0 < d0; i0 += gridDim.x){\
for (int i1 = blockIdx.y; i1 < d1; i1 += gridDim.y){\
for (int i2 = threadIdx.x; i2 < d2; i2 += blockDim.x){\
a[i0*sA0 + i1*sA1 + i2*sA2]/= b[i0*sB0 + i1*sB1 + i2*sB2]; \
}\
}\
}\
template <int operator_num>
__global__ void k_ielem_3(const int d0, const int d1, const int d2,
float* a, const int sA0, const int sA1, const int sA2,
const float* b, const int sB0, const int sB1, const int sB2){
for (int i0 = blockIdx.x; i0 < d0; i0 += gridDim.x){
for (int i1 = blockIdx.y; i1 < d1; i1 += gridDim.y){
for (int i2 = threadIdx.x; i2 < d2; i2 += blockDim.x){
switch (operator_num)
{
case IADD:
a[i0*sA0 + i1*sA1 + i2*sA2] += b[i0*sB0 + i1*sB1 + i2*sB2];
break;
case IDIV:
a[i0*sA0 + i1*sA1 + i2*sA2] /= b[i0*sB0 + i1*sB1 + i2*sB2];
break;
case CPY:
a[i0*sA0 + i1*sA1 + i2*sA2] = b[i0*sB0 + i1*sB1 + i2*sB2];
break;
}
}
}
}
}
__global__ void k_iDiv_4(const int d0, const int d1, const int d2, const int d3,\
float* a, const int sA0, const int sA1,\
const int sA2, const int sA3,\
const float* b, const int sB0, const int sB1,\
const int sB2, const int sB3){\
for (int i0 = blockIdx.x; i0 < d0; i0 += gridDim.x){\
for (int i1 = blockIdx.y; i1 < d1; i1 += gridDim.y){\
for (int i2 = threadIdx.x; i2 < d2; i2 += blockDim.x){\
for (int i3 = threadIdx.y; i3 < d3; i3 += blockDim.y){\
a[i0*sA0 + i1*sA1 + i2*sA2 + i3*sA3] /= b[i0*sB0 + i1*sB1 + i2*sB2 + i3*sB3]; \
}\
}\
}\
}\
template <int operator_num>
__global__ void k_ielem_4(const int d0, const int d1, const int d2, const int d3,
float* a, const int sA0, const int sA1,
const int sA2, const int sA3,
const float* b, const int sB0, const int sB1,
const int sB2, const int sB3){
for (int i0 = blockIdx.x; i0 < d0; i0 += gridDim.x){
for (int i1 = blockIdx.y; i1 < d1; i1 += gridDim.y){
for (int i2 = threadIdx.x; i2 < d2; i2 += blockDim.x){
for (int i3 = threadIdx.y; i3 < d3; i3 += blockDim.y){
switch (operator_num) {
case IADD:
a[i0*sA0 + i1*sA1 + i2*sA2 + i3*sA3]
+= b[i0*sB0 + i1*sB1 + i2*sB2 + i3*sB3];
break;
case IDIV:
a[i0*sA0 + i1*sA1 + i2*sA2 + i3*sA3]
/= b[i0*sB0 + i1*sB1 + i2*sB2 + i3*sB3];
break;
case CPY:
a[i0*sA0 + i1*sA1 + i2*sA2 + i3*sA3]
= b[i0*sB0 + i1*sB1 + i2*sB2 + i3*sB3];
break;
}
}
}
}
}
}
static PyObject *
CudaNdarray_inplace_add_div(PyObject* py_self, PyObject * py_other, int fct_nb)
/*
CudaNdarray_inplace_elemwise
Compute A / B or A + B, working inplace on A.
py_self - the CudaNdarray that we'll modify (A)
py_other - the other argument (B)
fct_nb - which operation to perform (operator_t)
Returns 0 on success.
Returns 1 on failure, and sets Python exception.
*/
int
CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t fct_nb)
{
int verbose = 0;
void (*k3)(const int, const int, const int,
float*, const int, const int, const int,
const float*, const int, const int, const int);
void (*k4)(const int, const int, const int, const int,
float*, const int, const int,
const int, const int,
const float*, const int, const int,
const int, const int);
switch (fct_nb)
{
case IADD:
k3 = k_ielem_3<IADD>;
k4 = k_ielem_4<IADD>;
break;
case IDIV:
k3 = k_ielem_3<IDIV>;
k4 = k_ielem_4<IDIV>;
break;
case CPY:
k3 = k_ielem_3<CPY>;
k4 = k_ielem_4<CPY>;
break;
default:
assert (0);
PyErr_Format(
PyExc_TypeError,
"CudaNdarray_inplace_elemwise invalid fct_nb (%i).",
(int)fct_nb);
return -1;
}
if (! CudaNdarray_Check(py_self)) {
PyErr_SetString(PyExc_TypeError, "CudaNdarray_inplace_add_div need a CudaNdarray on left");
return NULL;
PyErr_SetString(
PyExc_TypeError,
"CudaNdarray_inplace_elemwise need a CudaNdarray on left");
return -1;
}
if (! CudaNdarray_Check(py_other)) {
PyErr_SetString(PyExc_TypeError, "CudaNdarray_inplace_add_div need a CudaNdarray on right");
return NULL;
}
if (fct_nb<0 || fct_nb>1){
PyErr_SetString(PyExc_TypeError, "CudaNdarray_inplace_add_div fct_nb param supported are only 0 and 1.");
return NULL;
PyErr_SetString(
PyExc_TypeError,
"CudaNdarray_inplace_elemwise need a CudaNdarray on right");
return -1;
}
CudaNdarray * self = (CudaNdarray *)py_self;
CudaNdarray * other = (CudaNdarray *)py_other;
if (verbose) fprintf(stderr, "INPLACE ADD/DIV for self->nd=%d other->nd=%d\n",
self->nd, other->nd);
if (verbose)
{
fprintf(stderr,
"INPLACE ADD/DIV for self->nd=%d other->nd=%d\n",
self->nd, other->nd);
}
//standard elemwise size checks
if (self->nd != other->nd)
{
PyErr_Format(PyExc_TypeError, "CudaNdarray_inplace_add_div: need same number of dims. Got %d and %d", self->nd, other->nd);
return NULL;
PyErr_Format(
PyExc_TypeError,
"CudaNdarray_inplace_elemwise: need same number of dims. Got %d and %d",
self->nd, other->nd);
return -1;
}
//standard elemwise dim checks
unsigned int size = 1;
...
...
@@ -1029,30 +1085,27 @@ CudaNdarray_inplace_add_div(PyObject* py_self, PyObject * py_other, int fct_nb)
if ((CudaNdarray_HOST_DIMS(self)[i] != CudaNdarray_HOST_DIMS(other)[i])
&& (CudaNdarray_HOST_DIMS(other)[i] != 1))
{
PyErr_SetString(PyExc_TypeError, "need same dimensions (or broadcastable dimension)");
return NULL;
PyErr_SetString(
PyExc_ValueError,
"need same dimensions (or broadcastable dimension)");
return -1;
}
// if we're broadcasting other, then make sure it has stride 0
assert ((CudaNdarray_HOST_DIMS(self)[i] == CudaNdarray_HOST_DIMS(other)[i])
|| (CudaNdarray_HOST_STRIDES(other)[i] == 0));
size *= (unsigned int) CudaNdarray_HOST_DIMS(self)[i];
}
if(CudaNdarray_SIZE((CudaNdarray *)py_self)==0 && CudaNdarray_SIZE((CudaNdarray *)py_other)==0){
Py_INCREF(py_self);
return py_self;
}
void (*k_iop_3)(const int, const int, const int,
float*, const int, const int, const int,
const float*, const int, const int, const int);
void (*k_iop_4)(const int, const int, const int, const int,
float*, const int, const int,
const int, const int,
const float*, const int, const int,
const int, const int);
if(fct_nb == 0){
k_iop_3 = k_iAdd_3;
k_iop_4 = k_iAdd_4;
}else if(fct_nb == 1){
k_iop_3 = k_iDiv_3;
k_iop_4 = k_iDiv_4;
if (size==0)
{
if (CudaNdarray_SIZE((CudaNdarray *)py_other))
{
PyErr_SetString(
PyExc_ValueError,
"cannot work inplace on an un-initialized array");
return 0;
}
return 0;
}
switch(self->nd)
...
...
@@ -1061,63 +1114,77 @@ CudaNdarray_inplace_add_div(PyObject* py_self, PyObject * py_other, int fct_nb)
{
dim3 n_blocks(1, 1, 1);
dim3 n_threads(1);
k_iop_3<<<n_blocks, n_threads>>>(1,
1, //CudaNdarray_HOST_DIMS(self)[0],
1, //CudaNdarray_HOST_DIMS(self)[0],
k3<<<n_blocks, n_threads>>>(
1, //d0
1, //d1
1, //d2
CudaNdarray_DEV_DATA(self),
1, //strides
1,
1,
1, //CudaNdarray_HOST_STRIDES(self)[0],
CudaNdarray_HOST_STRIDES(self)[0],
CudaNdarray_DEV_DATA(other),
1, //strides
1,
1, //CudaNdarray_HOST_STRIDES(other)[0],
CudaNdarray_HOST_STRIDES(other)[0]);
1);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if
(
cudaSuccess != err)
if
(
cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error: %s: %s.\n", "k_iop_3", cudaGetErrorString(err));
return NULL;
PyErr_Format(
PyExc_RuntimeError,
"Cuda error: %s: %s.\n",
"k3",
cudaGetErrorString(err));
return -1;
}
Py_INCREF(py_self);
return py_self;
}
break;
case 1:
{
dim3 n_blocks(1, 1, 1);
dim3 n_threads(
std::min(CudaNdarray_HOST_DIMS(self)[0], NUM_VECTOR_OP_THREADS_PER_BLOCK)
);
k_iop_3<<<n_blocks, n_threads>>>(1,
1, //CudaNdarray_HOST_DIMS(self)[0],
std::min(
CudaNdarray_HOST_DIMS(self)[0],
NUM_VECTOR_OP_THREADS_PER_BLOCK));
k3<<<n_blocks, n_threads>>>(
1, //dimensions
1,
CudaNdarray_HOST_DIMS(self)[0],
CudaNdarray_DEV_DATA(self),
1, //strides
1,
1, //CudaNdarray_HOST_STRIDES(self)[0],
CudaNdarray_HOST_STRIDES(self)[0],
CudaNdarray_DEV_DATA(other),
1,
1,
//CudaNdarray_HOST_STRIDES(other)[0],
1,
//strides
1,
CudaNdarray_HOST_STRIDES(other)[0]);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if
(
cudaSuccess != err)
if
(
cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error: %s: %s.\n", "k_iop_3", cudaGetErrorString(err));
return NULL;
PyErr_Format(
PyExc_RuntimeError,
"Cuda error: %s: %s.\n",
"k3",
cudaGetErrorString(err));
return -1;
}
Py_INCREF(py_self);
return py_self;
}
break;
case 2:
{
//TODO: if both self and other are f-contiguous
// Then flip the block and thread dimensions
// to make contiguous reads & writes
dim3 n_blocks(1,
std::min(CudaNdarray_HOST_DIMS(self)[0], NUM_VECTOR_OP_BLOCKS)
);
std::min(
CudaNdarray_HOST_DIMS(self)[0],
NUM_VECTOR_OP_BLOCKS));
dim3 n_threads(
std::min(CudaNdarray_HOST_DIMS(self)[1], NUM_VECTOR_OP_THREADS_PER_BLOCK)
);
k_iop_3<<<n_blocks, n_threads>>>(1,
std::min(
CudaNdarray_HOST_DIMS(self)[1],
NUM_VECTOR_OP_THREADS_PER_BLOCK));
k3<<<n_blocks, n_threads>>>(1,
CudaNdarray_HOST_DIMS(self)[0],
CudaNdarray_HOST_DIMS(self)[1],
CudaNdarray_DEV_DATA(self),
...
...
@@ -1130,25 +1197,33 @@ CudaNdarray_inplace_add_div(PyObject* py_self, PyObject * py_other, int fct_nb)
CudaNdarray_HOST_STRIDES(other)[1]);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if
(
cudaSuccess != err)
if
(
cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error: %s: %s.\n", "k_iop_3", cudaGetErrorString(err));
return NULL;
PyErr_Format(
PyExc_RuntimeError,
"Cuda error: %s: %s.\n",
"k3",
cudaGetErrorString(err));
return -1;
}
Py_INCREF(py_self);
return py_self;
}
break;
case 3:
{
//TODO: Dimshuffle so that at least one of the arrays
// has a contiguous dimension on the thread idx.
dim3 n_blocks(
std::min(CudaNdarray_HOST_DIMS(self)[0], NUM_VECTOR_OP_BLOCKS),
CudaNdarray_HOST_DIMS(self)[1]
);
while (n_blocks.x * n_blocks.y > NUM_VECTOR_OP_BLOCKS) n_blocks.y /= 2;
std::min(
CudaNdarray_HOST_DIMS(self)[0],
NUM_VECTOR_OP_BLOCKS),
CudaNdarray_HOST_DIMS(self)[1]);
while (n_blocks.x * n_blocks.y > NUM_VECTOR_OP_BLOCKS)
n_blocks.y /= 2;
dim3 n_threads(
std::min(CudaNdarray_HOST_DIMS(self)[2], NUM_VECTOR_OP_THREADS_PER_BLOCK)
);
k_iop_3<<<n_blocks, n_threads>>>(
std::min(
CudaNdarray_HOST_DIMS(self)[2],
NUM_VECTOR_OP_THREADS_PER_BLOCK));
k3<<<n_blocks, n_threads>>>(
CudaNdarray_HOST_DIMS(self)[0],
CudaNdarray_HOST_DIMS(self)[1],
CudaNdarray_HOST_DIMS(self)[2],
...
...
@@ -1162,25 +1237,34 @@ CudaNdarray_inplace_add_div(PyObject* py_self, PyObject * py_other, int fct_nb)
CudaNdarray_HOST_STRIDES(other)[2]);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if
(
cudaSuccess != err)
if
(
cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error: %s: %s.\n", "k_iop_3", cudaGetErrorString(err));
return NULL;
PyErr_Format(
PyExc_RuntimeError,
"Cuda error: %s: %s.\n",
"k3",
cudaGetErrorString(err));
return -1;
}
Py_INCREF(py_self);
return py_self;
}
break;
case 4:
{
dim3 n_blocks(
std::min(CudaNdarray_HOST_DIMS(self)[0], NUM_VECTOR_OP_BLOCKS),
std::min(
CudaNdarray_HOST_DIMS(self)[0],
NUM_VECTOR_OP_BLOCKS),
CudaNdarray_HOST_DIMS(self)[1]
);
while (n_blocks.x * n_blocks.y > NUM_VECTOR_OP_BLOCKS) n_blocks.y /= 2;
while (n_blocks.x * n_blocks.y > NUM_VECTOR_OP_BLOCKS)
n_blocks.y /= 2;
dim3 n_threads(
std::min(CudaNdarray_HOST_DIMS(self)[2], NUM_VECTOR_OP_THREADS_PER_BLOCK)
);
k_iop_4<<<n_blocks, n_threads>>>(
std::min(
CudaNdarray_HOST_DIMS(self)[2],
NUM_VECTOR_OP_THREADS_PER_BLOCK)
//TODO: DON"T YOU NEED OT PUT DIMS[3] in here???
);
k4<<<n_blocks, n_threads>>>(
CudaNdarray_HOST_DIMS(self)[0],
CudaNdarray_HOST_DIMS(self)[1],
CudaNdarray_HOST_DIMS(self)[2],
...
...
@@ -1197,27 +1281,35 @@ CudaNdarray_inplace_add_div(PyObject* py_self, PyObject * py_other, int fct_nb)
CudaNdarray_HOST_STRIDES(other)[3]);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if
(
cudaSuccess != err)
if
(
cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error: %s: %s.\n", "k_iop_4", cudaGetErrorString(err));
return NULL;
PyErr_Format(
PyExc_RuntimeError,
"Cuda error: %s: %s.\n",
"k4",
cudaGetErrorString(err));
return -1;
}
Py_INCREF(py_self);
return py_self;
}
break;
case 5:
{
dim3 n_blocks(
std::min(CudaNdarray_HOST_DIMS(self)[1], NUM_VECTOR_OP_BLOCKS),
CudaNdarray_HOST_DIMS(self)[2]
);
while (n_blocks.x * n_blocks.y > NUM_VECTOR_OP_BLOCKS) n_blocks.y /= 2;
std::min(
CudaNdarray_HOST_DIMS(self)[1],
NUM_VECTOR_OP_BLOCKS),
CudaNdarray_HOST_DIMS(self)[2]);
while (n_blocks.x * n_blocks.y > NUM_VECTOR_OP_BLOCKS)
n_blocks.y /= 2;
dim3 n_threads(
std::min(CudaNdarray_HOST_DIMS(self)[3], NUM_VECTOR_OP_THREADS_PER_BLOCK)
std::min(
CudaNdarray_HOST_DIMS(self)[3],
NUM_VECTOR_OP_THREADS_PER_BLOCK)
//TODO: DON"T YOU NEED OT PUT DIMS[3] in here???
);
for (int i = 0; i < CudaNdarray_HOST_DIMS(self)[0]; ++i)
{
k_iop_
4<<<n_blocks, n_threads>>>(
k
4<<<n_blocks, n_threads>>>(
CudaNdarray_HOST_DIMS(self)[1],
CudaNdarray_HOST_DIMS(self)[2],
CudaNdarray_HOST_DIMS(self)[3],
...
...
@@ -1236,17 +1328,26 @@ CudaNdarray_inplace_add_div(PyObject* py_self, PyObject * py_other, int fct_nb)
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error: %s: %s.\n", "k_iop_4", cudaGetErrorString(err));
return NULL;
PyErr_Format(
PyExc_RuntimeError,
"Cuda error: %s: %s.\n",
"k4",
cudaGetErrorString(err));
return -1;
}
}
Py_INCREF(py_self);
return py_self;
}
break;
default:
{
PyErr_Format(
PyExc_NotImplementedError,
"inplace_elemwise w nd=%i\n",
self->nd);
return -1;
}
}
PyErr_Format(PyExc_NotImplementedError, "inplace_add w nd=%i\n", self->nd);
return NULL;
return 0;
}
/*
...
...
@@ -1254,11 +1355,14 @@ CudaNdarray_inplace_add_div(PyObject* py_self, PyObject * py_other, int fct_nb)
*/
// Will be called by __iadd__ in Python
static PyObject *
CudaNdarray_inplace_add(PyObject* py_self, PyObject * py_other){
PyObject * rval = CudaNdarray_inplace_add_div(py_self, py_other, 0);
//We should not increment the refcount as we are doing inplace operation
//And in this syntax, their is no additional reference created!
return rval;
CudaNdarray_inplace_add(PyObject* py_self, PyObject * py_other)
{
if (CudaNdarray_inplace_elemwise(py_self, py_other, IADD))
{
return NULL;
}
Py_INCREF(py_self);
return py_self;
}
/*
...
...
@@ -1266,11 +1370,14 @@ CudaNdarray_inplace_add(PyObject* py_self, PyObject * py_other){
*/
// Will be called by __idiv__ in Python
static PyObject *
CudaNdarray_inplace_div(PyObject* py_self, PyObject * py_other){
PyObject * rval = CudaNdarray_inplace_add_div(py_self, py_other, 1);
//We should not increment the refcount as we are doing inplace operation
//And in this syntax, their is no additional reference created!
return rval;
CudaNdarray_inplace_div(PyObject* py_self, PyObject * py_other)
{
if (CudaNdarray_inplace_elemwise(py_self, py_other, IDIV))
{
return NULL;
}
Py_INCREF(py_self);
return py_self;
}
static PyNumberMethods CudaNdarrayNumberMethods =
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论