Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
85447abe
提交
85447abe
authored
6月 28, 2010
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
差异文件
merge
上级
13b8fb68
7ae6897c
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
70 行增加
和
39 行删除
+70
-39
index.txt
doc/index.txt
+6
-0
nnet.py
theano/sandbox/cuda/nnet.py
+24
-17
test_mlp.py
theano/sandbox/cuda/tests/test_mlp.py
+0
-0
test_nnet.py
theano/sandbox/cuda/tests/test_nnet.py
+0
-0
type.py
theano/sandbox/cuda/type.py
+3
-1
multinomial.py
theano/sandbox/multinomial.py
+29
-13
rng_mrg.py
theano/sandbox/rng_mrg.py
+3
-4
test_rng_mrg.py
theano/sandbox/test_rng_mrg.py
+5
-4
没有找到文件。
doc/index.txt
浏览文件 @
85447abe
...
...
@@ -52,6 +52,10 @@ Community
* Register and post to `theano-dev`_ if you want to talk to the developers.
* Register and post to `theano-announce`_ if you want to be keep informed on important change on theano(low volume).
* Register and post to `theano-buildbot`_ if you want to receive our daily buildbot email.
* We try to stay organized with `Theano's Trac <http://trac-hg.assembla.com/theano/report/1>`__
* Come visit us in Montreal! Most of the developers are students in the LISA_ group at the `University of Montreal`_.
...
...
@@ -77,6 +81,8 @@ Community
.. _theano-dev: http://groups.google.com/group/theano-dev
.. _theano-users: http://groups.google.com/group/theano-users
.. _theano-announce: http://groups.google.com/group/theano-announce
.. _theano-buildbot: http://groups.google.com/group/theano-buildbot
.. _tickets: http://pylearn.org/theano/trac/query?status=accepted&status=assigned&status=new&status=reopened&group=milestone&max=200&col=id&col=summary&col=status&col=owner&col=type&col=priority&col=component&col=time&report=9&order=priority
.. _LISA: http://www.iro.umontreal.ca/~lisa
...
...
theano/sandbox/cuda/nnet.py
浏览文件 @
85447abe
...
...
@@ -188,7 +188,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx (Op):
def
make_node
(
self
,
dy
,
sm
,
y_idx
):
return
Apply
(
self
,
[
dy
,
sm
,
y_idx
],[
sm
.
type
()])
def
c_code_cache_version
(
self
):
return
(
2
,)
return
(
3
,)
#return ()
def
c_code
(
self
,
node
,
nodename
,
(
dnll
,
sm
,
y_idx
),
(
dx
,),
sub
):
fail
=
sub
[
'fail'
]
...
...
@@ -229,7 +229,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx (Op):
kCrossEntropySoftmax1HotWithBiasDx_
%(nodename)
s
<<<
CudaNdarray_HOST_DIMS(
%(dx)
s)[0],
CudaNdarray_HOST_DIMS(
%(dx)
s)[1]
std::min(CudaNdarray_HOST_DIMS(
%(dx)
s)[1],256)
>>>(
CudaNdarray_HOST_DIMS(
%(dx)
s)[0],
CudaNdarray_HOST_DIMS(
%(dx)
s)[1],
...
...
@@ -303,7 +303,7 @@ class GpuSoftmax (Op):
return
shape
def
c_code_cache_version
(
self
):
#return ()
return
(
1
,)
+
inline_softmax
.
code_version
return
(
2
,)
+
inline_softmax
.
code_version
def
c_code
(
self
,
node
,
nodename
,
(
x
,),
(
z
,),
sub
):
fail
=
sub
[
'fail'
]
return
"""
...
...
@@ -330,7 +330,7 @@ class GpuSoftmax (Op):
kSoftmax_
%(nodename)
s
<<<
// todo: cap these at the card limits, implement loops in kernel
CudaNdarray_HOST_DIMS(
%(x)
s)[0]
,
std::min(CudaNdarray_HOST_DIMS(
%(x)
s)[0],32*1024)
,
CudaNdarray_HOST_DIMS(
%(x)
s)[1],
CudaNdarray_HOST_DIMS(
%(x)
s)[1] * 2 * sizeof(float)
>>>(
...
...
@@ -362,11 +362,14 @@ class GpuSoftmax (Op):
body
=
[
"extern __shared__ float buf[]"
,
"float * buf2 = buf + N"
,
"buf[threadIdx.x] = x[blockIdx.x * sx0 + threadIdx.x * sx1]"
,
"buf2[threadIdx.x] = buf[threadIdx.x]"
,
"__syncthreads()"
,
inline_softmax
(
'N'
,
'buf'
,
'buf2'
,
'threadIdx.x'
,
'blockDim.x'
),
"sm[blockIdx.x * N + threadIdx.x] = buf[threadIdx.x]"
"for (int blockIDX = blockIdx.x; blockIDX < M; blockIDX += gridDim.x){"
,
"buf[threadIdx.x] = x[blockIDX * sx0 + threadIdx.x * sx1]"
,
"buf2[threadIdx.x] = buf[threadIdx.x]"
,
"__syncthreads()"
,
inline_softmax
(
'N'
,
'buf'
,
'buf2'
,
'threadIdx.x'
,
'blockDim.x'
),
"sm[blockIDX * N + threadIdx.x] = buf[threadIdx.x]"
,
"__syncthreads()"
,
"}"
,
])
...
...
@@ -386,7 +389,7 @@ class GpuSoftmaxWithBias (Op):
return
[
shape
[
0
]]
def
c_code_cache_version
(
self
):
#return ()
return
(
1
,)
+
inline_softmax
.
code_version
return
(
2
,)
+
inline_softmax
.
code_version
def
c_code
(
self
,
node
,
nodename
,
(
x
,
b
),
(
z
,),
sub
):
fail
=
sub
[
'fail'
]
...
...
@@ -425,7 +428,7 @@ class GpuSoftmaxWithBias (Op):
kSoftmaxWithBias_
%(nodename)
s
<<<
// todo: cap these at the card limits, implement loops in kernel
CudaNdarray_HOST_DIMS(
%(x)
s)[0]
,
std::min(CudaNdarray_HOST_DIMS(
%(x)
s)[0],32*1024)
,
CudaNdarray_HOST_DIMS(
%(x)
s)[1],
CudaNdarray_HOST_DIMS(
%(x)
s)[1] * 2 * sizeof(float)
>>>(
...
...
@@ -461,10 +464,14 @@ class GpuSoftmaxWithBias (Op):
body
=
[
"extern __shared__ float buf[]"
,
"float * buf2 = buf + N"
,
"buf[threadIdx.x] = x[blockIdx.x * sx0 + threadIdx.x * sx1]"
,
"buf[threadIdx.x] += b[threadIdx.x * sb0]"
,
"buf2[threadIdx.x] = buf[threadIdx.x]"
,
"__syncthreads()"
,
inline_softmax
(
'N'
,
'buf'
,
'buf2'
,
'threadIdx.x'
,
'blockDim.x'
),
"sm[blockIdx.x * N + threadIdx.x] = buf[threadIdx.x]"
"for (int blockIDX = blockIdx.x; blockIDX < M; blockIDX += gridDim.x){"
,
"buf[threadIdx.x] = x[blockIDX * sx0 + threadIdx.x * sx1]"
,
"buf[threadIdx.x] += b[threadIdx.x * sb0]"
,
"buf2[threadIdx.x] = buf[threadIdx.x]"
,
"__syncthreads()"
,
inline_softmax
(
'N'
,
'buf'
,
'buf2'
,
'threadIdx.x'
,
'blockDim.x'
),
"sm[blockIDX * N + threadIdx.x] = buf[threadIdx.x]"
,
"__syncthreads()"
,
"}"
,
])
#for (int i = blockIdx.x; i < N; i += gridDim.x)
theano/sandbox/cuda/tests/test_mlp.py
0 → 100644
浏览文件 @
85447abe
差异被折叠。
点击展开。
theano/sandbox/cuda/tests/test_nnet.py
浏览文件 @
85447abe
差异被折叠。
点击展开。
theano/sandbox/cuda/type.py
浏览文件 @
85447abe
...
...
@@ -254,7 +254,9 @@ class CudaNdarrayType(Type):
return
ret
def
c_libraries
(
self
):
return
[
'cudart'
]
# returning cublas because the cuda_ndarray.cuh header includes calls to SetVector and
# cublasGetError
return
[
'cudart'
,
'cublas'
]
def
c_support_code
(
cls
):
return
""
...
...
theano/sandbox/multinomial.py
浏览文件 @
85447abe
...
...
@@ -4,7 +4,7 @@ import theano.tensor as T
from
theano.tensor.opt
import
register_specialize
from
theano.gof
import
local_optimizer
from
theano.sandbox.cuda
import
cuda_available
from
theano.sandbox.cuda
import
cuda_available
,
cuda_enabled
if
cuda_available
:
from
theano.sandbox.cuda
import
CudaNdarrayType
from
theano.sandbox.cuda.basic_ops
import
host_from_gpu
,
gpu_from_host
...
...
@@ -109,12 +109,11 @@ class GpuMultinomial(Multinomial):
raise
TypeError
(
'pvals must be cudandarray'
,
pvals
)
if
not
isinstance
(
unis
.
type
,
CudaNdarrayType
):
raise
TypeError
(
'unis must be cudandarray'
,
unis
)
return
Apply
(
self
,
[
pvals
,
unis
],
[
pvals
.
type
()])
def
c_code_cache_version
(
self
):
#
return ()
return
(
super
(
GpuMultinomial
,
self
)
.
c_code_cache_version
(),
1
)
return
()
#
return (super(GpuMultinomial,self).c_code_cache_version(),1)
def
c_support_code_apply
(
self
,
node
,
nodename
):
return
"""
...
...
@@ -128,7 +127,7 @@ class GpuMultinomial(Multinomial):
float * global_outs
)
{
int n =
32
*blockIdx.x + threadIdx.x;
int n =
blockDim.x
*blockIdx.x + threadIdx.x;
if (n < nb_multi)
{
...
...
@@ -201,14 +200,31 @@ class GpuMultinomial(Multinomial):
int nb_outcomes = CudaNdarray_HOST_DIMS(
%(z)
s)[0];
int nb_multi = CudaNdarray_HOST_DIMS(
%(z)
s)[1];
int nb_block;
if (nb_multi
%% 32
== 0)
nb_block = nb_multi/32;
else
nb_block = (int)((float)nb_multi/32. + 1.);
//TODO : change this for a beautiful constant
int max_nb_blocks = 2<<15 - 1;
int nb_blocks = max_nb_blocks + 1;
int nb_threads=16; // so it really starts at 32, because of the *2
do
{
nb_threads*=2;
if (nb_multi
%%
nb_threads == 0)
nb_blocks = nb_multi/nb_threads;
else
nb_blocks = (int)((float)nb_multi/(float)nb_threads + 1.);
} while (nb_blocks > max_nb_blocks);
//printf("
\\
nN=
%%
i b=
%%
i t=
%%
i t*b=
%%
i", nb_multi, nb_blocks, nb_threads, nb_blocks*nb_threads);
// TODO : next line is a bit hardcoded...
if (nb_threads > 512)
{
PyErr_Format(PyExc_ValueError, "Mutinomial is not implemented for as many rows in the matrix (
%%
i)", nb_multi);
%(fail)
s;
}
dim3 n_blocks(nb_block,1,1);
dim3 n_threads(
32
,1,1);
dim3 n_blocks(nb_block
s
,1,1);
dim3 n_threads(
nb_threads
,1,1);
int n_shared = 0;
k_multi_warp_
%(name)
s<<<n_blocks, n_threads, n_shared>>>(
...
...
@@ -244,6 +260,6 @@ gpu_multinomial = GpuMultinomial()
def
use_gpu_multinomial
(
node
):
if
node
.
op
==
multinomial
:
return
[
host_from_gpu
(
gpu_multinomial
(
*
[
gpu_from_host
(
i
)
for
i
in
node
.
inputs
]))]
if
theano
.
config
.
device
.
startswith
(
'gpu'
):
if
cuda_enabled
:
#
theano.config.device.startswith('gpu'):
register_specialize
(
use_gpu_multinomial
)
theano/sandbox/rng_mrg.py
浏览文件 @
85447abe
...
...
@@ -685,7 +685,7 @@ class MRG_RandomStreams(object):
else
:
raise
NotImplementedError
(
"MRG_RandomStreams.binomial with n > 1"
)
def
multinomial
(
self
,
size
=
None
,
n
=
1
,
pvals
=
[[
.
5
,
.
5
]]
,
ndim
=
None
,
dtype
=
'int64'
):
def
multinomial
(
self
,
size
=
None
,
n
=
1
,
pvals
=
None
,
ndim
=
None
,
dtype
=
'int64'
):
"""
Sample `n` (currently `n` needs to be 1) times from a multinomial distribution defined by
probabilities pvals.
...
...
@@ -696,13 +696,12 @@ class MRG_RandomStreams(object):
`size` and `ndim` are only there keep the same signature as other uniform, binomial, normal, etc.
todo : adapt multinomial to take that into account
"""
if
pvals
is
None
:
raise
TypeError
(
"You have to specify pvals"
)
pvals
=
as_tensor_variable
(
pvals
)
if
n
==
1
and
pvals
.
ndim
==
2
:
pvals
=
as_tensor_variable
(
pvals
)
unis
=
self
.
uniform
(
size
=
pvals
.
shape
[
0
:
1
],
ndim
=
1
)
return
cast
(
multinomial
(
pvals
.
T
,
unis
)
.
T
,
dtype
)
else
:
raise
NotImplementedError
(
"MRG_RandomStreams.multinomial only implemented with n == 1 and pvals.ndim = 2"
)
...
...
theano/sandbox/test_rng_mrg.py
浏览文件 @
85447abe
...
...
@@ -345,7 +345,7 @@ def test_uniform():
#print 'random?[-1,-10:]\n', out[-1,-10:]
basictest
(
f
,
steps
,
sample_size
,
prefix
=
'mrg cpu'
,
inputs
=
input
)
if
mode
!=
'FAST_COMPILE'
:
if
mode
!=
'FAST_COMPILE'
and
cuda_available
:
print
''
print
'ON GPU with size=(
%
s):'
%
str
(
size
)
R
=
MRG_RandomStreams
(
234
,
use_cuda
=
True
)
...
...
@@ -403,7 +403,7 @@ def test_binomial():
print
'random?[-1,-10:]
\n
'
,
out
[
-
1
,
-
10
:]
basictest
(
f
,
steps
,
sample_size
,
prefix
=
'mrg cpu'
,
inputs
=
input
,
allow_01
=
True
,
target_avg
=
mean
)
if
mode
!=
'FAST_COMPILE'
:
if
mode
!=
'FAST_COMPILE'
and
cuda_available
:
print
''
print
'ON GPU with size=(
%
s) and mean(
%
d):'
%
(
str
(
size
),
mean
)
R
=
MRG_RandomStreams
(
234
,
use_cuda
=
True
)
...
...
@@ -450,7 +450,7 @@ def test_normal0():
# now with odd number of samples
sample_size
=
(
sample_size
[
0
],
sample_size
[
1
]
-
1
)
if
mode
!=
'FAST_COMPILE'
:
if
mode
!=
'FAST_COMPILE'
and
cuda_available
:
print
''
print
'ON GPU:'
R
=
MRG_RandomStreams
(
234
,
use_cuda
=
True
)
...
...
@@ -465,7 +465,7 @@ def test_normal0():
print
'random?[:10]
\n
'
,
numpy
.
asarray
(
f
())[
0
,
0
:
10
]
print
'----'
sys
.
stdout
.
flush
()
basictest
(
f
,
steps
,
sample_size
_odd
,
target_avg
=-
5.0
,
target_std
=
2.0
,
prefix
=
'gpu mrg '
,
allow_01
=
True
)
basictest
(
f
,
steps
,
sample_size
,
target_avg
=-
5.0
,
target_std
=
2.0
,
prefix
=
'gpu mrg '
,
allow_01
=
True
)
print
''
...
...
@@ -528,6 +528,7 @@ def test_multinomial():
print
''
print
'ON GPU:'
R
=
MRG_RandomStreams
(
234
,
use_cuda
=
True
)
pvals
=
numpy
.
asarray
(
pvals
,
dtype
=
'float32'
)
n
=
R
.
multinomial
(
pvals
=
pvals
,
dtype
=
'float32'
)
assert
n
.
dtype
==
'float32'
#well, it's really that this test w GPU doesn't make sense otw
f
=
theano
.
function
([],
theano
.
Out
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论