Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
fa78b21a
提交
fa78b21a
authored
2月 14, 2017
作者:
abergeron
提交者:
GitHub
2月 14, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5451 from aalmah/multi_dtype_multinomial
multiple data type support for GPUMultinomialFromUniform output
上级
11b72eda
dcc43098
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
63 行增加
和
19 行删除
+63
-19
multinomial.py
theano/gpuarray/multinomial.py
+23
-16
test_multinomial.py
theano/gpuarray/tests/test_multinomial.py
+40
-3
没有找到文件。
theano/gpuarray/multinomial.py
浏览文件 @
fa78b21a
...
@@ -18,10 +18,12 @@ from .opt import register_opt, op_lifter, register_opt2
...
@@ -18,10 +18,12 @@ from .opt import register_opt, op_lifter, register_opt2
from
.type
import
GpuArrayType
from
.type
import
GpuArrayType
from
.elemwise
import
GpuDimShuffle
from
.elemwise
import
GpuDimShuffle
from
theano.scalar
import
as_scalar
from
theano.scalar
import
as_scalar
from
.fp16_help
import
write_w
,
load_w
,
work_dtype
class
GPUAMultinomialFromUniform
(
GpuKernelBase
,
Op
):
class
GPUAMultinomialFromUniform
(
GpuKernelBase
,
Op
):
__props__
=
(
"odtype"
,)
__props__
=
(
"odtype"
,)
_f16_ok
=
True
def
__init__
(
self
,
odtype
):
def
__init__
(
self
,
odtype
):
Op
.
__init__
(
self
)
Op
.
__init__
(
self
)
...
@@ -37,10 +39,9 @@ class GPUAMultinomialFromUniform(GpuKernelBase, Op):
...
@@ -37,10 +39,9 @@ class GPUAMultinomialFromUniform(GpuKernelBase, Op):
return
[
os
.
path
.
dirname
(
__file__
)]
return
[
os
.
path
.
dirname
(
__file__
)]
def
make_node
(
self
,
pvals
,
unis
):
def
make_node
(
self
,
pvals
,
unis
):
assert
pvals
.
dtype
==
'float32'
assert
unis
.
dtype
==
pvals
.
dtype
assert
unis
.
dtype
==
'float32'
assert
pvals
.
dtype
in
[
'float32'
,
'float16'
,
'float64'
]
ctx_name
=
infer_context_name
(
pvals
,
unis
)
ctx_name
=
infer_context_name
(
pvals
,
unis
)
pvals
=
as_gpuarray_variable
(
pvals
,
ctx_name
)
pvals
=
as_gpuarray_variable
(
pvals
,
ctx_name
)
unis
=
as_gpuarray_variable
(
unis
,
ctx_name
)
unis
=
as_gpuarray_variable
(
unis
,
ctx_name
)
...
@@ -60,14 +61,19 @@ class GPUAMultinomialFromUniform(GpuKernelBase, Op):
...
@@ -60,14 +61,19 @@ class GPUAMultinomialFromUniform(GpuKernelBase, Op):
return
Apply
(
self
,
[
pvals
,
unis
],
[
out
])
return
Apply
(
self
,
[
pvals
,
unis
],
[
out
])
def
gpu_kernels
(
self
,
node
,
name
):
def
gpu_kernels
(
self
,
node
,
name
):
out_ctype
=
pygpu
.
gpuarray
.
dtype_to_ctype
(
node
.
outputs
[
0
]
.
dtype
)
in_ctype
=
pygpu
.
gpuarray
.
dtype_to_ctype
(
node
.
inputs
[
0
]
.
dtype
)
work_ctype
=
pygpu
.
gpuarray
.
dtype_to_ctype
(
work_dtype
(
node
.
inputs
[
0
]
.
dtype
))
write_out_ctype
=
write_w
(
node
.
outputs
[
0
]
.
dtype
)
load_in_ctype
=
load_w
(
node
.
inputs
[
0
]
.
dtype
)
code
=
"""
code
=
"""
KERNEL void k_multi_warp_multinomial(
KERNEL void k_multi_warp_multinomial(
const ga_size nb_multi,
const ga_size nb_multi,
const ga_size nb_outcomes,
const ga_size nb_outcomes,
GLOBAL_MEM
float
* global_pvals,
GLOBAL_MEM
%(in_ctype)
s
* global_pvals,
const ga_ssize pvals_row_stride,
const ga_ssize pvals_row_stride,
const ga_ssize pvals_col_stride,
const ga_ssize pvals_col_stride,
GLOBAL_MEM
float
* global_unis,
GLOBAL_MEM
%(in_ctype)
s
* global_unis,
const ga_ssize unis_stride,
const ga_ssize unis_stride,
GLOBAL_MEM
%(out_ctype)
s * global_outs,
GLOBAL_MEM
%(out_ctype)
s * global_outs,
const ga_ssize outs_row_stride,
const ga_ssize outs_row_stride,
...
@@ -78,16 +84,15 @@ KERNEL void k_multi_warp_multinomial(
...
@@ -78,16 +84,15 @@ KERNEL void k_multi_warp_multinomial(
int n = LDIM_0*GID_0 + LID_0;
int n = LDIM_0*GID_0 + LID_0;
if (n < nb_multi)
if (n < nb_multi)
{
{
float
cummul = 0.;
%(work_ctype)
s
cummul = 0.;
bool done = false;
bool done = false;
const
float unis_n = global_unis[n*unis_stride]
;
const
%(work_ctype)
s unis_n =
%(load_in_ctype)
s(global_unis[n*unis_stride])
;
for (ga_size m = 0; m < nb_outcomes; ++m)
for (ga_size m = 0; m < nb_outcomes; ++m)
{
{
%(
out
_ctype)
s current_out = 0;
%(
work
_ctype)
s current_out = 0;
if (!done)
if (!done)
{
{
cummul += global_pvals[m * pvals_col_stride +
cummul +=
%(load_in_ctype)
s(global_pvals[m * pvals_col_stride + n * pvals_row_stride]);
n * pvals_row_stride];
if (unis_n < cummul)
if (unis_n < cummul)
{
{
current_out = 1;
current_out = 1;
...
@@ -96,11 +101,12 @@ KERNEL void k_multi_warp_multinomial(
...
@@ -96,11 +101,12 @@ KERNEL void k_multi_warp_multinomial(
}
}
//write out transposed for speed.
//write out transposed for speed.
global_outs[n * outs_col_stride +
global_outs[n * outs_col_stride +
m * outs_row_stride] =
current_out
;
m * outs_row_stride] =
%(write_out_ctype)
s(current_out)
;
}
}
}
}
}
}
"""
%
dict
(
out_ctype
=
pygpu
.
gpuarray
.
dtype_to_ctype
(
node
.
outputs
[
0
]
.
dtype
))
"""
%
dict
(
out_ctype
=
out_ctype
,
write_out_ctype
=
write_out_ctype
,
work_ctype
=
work_ctype
,
in_ctype
=
in_ctype
,
load_in_ctype
=
load_in_ctype
)
return
[
Kernel
(
return
[
Kernel
(
code
=
code
,
name
=
"k_multi_warp_multinomial"
,
code
=
code
,
name
=
"k_multi_warp_multinomial"
,
params
=
[
pygpu
.
gpuarray
.
SIZE
,
params
=
[
pygpu
.
gpuarray
.
SIZE
,
...
@@ -124,6 +130,7 @@ KERNEL void k_multi_warp_multinomial(
...
@@ -124,6 +130,7 @@ KERNEL void k_multi_warp_multinomial(
sync
=
bool
(
config
.
gpuarray
.
sync
)
sync
=
bool
(
config
.
gpuarray
.
sync
)
kname
=
self
.
gpu_kernels
(
node
,
name
)[
0
]
.
objvar
kname
=
self
.
gpu_kernels
(
node
,
name
)[
0
]
.
objvar
out_typecode
=
pygpu
.
gpuarray
.
dtype_to_typecode
(
node
.
outputs
[
0
]
.
dtype
)
out_typecode
=
pygpu
.
gpuarray
.
dtype_to_typecode
(
node
.
outputs
[
0
]
.
dtype
)
in_typecode
=
pygpu
.
gpuarray
.
dtype_to_typecode
(
node
.
inputs
[
0
]
.
dtype
)
s
=
"""
s
=
"""
PyGpuArrayObject * pvals =
%(pvals)
s;
PyGpuArrayObject * pvals =
%(pvals)
s;
PyGpuArrayObject * unis =
%(unis)
s;
PyGpuArrayObject * unis =
%(unis)
s;
...
@@ -187,9 +194,9 @@ KERNEL void k_multi_warp_multinomial(
...
@@ -187,9 +194,9 @@ KERNEL void k_multi_warp_multinomial(
void *args[10];
void *args[10];
ssize_t strides[5] = {
ssize_t strides[5] = {
PyGpuArray_STRIDES(pvals)[0]/
sizeof(float
),
PyGpuArray_STRIDES(pvals)[0]/
gpuarray_get_elsize(
%(in_typecode)
s
),
PyGpuArray_STRIDES(pvals)[1]/
sizeof(float
),
PyGpuArray_STRIDES(pvals)[1]/
gpuarray_get_elsize(
%(in_typecode)
s
),
PyGpuArray_STRIDES(unis)[0]/
sizeof(float
),
PyGpuArray_STRIDES(unis)[0]/
gpuarray_get_elsize(
%(in_typecode)
s
),
PyGpuArray_STRIDES(out)[0]/gpuarray_get_elsize(
%(out_typecode)
s),
PyGpuArray_STRIDES(out)[0]/gpuarray_get_elsize(
%(out_typecode)
s),
PyGpuArray_STRIDES(out)[1]/gpuarray_get_elsize(
%(out_typecode)
s)
PyGpuArray_STRIDES(out)[1]/gpuarray_get_elsize(
%(out_typecode)
s)
};
};
...
@@ -222,7 +229,7 @@ KERNEL void k_multi_warp_multinomial(
...
@@ -222,7 +229,7 @@ KERNEL void k_multi_warp_multinomial(
return
s
return
s
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
2
,)
return
(
3
,)
class
GPUAMultinomialWOReplacementFromUniform
(
GpuKernelBase
,
Op
):
class
GPUAMultinomialWOReplacementFromUniform
(
GpuKernelBase
,
Op
):
...
...
theano/gpuarray/tests/test_multinomial.py
浏览文件 @
fa78b21a
...
@@ -16,15 +16,14 @@ from ..multinomial import (GPUAMultinomialFromUniform,
...
@@ -16,15 +16,14 @@ from ..multinomial import (GPUAMultinomialFromUniform,
GPUAMultinomialWOReplacementFromUniform
)
GPUAMultinomialWOReplacementFromUniform
)
def
test_multinomial_
0
():
def
test_multinomial_
output_dtype
():
# This tests the MultinomialFromUniform Op directly, not going through the
# This tests the MultinomialFromUniform Op directly, not going through the
# multinomial() call in GPU random generation.
# multinomial() call in GPU random generation.
p
=
tensor
.
fmatrix
()
p
=
tensor
.
fmatrix
()
u
=
tensor
.
fvector
()
u
=
tensor
.
fvector
()
for
dtype
in
[
'int64'
,
'float32'
,
'auto'
]:
for
dtype
in
[
'int64'
,
'float32'
,
'float16'
,
'float64'
,
'int32'
,
'auto'
]:
m
=
theano
.
sandbox
.
multinomial
.
MultinomialFromUniform
(
dtype
)(
p
,
u
)
m
=
theano
.
sandbox
.
multinomial
.
MultinomialFromUniform
(
dtype
)(
p
,
u
)
# the m*2 allows the multinomial to reuse output
# the m*2 allows the multinomial to reuse output
...
@@ -52,6 +51,44 @@ def test_multinomial_0():
...
@@ -52,6 +51,44 @@ def test_multinomial_0():
utt
.
assert_allclose
(
r
,
[[
0
,
2
]])
utt
.
assert_allclose
(
r
,
[[
0
,
2
]])
def
test_multinomial_input_dtype
():
# This tests the MultinomialFromUniform Op directly, not going through the
# multinomial() call in GPU random generation.
for
idtype
in
[
'float32'
,
'float16'
,
'float64'
]:
for
odtype
in
[
'float32'
,
'float16'
,
'float64'
,
'int32'
]:
p
=
tensor
.
matrix
(
'p'
,
idtype
)
u
=
tensor
.
vector
(
'u'
,
idtype
)
# p = tensor.dmatrix('p')
# u = tensor.dvector('u')
m
=
theano
.
sandbox
.
multinomial
.
MultinomialFromUniform
(
odtype
)(
p
,
u
)
# the m*2 allows the multinomial to reuse output
f
=
function
([
p
,
u
],
m
*
2
,
allow_input_downcast
=
True
,
mode
=
mode_with_gpu
)
assert
any
([
type
(
node
.
op
)
is
GPUAMultinomialFromUniform
for
node
in
f
.
maker
.
fgraph
.
toposort
()])
# test that both first and second samples can be drawn
utt
.
assert_allclose
(
f
([[
1
,
0
],
[
0
,
1
]],
[
.
1
,
.
1
]),
[[
2
,
0
],
[
0
,
2
]])
# test that both second labels can be drawn
r
=
f
([[
.
2
,
.
8
],
[
.
3
,
.
7
]],
[
.
31
,
.
31
])
utt
.
assert_allclose
(
r
,
[[
0
,
2
],
[
0
,
2
]])
# test that both first labels can be drawn
r
=
f
([[
.
2
,
.
8
],
[
.
3
,
.
7
]],
[
.
21
,
.
21
])
utt
.
assert_allclose
(
r
,
[[
0
,
2
],
[
2
,
0
]])
# change the size to make sure output gets reallocated ok
# and also make sure that the GPU version doesn't screw up the
# transposed-ness
r
=
f
([[
.
2
,
.
8
]],
[
.
25
])
utt
.
assert_allclose
(
r
,
[[
0
,
2
]])
# TODO: check a bigger example (make sure blocking on GPU is handled correctly)
# TODO: check a bigger example (make sure blocking on GPU is handled correctly)
def
test_multinomial_large
():
def
test_multinomial_large
():
# DEBUG_MODE will test this on GPU
# DEBUG_MODE will test this on GPU
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论