Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
462fb359
提交
462fb359
authored
10月 01, 2009
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
差异文件
merge + several things that can't be committed partially.
Mainly: - improved GpuSum.__str__ - fixed problem in make_node of IncSubtensor and Subtensor - added PyErr_Format()s to failure-handling code of GpuSum
上级
9e6c2c03
3e5f365f
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
499 行增加
和
94 行删除
+499
-94
basic_ops.py
basic_ops.py
+463
-93
elemwise.py
elemwise.py
+1
-1
opt.py
opt.py
+35
-0
没有找到文件。
basic_ops.py
浏览文件 @
462fb359
...
@@ -117,7 +117,8 @@ class GpuElemwise(Op):
...
@@ -117,7 +117,8 @@ class GpuElemwise(Op):
items
=
self
.
inplace_pattern
.
items
()
items
=
self
.
inplace_pattern
.
items
()
items
.
sort
()
items
.
sort
()
return
"GpuElemwise{
%
s}
%
s"
%
(
self
.
scalar_op
.
__class__
.
__name__
,
str
(
items
))
return
"GpuElemwise{
%
s}
%
s"
%
(
self
.
scalar_op
.
__class__
.
__name__
,
str
(
items
))
return
"GpuElemwise{
%
s}"
%
(
self
.
scalar_op
.
__class__
.
__name__
)
#return "GpuElemwise{%s}" % (self.scalar_op.__class__.__name__)
return
"GpuElemwise{
%
s}"
%
(
self
.
scalar_op
)
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
_inputs
=
[
as_cuda_ndarray_variable
(
i
)
for
i
in
inputs
]
_inputs
=
[
as_cuda_ndarray_variable
(
i
)
for
i
in
inputs
]
...
@@ -367,7 +368,7 @@ class GpuSum(Op):
...
@@ -367,7 +368,7 @@ class GpuSum(Op):
return
hash
(
type
(
self
))
^
hash
(
self
.
reduce_mask
)
return
hash
(
type
(
self
))
^
hash
(
self
.
reduce_mask
)
def
__str__
(
self
):
def
__str__
(
self
):
return
"GpuSum{
%
s}"
%
str
(
self
.
reduce_mask
)
return
"GpuSum{
%
s}"
%
','
.
join
(
str
(
i
)
for
i
in
self
.
reduce_mask
)
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
if
(
x
.
type
.
ndim
!=
len
(
self
.
reduce_mask
)):
if
(
x
.
type
.
ndim
!=
len
(
self
.
reduce_mask
)):
...
@@ -431,6 +432,7 @@ class GpuSum(Op):
...
@@ -431,6 +432,7 @@ class GpuSum(Op):
cnda_
%(z)
s = (CudaNdarray*) CudaNdarray_NewDims(
%(nd_out)
s, new_dims);
cnda_
%(z)
s = (CudaNdarray*) CudaNdarray_NewDims(
%(nd_out)
s, new_dims);
if (NULL == cnda_
%(z)
s)
if (NULL == cnda_
%(z)
s)
{
{
PyErr_Format(PyExc_RuntimeError, "Failed to allocate output");
%(fail)
s;
%(fail)
s;
}
}
}
}
...
@@ -439,20 +441,172 @@ class GpuSum(Op):
...
@@ -439,20 +441,172 @@ class GpuSum(Op):
#
#
# Now perform the reduction
# Now perform the reduction
#
#
if
self
.
reduce_mask
==
(
1
,):
getattr
(
self
,
'c_code_reduce_
%
s'
%
(
''
.
join
(
str
(
i
)
for
i
in
self
.
reduce_mask
)))(
sio
,
node
,
name
,
x
,
z
,
fail
)
self
.
c_code_reduce_1
(
sio
,
node
,
name
,
x
,
z
,
fail
)
elif
self
.
reduce_mask
==
(
1
,
1
):
return
sio
.
getvalue
()
self
.
c_code_reduce_11
(
sio
,
node
,
name
,
x
,
z
,
fail
)
elif
self
.
reduce_mask
==
(
1
,
0
):
def
_makecall
(
self
,
node
,
name
,
x
,
z
,
fail
):
self
.
c_code_reduce_10
(
sio
,
node
,
name
,
x
,
z
,
fail
)
"""Return a string for making a kernel call.
elif
self
.
reduce_mask
==
(
1
,
0
,
1
,
1
):
self
.
c_code_reduce_1011
(
sio
,
node
,
name
,
x
,
z
,
fail
)
The return value looks something like:
else
:
print
'UNWRITTEN REDUCE MASK'
,
self
.
reduce_mask
.. code-block:: c
assert
0
if (verbose) printf("running kernel_reduce_sum_10_
%(name)
s
\\
n");
int n_shared = sizeof(float) * n_threads.x;
kernel_reduce_sum_10_
%(name)
s<<<n_blocks, n_threads, n_shared>>>(
CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0],
CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[1],
CudaNdarray_DEV_DATA(cnda_
%(x)
s),
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[0],
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[1],
CudaNdarray_DEV_DATA(cnda_
%(z)
s),
CudaNdarray_HOST_STRIDES(cnda_
%(z)
s)[0]
);
CNDA_THREAD_SYNC;
if (cudaSuccess != cudaGetLastError())
{
PyErr_Format(PyExc_RuntimeError, "Cuda error: ... );
%(fail)
s;
}
"""
sio
=
StringIO
.
StringIO
()
pattern
=
''
.
join
(
str
(
c
)
for
c
in
self
.
reduce_mask
)
ndim
=
len
(
pattern
)
nd_out
=
ndim
-
sum
(
self
.
reduce_mask
)
print
>>
sio
,
"""
if (verbose) printf("running kernel_reduce_sum_
%(pattern)
s_
%(name)
s
\\
n");
int n_shared = sizeof(float) * n_threads.x * n_threads.y * n_threads.z;
kernel_reduce_sum_
%(pattern)
s_
%(name)
s<<<n_blocks, n_threads, n_shared>>>(
"""
%
locals
()
for
i
in
xrange
(
ndim
):
print
>>
sio
,
"""
CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[
%(i)
s],
"""
%
locals
()
print
>>
sio
,
"""
CudaNdarray_DEV_DATA(cnda_
%(x)
s)
"""
%
locals
()
for
i
in
xrange
(
ndim
):
print
>>
sio
,
"""
,CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[
%(i)
s]
"""
%
locals
()
print
>>
sio
,
"""
,CudaNdarray_DEV_DATA(cnda_
%(z)
s)
"""
%
locals
()
for
i
in
xrange
(
nd_out
):
print
>>
sio
,
"""
,CudaNdarray_HOST_STRIDES(cnda_
%(z)
s)[
%(i)
s]
"""
%
locals
()
print
>>
sio
,
"""
);
CNDA_THREAD_SYNC;
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s. (grid:
%%
i x
%%
i; block:
%%
i x
%%
i x
%%
i)
\\
n",
"kernel_reduce_sum_
%(pattern)
s_
%(name)
s",
cudaGetErrorString(sts),
n_blocks.x,
n_blocks.y,
n_threads.x,
n_threads.y,
n_threads.z);
%(fail)
s;
}
"""
%
locals
()
return
sio
.
getvalue
()
def
_k_decl
(
self
,
node
,
nodename
):
"""Return a string to declare a kernel function
.. code-block:: c
static __global__ void kernel_reduce_sum_110_
%(nodename)
s(
const int d0,
const int d1,
const int d2,
const float *A,
const int sA0,
const int sA1,
const int sA2,
float * Z,
const int sZ0)
"""
%
locals
()
pattern
=
''
.
join
(
str
(
i
)
for
i
in
self
.
reduce_mask
)
sio
=
StringIO
.
StringIO
()
print
>>
sio
,
"""
static __global__ void kernel_reduce_sum_
%(pattern)
s_
%(nodename)
s(
"""
%
locals
()
for
i
in
xrange
(
len
(
self
.
reduce_mask
)):
print
>>
sio
,
"""
const int d
%(i)
s,
"""
%
locals
()
print
>>
sio
,
"""
const float *A,
"""
%
locals
()
for
i
in
xrange
(
len
(
self
.
reduce_mask
)):
print
>>
sio
,
"""
const int sA
%(i)
s,
"""
%
locals
()
print
>>
sio
,
"""
float * Z
"""
%
locals
()
for
i
in
xrange
(
len
(
self
.
reduce_mask
)
-
sum
(
self
.
reduce_mask
)):
print
>>
sio
,
"""
, const int sZ
%(i)
s
"""
%
locals
()
print
>>
sio
,
")"
return
sio
.
getvalue
()
return
sio
.
getvalue
()
def
_k_init
(
self
,
*
args
):
return
"""
const int threadCount = blockDim.x * blockDim.y * blockDim.y;
const int threadNum = threadIdx.z * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;
extern __shared__ float buf[];
float mysum = 0.0f;
if (warpSize != 32)
{
//TODO: set error code
Z[0] = -666;
return;
}
"""
def
_k_reduce_buf
(
self
,
z_pos
):
return
"""
buf[threadNum] = mysum;
__syncthreads();
// rest of function is handled by one warp
if (threadNum < warpSize)
{
//round up all the partial sums into the first `warpSize` elements
for (int i = threadNum + warpSize; i < threadCount; i += warpSize)
{
mysum += buf[i];
}
buf[threadNum] = mysum;
// no sync because only one warp is running
if (threadNum < 16)
{
//reduce so that threadNum 0 has the sum of everything
if(threadNum + 16 < threadCount) buf[threadNum] += buf[threadNum+16];
if(threadNum + 8 < threadCount) buf[threadNum] += buf[threadNum+8];
if(threadNum + 4 < threadCount) buf[threadNum] += buf[threadNum+4];
if(threadNum + 2 < threadCount) buf[threadNum] += buf[threadNum+2];
if(threadNum + 1 < threadCount) buf[threadNum] += buf[threadNum+1];
if (threadNum == 0)
{
%(z_pos)
s = buf[0];
}
}
}
"""
%
locals
()
def
c_code_reduce_1
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
def
c_code_reduce_1
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
print
>>
sio
,
"""
print
>>
sio
,
"""
{
{
...
@@ -469,8 +623,17 @@ class GpuSum(Op):
...
@@ -469,8 +623,17 @@ class GpuSum(Op):
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[0],
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[0],
CudaNdarray_DEV_DATA(cnda_
%(z)
s));
CudaNdarray_DEV_DATA(cnda_
%(z)
s));
CNDA_THREAD_SYNC;
CNDA_THREAD_SYNC;
if (cudaSuccess != cudaGetLastError())
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s. (grid:
%%
i x
%%
i; block:
%%
i x
%%
i x
%%
i)
\\
n",
"kernel_reduce_sum_1_
%(name)
s",
cudaGetErrorString(sts),
n_blocks.x,
n_blocks.y,
n_threads.x,
n_threads.y,
n_threads.z);
%(fail)
s;
%(fail)
s;
}
}
}
}
...
@@ -483,13 +646,14 @@ class GpuSum(Op):
...
@@ -483,13 +646,14 @@ class GpuSum(Op):
dim3 n_threads(
dim3 n_threads(
std::min(CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[1],
std::min(CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[1],
NUM_VECTOR_OP_THREADS_PER_BLOCK));
NUM_VECTOR_OP_THREADS_PER_BLOCK));
while (n_threads.y * n_threads.x < NUM_VECTOR_OP_THREADS_PER_BLOCK) ++n_threads.y;
while (n_threads.y * n_threads.x <
=
NUM_VECTOR_OP_THREADS_PER_BLOCK) ++n_threads.y;
n_threads.y -= 1;
n_threads.y -= 1;
if (n_threads.y > CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0])
if (n_threads.y > CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0])
n_threads.y = CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0];
n_threads.y = CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0];
dim3 n_blocks(1);
dim3 n_blocks(1);
if (verbose) printf("running kernel_reduce_sum_11_
%(name)
s
\\
n");
if (verbose) fprintf(stdout, "running kernel_reduce_sum_11_
%(name)
s
\\
n");
if (verbose) fprint_CudaNdarray(stdout, cnda_
%(x)
s);
int n_shared = sizeof(float) * n_threads.x * n_threads.y * n_threads.z;
int n_shared = sizeof(float) * n_threads.x * n_threads.y * n_threads.z;
kernel_reduce_sum_11_
%(name)
s<<<n_blocks, n_threads, n_shared>>>(
kernel_reduce_sum_11_
%(name)
s<<<n_blocks, n_threads, n_shared>>>(
CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0],
CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0],
...
@@ -499,8 +663,17 @@ class GpuSum(Op):
...
@@ -499,8 +663,17 @@ class GpuSum(Op):
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[1],
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[1],
CudaNdarray_DEV_DATA(cnda_
%(z)
s));
CudaNdarray_DEV_DATA(cnda_
%(z)
s));
CNDA_THREAD_SYNC;
CNDA_THREAD_SYNC;
if (cudaSuccess != cudaGetLastError())
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s. (grid:
%%
i x
%%
i; block:
%%
i x
%%
i x
%%
i)
\\
n",
"kernel_reduce_sum_11_
%(name)
s",
cudaGetErrorString(sts),
n_blocks.x,
n_blocks.y,
n_threads.x,
n_threads.y,
n_threads.z);
%(fail)
s;
%(fail)
s;
}
}
}
}
...
@@ -527,13 +700,120 @@ class GpuSum(Op):
...
@@ -527,13 +700,120 @@ class GpuSum(Op):
CudaNdarray_HOST_STRIDES(cnda_
%(z)
s)[0]
CudaNdarray_HOST_STRIDES(cnda_
%(z)
s)[0]
);
);
CNDA_THREAD_SYNC;
CNDA_THREAD_SYNC;
if (cudaSuccess != cudaGetLastError())
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s. (grid:
%%
i x
%%
i; block:
%%
i x
%%
i x
%%
i)
\\
n",
"kernel_reduce_sum_10_
%(name)
s",
cudaGetErrorString(sts),
n_blocks.x,
n_blocks.y,
n_threads.x,
n_threads.y,
n_threads.z);
%(fail)
s;
%(fail)
s;
}
}
}
}
"""
%
locals
()
"""
%
locals
()
def
c_code_reduce_100
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
makecall
=
self
.
_makecall
(
node
,
name
,
x
,
z
,
fail
)
# use threadIdx.x for i0
# use blockIdx.x for i1
# use blockIdx.y for i2
print
>>
sio
,
"""
{
int verbose = 0;
dim3 n_threads(
std::min(CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0],
NUM_VECTOR_OP_THREADS_PER_BLOCK));
dim3 n_blocks(CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[1]);
while (n_blocks.x * n_blocks.y <= NUM_VECTOR_OP_BLOCKS)
{
if (n_blocks.y > CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[2])
break;
n_blocks.y += 1;
}
n_blocks.y -= 1;
%(makecall)
s
}
"""
%
locals
()
def
c_code_reduce_110
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
makecall
=
self
.
_makecall
(
node
,
name
,
x
,
z
,
fail
)
print
>>
sio
,
"""
{
int verbose = 0;
dim3 n_threads(
std::min(CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[1],
NUM_VECTOR_OP_THREADS_PER_BLOCK));
while (n_threads.x*n_threads.y <= NUM_VECTOR_OP_THREADS_PER_BLOCK)
{
if (n_threads.y > CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0])
break;
n_threads.y += 1;
}
n_threads.y -= 1;
dim3 n_blocks(CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[2]);
%(makecall)
s
}
"""
%
locals
()
def
c_code_reduce_001
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
makecall
=
self
.
_makecall
(
node
,
name
,
x
,
z
,
fail
)
print
>>
sio
,
"""
{
int verbose = 0;
dim3 n_threads(
std::min(CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[2],
NUM_VECTOR_OP_THREADS_PER_BLOCK));
dim3 n_blocks(
std::min(CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0],
NUM_VECTOR_OP_BLOCKS));
while (n_blocks.x * n_blocks.y <= NUM_VECTOR_OP_BLOCKS)
{
if (n_blocks.y > CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[1])
break;
n_blocks.y += 1;
}
n_blocks.y -= 1;
%(makecall)
s
}
"""
%
locals
()
def
c_code_reduce_111
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
makecall
=
self
.
_makecall
(
node
,
name
,
x
,
z
,
fail
)
print
>>
sio
,
"""
{
int verbose = 0;
dim3 n_threads(
std::min(CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[2],
NUM_VECTOR_OP_THREADS_PER_BLOCK));
//get as many y threads as we can fit
while (n_threads.x * n_threads.y <= NUM_VECTOR_OP_THREADS_PER_BLOCK)
{
if (n_threads.y > CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[1])
break;
n_threads.y += 1;
}
n_threads.y -= 1;
//get as many z threads as we can fit
while (n_threads.x * n_threads.y * n_threads.z <= NUM_VECTOR_OP_THREADS_PER_BLOCK)
{
if (n_threads.z > CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0])
break;
n_threads.z += 1;
}
n_threads.z -= 1;
dim3 n_blocks(1,1,1);
%(makecall)
s
}
"""
%
locals
()
def
c_code_reduce_1011
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
def
c_code_reduce_1011
(
self
,
sio
,
node
,
name
,
x
,
z
,
fail
):
print
>>
sio
,
"""
print
>>
sio
,
"""
{
{
...
@@ -573,15 +853,26 @@ class GpuSum(Op):
...
@@ -573,15 +853,26 @@ class GpuSum(Op):
CudaNdarray_DEV_DATA(cnda_
%(z)
s),
CudaNdarray_DEV_DATA(cnda_
%(z)
s),
CudaNdarray_HOST_STRIDES(cnda_
%(z)
s)[0]);
CudaNdarray_HOST_STRIDES(cnda_
%(z)
s)[0]);
CNDA_THREAD_SYNC;
CNDA_THREAD_SYNC;
if (cudaSuccess != cudaGetLastError())
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s. (grid:
%%
i x
%%
i; block:
%%
i x
%%
i x
%%
i)
\\
n",
"kernel_reduce_sum_1011_
%(name)
s",
cudaGetErrorString(sts),
n_blocks.x,
n_blocks.y,
n_threads.x,
n_threads.y,
n_threads.z);
%(fail)
s;
%(fail)
s;
}
}
}
}
"""
%
locals
()
"""
%
locals
()
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
()
#return ()
return
(
7
,)
def
c_support_code_apply
(
self
,
node
,
nodename
):
def
c_support_code_apply
(
self
,
node
,
nodename
):
sio
=
StringIO
.
StringIO
()
sio
=
StringIO
.
StringIO
()
...
@@ -747,6 +1038,148 @@ class GpuSum(Op):
...
@@ -747,6 +1038,148 @@ class GpuSum(Op):
}
}
}
}
"""
%
locals
()
"""
%
locals
()
if
self
.
reduce_mask
==
(
1
,
1
,
0
):
# this kernel uses one block for each column,
# threads per block for each element per column.
#TODO: This kernel is pretty inefficient in terms of reading, because if A is
# c_contiguous (typical case) then each warp is accessing non-contigous
# memory (a segment of a column).
reducebuf
=
self
.
_k_reduce_buf
(
'Z[blockIdx.x * sZ0]'
)
print
>>
sio
,
"""
static __global__ void kernel_reduce_sum_110_
%(nodename)
s(
const int d0,
const int d1,
const int d2,
const float *A, const int sA0, const int sA1, const int sA2,
float * Z, const int sZ0)
{
const int threadCount = blockDim.x * blockDim.y;
const int threadNum = threadIdx.y * blockDim.x + threadIdx.x;
extern __shared__ float buf[];
float mysum = 0.0f;
if (warpSize != 32)
{
//TODO: set error code
Z[blockIdx.x * sZ0] = -666;
return;
}
for (int i0 = threadIdx.y; i0 < d0; i0 += blockDim.y)
{
for (int i1 = threadIdx.x; i1 < d1; i1 += blockDim.x)
{
float Ai = A[i0 * sA0 + i1 * sA1 + blockIdx.x * sA2];
mysum += Ai;
}
}
%(reducebuf)
s
}
"""
%
locals
()
if
self
.
reduce_mask
==
(
1
,
0
,
0
):
reducebuf
=
self
.
_k_reduce_buf
(
'Z[i1 * sZ0 + i2 * sZ1]'
)
decl
=
self
.
_k_decl
(
node
,
nodename
)
init
=
self
.
_k_init
(
node
,
nodename
)
print
>>
sio
,
"""
%(decl)
s
{
%(init)
s
for (int i2 = blockIdx.y; i2 < d2; i2 += gridDim.y)
{
for (int i1 = blockIdx.x; i1 < d1; i1 += gridDim.x)
{
mysum = 0;
for (int i0 = threadIdx.x; i0 < d0; i0 += blockDim.x)
{
mysum += A[i0 * sA0 + i1 * sA1 + i2 * sA2];
}
%(reducebuf)
s
}
}
}
"""
%
locals
()
if
self
.
reduce_mask
==
(
1
,
1
,
1
):
reducebuf
=
self
.
_k_reduce_buf
(
'Z[0]'
)
decl
=
self
.
_k_decl
(
node
,
nodename
)
init
=
self
.
_k_init
(
node
,
nodename
)
print
>>
sio
,
"""
%(decl)
s
{
%(init)
s
mysum = 0;
for (int i0 = threadIdx.z; i0 < d0; i0 += blockDim.z)
{
for (int i1 = threadIdx.y; i1 < d1; i1 += blockDim.y)
{
for (int i2 = threadIdx.x; i2 < d2; i2 += blockDim.x)
{
mysum += A[i0 * sA0 + i1 * sA1 + i2 * sA2];
}
}
}
%(reducebuf)
s
}
"""
%
locals
()
if
self
.
reduce_mask
==
(
0
,
0
,
1
):
# this kernel uses one block for each row,
# threads per block for each element per row.
print
>>
sio
,
"""
static __global__ void kernel_reduce_sum_001_
%(nodename)
s(
const int d0,
const int d1,
const int d2,
const float *A, const int sA0, const int sA1, const int sA2,
float * Z, const int sZ0, const int sZ1)
{
const int threadCount = blockDim.x;
const int threadNum = threadIdx.x;
extern __shared__ float buf[];
if (warpSize != 32)
{
return; //TODO: set error code
}
for (int i0 = blockIdx.x; i0 < d0; i0 += gridDim.x)
{
for (int i1 = blockIdx.y; i1 < d1; i1 += gridDim.y)
{
float mysum = 0.0f;
for (int i2 = threadIdx.x; i2 < d2; i2 += blockDim.x)
{
mysum += A[i0 * sA0 + i1 * sA1 + i2 * sA2];
}
buf[threadNum] = mysum;
__syncthreads();
// rest of function is handled by one warp
if (threadNum < warpSize)
{
for (int i = threadNum + warpSize; i < threadCount; i += warpSize)
{
mysum += buf[i];
}
buf[threadNum] = mysum;
if (threadNum < 16)
{
//reduce so that threadNum 0 has the sum of everything
if(threadNum + 16 < threadCount) buf[threadNum] += buf[threadNum+16];
if(threadNum + 8 < threadCount) buf[threadNum] += buf[threadNum+8];
if(threadNum + 4 < threadCount) buf[threadNum] += buf[threadNum+4];
if(threadNum + 2 < threadCount) buf[threadNum] += buf[threadNum+2];
if(threadNum + 1 < threadCount) buf[threadNum] += buf[threadNum+1];
if (threadNum == 0)
{
Z[i0 * sZ0 + i1 * sZ1] = buf[0];
}
}
}
}
}
}
"""
%
locals
()
if
self
.
reduce_mask
==
(
1
,
0
,
1
,
1
):
if
self
.
reduce_mask
==
(
1
,
0
,
1
,
1
):
print
>>
sio
,
"""
print
>>
sio
,
"""
static __global__ void kernel_reduce_sum_1011_
%(nodename)
s(
static __global__ void kernel_reduce_sum_1011_
%(nodename)
s(
...
@@ -820,10 +1253,10 @@ class GpuReshape(tensor.Reshape):
...
@@ -820,10 +1253,10 @@ class GpuReshape(tensor.Reshape):
class
GpuSubtensor
(
tensor
.
Subtensor
):
class
GpuSubtensor
(
tensor
.
Subtensor
):
# __hash__, __eq__, __str__ come from tensor.Subtensor
# __hash__, __eq__, __str__ come from tensor.Subtensor
def
make_node
(
self
,
x
,
*
inputs
):
def
make_node
(
self
,
x
,
*
inputs
):
assert
isinstance
(
x
.
type
,
CudaNdarrayType
)
rval
=
tensor
.
Subtensor
.
make_node
(
self
,
x
,
*
inputs
)
rval
=
tensor
.
Subtensor
.
make_node
(
self
,
x
,
*
inputs
)
rval
.
inputs
[
0
]
=
x
# clobber the 'astensor'
otype
=
CudaNdarrayType
(
rval
.
outputs
[
0
]
.
type
.
broadcastable
)
rval
.
outputs
[
0
]
.
type
=
CudaNdarrayType
(
rval
.
outputs
[
0
]
.
type
.
broadcastable
)
return
Apply
(
self
,
[
x
]
+
rval
.
inputs
[
1
:],
[
otype
()])
return
rval
def
perform
(
self
,
node
,
inputs
,
(
out
,
)):
def
perform
(
self
,
node
,
inputs
,
(
out
,
)):
x
=
inputs
[
0
]
x
=
inputs
[
0
]
...
@@ -844,75 +1277,12 @@ class GpuSubtensor(tensor.Subtensor):
...
@@ -844,75 +1277,12 @@ class GpuSubtensor(tensor.Subtensor):
cdata
=
cdata
[
0
]
cdata
=
cdata
[
0
]
out
[
0
]
=
x
.
__getitem__
(
cdata
)
out
[
0
]
=
x
.
__getitem__
(
cdata
)
def
old_perform
(
self
,
node
,
inputs
,
(
out
,
)):
class
GpuIncSubtensor
(
tensor
.
IncSubtensor
):
indices
=
list
(
reversed
(
inputs
[
1
:]))
def
make_node
(
self
,
x
,
y
,
*
inputs
):
assert
isinstance
(
x
.
type
,
CudaNdarrayType
)
def
convert
(
entry
):
assert
isinstance
(
y
.
type
,
CudaNdarrayType
)
if
isinstance
(
entry
,
Type
):
rval
=
tensor
.
IncSubtensor
.
make_node
(
self
,
x
,
y
,
*
inputs
)
return
indices
.
pop
()
return
Apply
(
self
,
[
x
,
y
]
+
rval
.
inputs
[
2
:],
[
x
.
type
()])
elif
isinstance
(
entry
,
slice
):
return
slice
(
convert
(
entry
.
start
),
convert
(
entry
.
stop
),
convert
(
entry
.
step
))
else
:
return
entry
x
=
inputs
[
0
]
.
view
()
out
[
0
]
=
x
#todo; when this works, put it into CudaNdarray.__getitem__
# (sequence protocol)
x_shape
=
x
.
shape
x_strides
=
x
.
_strides
offset
=
0
for
i
,
thing
in
enumerate
(
map
(
convert
,
self
.
idx_list
)):
if
isinstance
(
thing
,
int
):
#this requires reducing the rank of the
# view....
raise
NotImplementedError
()
if
isinstance
(
thing
,
slice
):
#stride
if
thing
.
step
is
None
:
stride
=
1
else
:
stride
=
thing
.
step
#start
if
thing
.
start
is
None
:
if
stride
>
0
:
start
=
0
else
:
start
=
x_shape
[
i
]
-
1
else
:
if
thing
.
start
<
0
:
start
=
x_shape
[
i
]
-
thing
.
start
else
:
start
=
thing
.
start
#stop
if
thing
.
stop
is
None
:
if
stride
>
0
:
stop
=
x_shape
[
i
]
else
:
stop
=
-
1
else
:
if
thing
.
stop
<
0
:
stop
=
x_shape
[
i
]
-
thing
.
stop
else
:
stop
=
thing
.
stop
newlen
=
(
stop
-
start
)
//
stride
offset
+=
x_strides
[
i
]
*
start
debug
(
'GpuSubtensor slice'
,
i
,
': '
,
start
,
stop
,
stride
)
debug
(
'GpuSubtensor shape'
,
i
,
': '
,
x_shape
[
i
],
newlen
)
x
.
_set_shape_i
(
i
,
newlen
)
x
.
_set_stride
(
i
,
x_strides
[
i
]
*
stride
)
#print 'perform', id(x), x.shape, i, thing
sizeof_float
=
4
x
.
_dev_data
+=
offset
*
sizeof_float
#sys.stdout.flush()
#sys.exit()
class
GpuShape
(
tensor
.
Shape
):
class
GpuShape
(
tensor
.
Shape
):
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
...
...
elemwise.py
浏览文件 @
462fb359
...
@@ -641,7 +641,7 @@ class NaiveAlgo(object):
...
@@ -641,7 +641,7 @@ class NaiveAlgo(object):
output_args
=
", "
.
join
(
"o
%
i_data, o
%
i_str"
%
(
ipos
,
ipos
)
output_args
=
", "
.
join
(
"o
%
i_data, o
%
i_str"
%
(
ipos
,
ipos
)
for
ipos
in
xrange
(
len
(
node
.
outputs
)))
for
ipos
in
xrange
(
len
(
node
.
outputs
)))
prod_dims
=
'*'
.
join
(
"dims[
%
i]"
%
di
for
di
in
xrange
(
nd
)
)
prod_dims
=
'*'
.
join
(
[
"dims[
%
i]"
%
di
for
di
in
xrange
(
nd
)]
+
[
'1'
]
)
scalar_op
=
self
.
scalar_op
.
__class__
.
__name__
scalar_op
=
self
.
scalar_op
.
__class__
.
__name__
...
...
opt.py
浏览文件 @
462fb359
...
@@ -183,6 +183,41 @@ def local_gpu_subtensor(node):
...
@@ -183,6 +183,41 @@ def local_gpu_subtensor(node):
return
[
host_from_gpu
(
GpuSubtensor
(
node
.
op
.
idx_list
)(
gpu_x
,
*
coords
))]
return
[
host_from_gpu
(
GpuSubtensor
(
node
.
op
.
idx_list
)(
gpu_x
,
*
coords
))]
return
False
return
False
@register_opt
()
@local_optimizer
([])
def
local_gpu_incsubtensor
(
node
):
if
node
.
op
==
gpu_from_host
:
host_output
=
node
.
inputs
[
0
]
if
host_output
.
owner
and
type
(
host_output
.
owner
.
op
)
==
tensor
.
IncSubtensor
:
incsubt
=
host_output
.
owner
.
op
x
,
y
=
host_output
.
owner
.
inputs
[
0
:
2
]
coords
=
host_output
.
owner
.
inputs
[
2
:]
return
[
GpuIncSubtensor
(
incsubt
.
idx_list
,
inplace
=
incsubt
.
inplace
)(
gpu_from_host
(
x
),
gpu_from_host
(
y
),
*
coords
)]
if
type
(
node
.
op
)
==
tensor
.
IncSubtensor
:
x
,
y
=
node
.
inputs
[
0
:
2
]
assert
isinstance
(
x
.
type
,
tensor
.
TensorType
)
assert
isinstance
(
y
.
type
,
tensor
.
TensorType
)
coords
=
node
.
inputs
[
2
:]
go_gpu
=
False
if
x
.
owner
and
x
.
owner
.
op
==
host_from_gpu
:
go_gpu
=
True
gpu_x
,
=
x
.
owner
.
inputs
else
:
gpu_x
=
gpu_from_host
(
x
)
if
y
.
owner
and
y
.
owner
.
op
==
host_from_gpu
:
go_gpu
=
True
gpu_y
,
=
y
.
owner
.
inputs
else
:
gpu_y
=
gpu_from_host
(
y
)
if
go_gpu
:
return
[
host_from_gpu
(
GpuIncSubtensor
(
node
.
op
.
idx_list
,
inplace
=
node
.
op
.
inplace
)(
gpu_x
,
gpu_y
,
*
coords
))]
return
False
@register_opt
()
@register_opt
()
@local_optimizer
([])
@local_optimizer
([])
def
local_gpu_shape
(
node
):
def
local_gpu_shape
(
node
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论