Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cee77d13
提交
cee77d13
authored
8月 28, 2009
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added maxpooling op GpuDownsampleMax
上级
3272fee0
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
303 行增加
和
39 行删除
+303
-39
blas.py
blas.py
+235
-0
opt.py
opt.py
+59
-36
test_nnet.py
tests/test_nnet.py
+9
-3
没有找到文件。
blas.py
浏览文件 @
cee77d13
...
...
@@ -187,3 +187,238 @@ class GpuConv(Op):
kern_align
=
self
.
logical_kern_align_top
,
verbose
=
0
)
from
theano.sandbox.downsample
import
DownsampleFactorMax
class
GpuDownsampleFactorMax
(
DownsampleFactorMax
):
# inherit __eq__, __hash__, __str__
def
make_node
(
self
,
x
):
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
perform
(
self
,
node
,
input_storage
,
output_storage
):
raise
NotImplementedError
(
'only C is implemented'
)
def
c_code_cache_version
(
self
):
return
()
def
c_code
(
self
,
node
,
nodename
,
(
x
,),
(
z
,),
sub
):
fail
=
sub
[
'fail'
]
ds0
,
ds1
=
self
.
ds
ignore_border
=
int
(
self
.
ignore_border
)
return
"""
int dims[4], xdim2, xdim3;
if (cnda_
%(x)
s->nd != 4)
{
PyErr_SetString(PyExc_ValueError, "rank error");
%(fail)
s;
}
xdim2 = CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[2];
xdim3 = CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[3];
dims[0] = CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0];
dims[1] = CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[1];
dims[2] = xdim2 /
%(ds0)
s;
dims[3] = xdim3 /
%(ds1)
s;
if (!
%(ignore_border)
s)
{
dims[2] += (xdim2
%%
(
%(ds0)
s)?1:0);
dims[3] += (xdim3
%%
(
%(ds1)
s)?1:0);
}
if ((NULL == cnda_
%(z)
s)
|| (CudaNdarray_HOST_DIMS(cnda_
%(z)
s)[0] != dims[0])
|| (CudaNdarray_HOST_DIMS(cnda_
%(z)
s)[1] != dims[1])
|| (CudaNdarray_HOST_DIMS(cnda_
%(z)
s)[2] != dims[2])
|| (CudaNdarray_HOST_DIMS(cnda_
%(z)
s)[3] != dims[3]))
{
Py_XDECREF(cnda_
%(z)
s);
cnda_
%(z)
s = (CudaNdarray*)CudaNdarray_new_null();
if ((NULL == cnda_
%(z)
s)
|| CudaNdarray_alloc_contiguous(cnda_
%(z)
s, 4, dims))
{
Py_XDECREF(cnda_
%(z)
s);
cnda_
%(z)
s = NULL;
%(fail)
s;
}
}
{
dim3 grid(dims[0] * dims[1], dims[2]);
//dim3 block(std::min(dims[3], 512)); //TODO: implement this by supporting more
//outputs than threads
dim3 block(dims[3]);
kMaxPool_
%(nodename)
s<
%(ds0)
s,
%(ds1)
s> <<<grid, block, xdim3>>>(
dims[0], dims[1], dims[2], dims[3], xdim2, xdim3,
CudaNdarray_DEV_DATA(cnda_
%(x)
s),
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[0],
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[1],
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[2],
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[3],
CudaNdarray_DEV_DATA(cnda_
%(z)
s));
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s.
\\
n", "kMaxPool_
%(nodename)
s", cudaGetErrorString(err));
%(fail)
s;
}
}
"""
%
locals
()
def
c_support_code_apply
(
self
,
node
,
nodename
):
ignore_border
=
int
(
self
.
ignore_border
)
return
"""
template<int pf2, int pf3>
__global__ void kMaxPool_
%(nodename)
s(
int D0, int D1, int D2, int D3, int xD2, int xD3,
const float * x, int xS0, int xS1, int xS2, int xS3,
float *z)
{
float cur_max, cur_x;
int i0 = blockIdx.x / D0;
int i1 = blockIdx.x
%%
D0;
int i2 = blockIdx.y;
extern __shared__ float xbuf[]; //size [xD3]
for (int r2 = 0; (r2 < pf2) && (
%(ignore_border)
s || (r2 + i2*pf2 < xD2)); ++r2)
{
__syncthreads();
// load the current row of the image into shared memory
for (int i3 = threadIdx.x; i3 < xD3; i3 += blockDim.x)
{
xbuf[i3] = x[i0*xS0 + i1*xS1 + (i2*pf2+r2)*xS2 + i3*xS3];
}
__syncthreads();
// initialize our max if this is the first row we're loading
cur_max = (r2 == 0) ? xbuf[threadIdx.x*pf3] : cur_max;
// do a mini-reduction over the pf3 relevant elements in the current row
for (int k = 0; k < pf3; ++k)
{
cur_x = xbuf[threadIdx.x*pf3+k];
cur_max = (cur_x < cur_max) ? cur_x : cur_max;
}
}
//store the result to global memory
z[i0 * D1*D2*D3 + i1*D2*D3 + i2*D3 + threadIdx.x] = cur_max;
}
"""
%
locals
()
from
theano.sandbox.downsample
import
DownsampleFactorMaxGrad
class
GpuDownsampleFactorMaxGrad
(
DownsampleFactorMaxGrad
):
# inherit __eq__, __hash__, __str__
def
make_node
(
self
,
x
,
z
,
gz
):
return
Apply
(
self
,
[
x
,
z
,
gz
],
[
x
.
type
()])
def
perform
(
self
,
node
,
input_storage
,
output_storage
):
raise
NotImplementedError
(
'only C is implemented'
)
def
c_code_cache_version
(
self
):
return
()
def
c_code
(
self
,
node
,
nodename
,
(
x
,
z
,
gz
),
(
gx
,),
sub
):
fail
=
sub
[
'fail'
]
ds0
,
ds1
=
self
.
ds
ignore_border
=
int
(
self
.
ignore_border
)
return
"""
if (cnda_
%(x)
s->nd != 4
|| cnda_
%(z)
s->nd != 4
|| cnda_
%(gz)
s->nd != 4)
{
PyErr_SetString(PyExc_ValueError, "rank error");
%(fail)
s;
}
if ((NULL == cnda_
%(gx)
s)
|| (CudaNdarray_HOST_DIMS(cnda_
%(gx)
s)[0] != CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0])
|| (CudaNdarray_HOST_DIMS(cnda_
%(gx)
s)[1] != CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[1])
|| (CudaNdarray_HOST_DIMS(cnda_
%(gx)
s)[2] != CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[2])
|| (CudaNdarray_HOST_DIMS(cnda_
%(gx)
s)[3] != CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[3]))
{
Py_XDECREF(cnda_
%(gx)
s);
cnda_
%(gx)
s = (CudaNdarray*)CudaNdarray_new_null();
if ((NULL == cnda_
%(gx)
s)
|| CudaNdarray_alloc_contiguous(cnda_
%(gx)
s, 4, CudaNdarray_HOST_DIMS(cnda_
%(x)
s)))
{
Py_XDECREF(cnda_
%(gx)
s);
cnda_
%(gx)
s = NULL;
%(fail)
s;
}
}
{
dim3 grid(CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0], CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[2]);
//TODO: implement this by supporting more
//outputs than threads
dim3 block(CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[3]);
kDownsampleMaxGrad_
%(nodename)
s<
%(ds0)
s,
%(ds1)
s> <<<grid, block>>>(
CudaNdarray_HOST_DIMS(cnda_
%(z)
s)[0],
CudaNdarray_HOST_DIMS(cnda_
%(z)
s)[1],
CudaNdarray_HOST_DIMS(cnda_
%(z)
s)[2],
CudaNdarray_HOST_DIMS(cnda_
%(z)
s)[3],
CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[2],
CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[3],
CudaNdarray_DEV_DATA(cnda_
%(x)
s),
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[0],
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[1],
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[2],
CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[3],
CudaNdarray_DEV_DATA(cnda_
%(z)
s),
CudaNdarray_HOST_STRIDES(cnda_
%(z)
s)[0],
CudaNdarray_HOST_STRIDES(cnda_
%(z)
s)[1],
CudaNdarray_HOST_STRIDES(cnda_
%(z)
s)[2],
CudaNdarray_HOST_STRIDES(cnda_
%(z)
s)[3],
CudaNdarray_DEV_DATA(cnda_
%(gz)
s),
CudaNdarray_HOST_STRIDES(cnda_
%(gz)
s)[0],
CudaNdarray_HOST_STRIDES(cnda_
%(gz)
s)[1],
CudaNdarray_HOST_STRIDES(cnda_
%(gz)
s)[2],
CudaNdarray_HOST_STRIDES(cnda_
%(gz)
s)[3],
CudaNdarray_DEV_DATA(cnda_
%(gx)
s));
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s. (grid:
%%
i x
%%
i; block:
%%
i x
%%
i x
%%
i)
\\
n",
"kDownsampleMaxGrad_
%(nodename)
s",
cudaGetErrorString(err),
grid.x,
grid.y,
block.x,
block.y,
block.z);
%(fail)
s;
}
}
"""
%
locals
()
def
c_support_code_apply
(
self
,
node
,
nodename
):
ignore_border
=
int
(
self
.
ignore_border
)
return
"""
template<int ds0, int ds1>
__global__ void kDownsampleMaxGrad_
%(nodename)
s(
int D0, int D1, int D2, int D3, int xD2, int xD3,
const float * x, int xS0, int xS1, int xS2, int xS3,
const float * z, int zS0, int zS1, int zS2, int zS3,
const float * gz, int gzS0, int gzS1, int gzS2, int gzS3,
float *gx)
{
float cur_max, cur_x, my_z, my_gz;
int i0 = blockIdx.x;
int i1 = 0;
int i2 = blockIdx.y; // row wrt z and/or gz
int x_col = threadIdx.x;
// The algorithm here is that every thread writes one output pixel per line
for (i1 = 0; i1 < D1; ++i1)
{
if (
%(ignore_border)
s && (x_col >= ds1 * D3))
{
my_gz = 0;
}
else
{
my_gz = gz[i0 * gzS0 + i1 * gzS1 + i2 * gzS2 + (x_col/ds1)*gzS3];
my_z = z[i0 * zS0 + i1 * zS1 + i2 * zS2 + (x_col/ds1)* zS3];
}
for (int x_row = i2*ds0; (x_row < i2*ds0+ds0) && (
%(ignore_border)
s || (x_row < xD2)); ++x_row)
{
gx[i0 * D1*xD2*xD3 + i1*xD2*xD3 + x_row*xD3 + x_col]
= (my_z == x[i0*xS0 + i1*xS1 + x_row*xS2 + x_col]) ? my_gz : 0;
}
}
}
"""
%
locals
()
opt.py
浏览文件 @
cee77d13
...
...
@@ -4,6 +4,7 @@ from theano.gof import local_optimizer, EquilibriumDB, SequenceDB
from
theano_cuda_ndarray.basic_ops
import
*
from
theano_cuda_ndarray.blas
import
gpu_dot22
,
gpu_gemm
,
GpuConv
from
theano_cuda_ndarray.blas
import
GpuDownsampleFactorMax
,
GpuDownsampleFactorMaxGrad
from
theano_cuda_ndarray.nnet
import
(
GpuCrossentropySoftmaxArgmax1HotWithBias
,
GpuCrossentropySoftmax1HotWithBiasDx
)
...
...
@@ -148,42 +149,6 @@ def local_gpu_sum(node):
return
[
host_from_gpu
(
GpuSum
(
reduce_mask
)(
gpu_from_host
(
x
)))]
return
False
import
theano.sandbox.conv
@register_opt
()
@local_optimizer
([])
def
local_gpu_conv
(
node
):
"""
gpu_from_host(conv) -> gpu_conv(gpu_from_host)
conv(host_from_gpu) -> host_from_gpu(conv)
"""
def
GpuConvOp_from_ConvOp
(
op
):
ret
=
GpuConv
(
border_mode
=
op
.
out_mode
,
subsample
=
(
op
.
dx
,
op
.
dy
),
logical_img_hw
=
op
.
imshp_logical
[
1
:
3
],
logical_kern_hw
=
op
.
kshp_logical
,
logical_kern_align_top
=
op
.
kshp_logical_top_aligned
)
#HACK to print the number of MFlops in the profiler output.
if
hasattr
(
op
,
'flops'
):
ret
.
flops
=
op
.
flops
return
ret
if
node
.
op
==
gpu_from_host
:
host_input
=
node
.
inputs
[
0
]
if
host_input
.
owner
and
isinstance
(
host_input
.
owner
.
op
,
theano
.
sandbox
.
conv
.
ConvOp
):
gpu_conv
=
GpuConvOp_from_ConvOp
(
host_input
.
owner
.
op
)
img
,
kern
=
host_input
.
owner
.
inputs
return
[
gpu_conv
(
gpu_from_host
(
img
),
gpu_from_host
(
kern
))]
if
isinstance
(
node
.
op
,
theano
.
sandbox
.
conv
.
ConvOp
):
img
,
kern
=
node
.
inputs
img_on_gpu
=
(
img
.
owner
and
img
.
owner
.
op
==
host_from_gpu
)
kern_on_gpu
=
(
kern
.
owner
and
kern
.
owner
.
op
==
host_from_gpu
)
if
img_on_gpu
or
kern_on_gpu
:
gpu_conv
=
GpuConvOp_from_ConvOp
(
node
.
op
)
return
[
host_from_gpu
(
gpu_conv
(
gpu_from_host
(
img
),
gpu_from_host
(
kern
)))]
@register_opt
()
@local_optimizer
([])
def
local_gpu_reshape
(
node
):
...
...
@@ -265,3 +230,61 @@ def local_gpu_crossentorpy_softmax_1hot_with_bias_dx(node):
gpu_from_host
(
cast
(
yidx
,
'float32'
)))
return
[
host_from_gpu
(
gpu_dx
)]
return
False
#### Convolution, maxpooling
import
theano.sandbox.conv
@register_opt
()
@local_optimizer
([])
def
local_gpu_conv
(
node
):
"""
gpu_from_host(conv) -> gpu_conv(gpu_from_host)
conv(host_from_gpu) -> host_from_gpu(conv)
"""
def
GpuConvOp_from_ConvOp
(
op
):
ret
=
GpuConv
(
border_mode
=
op
.
out_mode
,
subsample
=
(
op
.
dx
,
op
.
dy
),
logical_img_hw
=
op
.
imshp_logical
[
1
:
3
],
logical_kern_hw
=
op
.
kshp_logical
,
logical_kern_align_top
=
op
.
kshp_logical_top_aligned
)
#HACK to print the number of MFlops in the profiler output.
if
hasattr
(
op
,
'flops'
):
ret
.
flops
=
op
.
flops
return
ret
if
node
.
op
==
gpu_from_host
:
host_input
=
node
.
inputs
[
0
]
if
host_input
.
owner
and
isinstance
(
host_input
.
owner
.
op
,
theano
.
sandbox
.
conv
.
ConvOp
):
gpu_conv
=
GpuConvOp_from_ConvOp
(
host_input
.
owner
.
op
)
img
,
kern
=
host_input
.
owner
.
inputs
return
[
gpu_conv
(
gpu_from_host
(
img
),
gpu_from_host
(
kern
))]
if
isinstance
(
node
.
op
,
theano
.
sandbox
.
conv
.
ConvOp
):
img
,
kern
=
node
.
inputs
img_on_gpu
=
(
img
.
owner
and
img
.
owner
.
op
==
host_from_gpu
)
kern_on_gpu
=
(
kern
.
owner
and
kern
.
owner
.
op
==
host_from_gpu
)
if
img_on_gpu
or
kern_on_gpu
:
gpu_conv
=
GpuConvOp_from_ConvOp
(
node
.
op
)
return
[
host_from_gpu
(
gpu_conv
(
gpu_from_host
(
img
),
gpu_from_host
(
kern
)))]
import
theano.sandbox.downsample
@register_opt
()
@local_optimizer
([])
def
local_gpu_downsample_factor_max
(
node
):
if
isinstance
(
node
.
op
,
theano
.
sandbox
.
downsample
.
DownsampleFactorMax
):
x
,
=
node
.
inputs
if
(
x
.
owner
and
x
.
owner
.
op
==
host_from_gpu
):
gpu_ds
=
GpuDownsampleFactorMax
(
node
.
op
.
ds
,
node
.
op
.
ignore_border
)
return
[
host_from_gpu
(
gpu_ds
(
x
.
owner
.
inputs
[
0
]))]
@register_opt
()
@local_optimizer
([])
def
local_gpu_downsample_factor_max_grad
(
node
):
if
isinstance
(
node
.
op
,
theano
.
sandbox
.
downsample
.
DownsampleFactorMaxGrad
):
x
,
z
,
gz
=
node
.
inputs
if
(
x
.
owner
and
x
.
owner
.
op
==
host_from_gpu
):
gpu_ds_grad
=
GpuDownsampleFactorMaxGrad
(
node
.
op
.
ds
,
node
.
op
.
ignore_border
)
return
[
host_from_gpu
(
gpu_ds_grad
(
x
.
owner
.
inputs
[
0
],
gpu_from_host
(
z
),
gpu_from_host
(
gz
)))]
tests/test_nnet.py
浏览文件 @
cee77d13
import
sys
,
time
import
theano
,
theano
.
sandbox
.
conv
import
theano
from
theano.compile.sandbox.sharedvalue
import
shared
from
theano.compile.sandbox.pfunc
import
pfunc
from
theano
import
tensor
import
theano.tensor.nnet
import
theano.sandbox.conv
import
theano.sandbox.downsample
import
numpy
import
theano_cuda_ndarray
as
tcn
...
...
@@ -251,8 +255,10 @@ def run_conv_nnet2_classif(shared_fn, isize, ksize, n_batch=60, n_iter=25):
conv_op
=
theano
.
sandbox
.
conv
.
ConvOp
(
shape_img
[
2
:],
shape_kern
[
2
:],
n_kern
,
n_batch
,
1
,
1
)
conv_op1
=
theano
.
sandbox
.
conv
.
ConvOp
((
n_kern
,
logical_hid_shape
[
0
]
/
2
,
logical_hid_shape
[
1
]
/
2
),
shape_kern1
[
2
:],
n_kern1
,
n_batch
,
1
,
1
)
hid
=
tensor
.
tanh
(
conv_op
(
x
,
w0
)
+
b0
)
hid1
=
tensor
.
tanh
(
conv_op1
(
hid
[:,:,::
2
,::
2
],
w1
)
+
b1
)
ds_op
=
theano
.
sandbox
.
downsample
.
DownsampleFactorMax
((
2
,
2
),
ignore_border
=
False
)
hid
=
tensor
.
tanh
(
ds_op
(
conv_op
(
x
,
w0
)
+
b0
))
hid1
=
tensor
.
tanh
(
conv_op1
(
hid
,
w1
)
+
b1
)
hid_flat
=
hid1
.
reshape
((
n_batch
,
n_hid
))
out
=
tensor
.
nnet
.
softmax
(
tensor
.
dot
(
hid_flat
,
v
)
+
c
)
loss
=
tensor
.
sum
(
tensor
.
nnet
.
crossentropy_categorical_1hot
(
out
,
tensor
.
argmax
(
y
,
axis
=
1
))
*
lr
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论