Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c690c8a2
提交
c690c8a2
authored
3月 16, 2011
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
sandbox/multinomial - changes,fixes,extensions to CPU and GPU code, and new tests
上级
400dc78b
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
83 行增加
和
5 行删除
+83
-5
rng_mrg.py
theano/sandbox/rng_mrg.py
+9
-5
test_multinomial.py
theano/sandbox/test_multinomial.py
+74
-0
没有找到文件。
theano/sandbox/rng_mrg.py
浏览文件 @
c690c8a2
...
@@ -727,8 +727,7 @@ class MRG_RandomStreams(object):
...
@@ -727,8 +727,7 @@ class MRG_RandomStreams(object):
else
:
else
:
raise
NotImplementedError
(
"MRG_RandomStreams.binomial with n > 1"
)
raise
NotImplementedError
(
"MRG_RandomStreams.binomial with n > 1"
)
def
multinomial
(
self
,
size
=
None
,
n
=
1
,
pvals
=
None
,
ndim
=
None
,
dtype
=
'int64'
,
def
multinomial
(
self
,
size
=
None
,
n
=
1
,
pvals
=
None
,
ndim
=
None
,
dtype
=
'int64'
):
n_unis
=
None
):
"""
"""
Sample `n` (currently `n` needs to be 1) times from a multinomial
Sample `n` (currently `n` needs to be 1) times from a multinomial
distribution defined by probabilities pvals.
distribution defined by probabilities pvals.
...
@@ -745,10 +744,15 @@ class MRG_RandomStreams(object):
...
@@ -745,10 +744,15 @@ class MRG_RandomStreams(object):
raise
TypeError
(
"You have to specify pvals"
)
raise
TypeError
(
"You have to specify pvals"
)
pvals
=
as_tensor_variable
(
pvals
)
pvals
=
as_tensor_variable
(
pvals
)
if
n
==
1
and
pvals
.
ndim
==
2
:
if
n
==
1
and
pvals
.
ndim
==
2
:
unis
=
self
.
uniform
(
size
=
pvals
.
shape
[
0
:
1
],
ndim
=
1
)
ndim
,
size
,
bcast
=
raw_random
.
_infer_ndim_bcast
(
return
cast
(
multinomial
(
pvals
.
T
,
unis
)
.
T
,
dtype
)
ndim
,
size
,
n
,
pvals
[:,
0
])
bcast
=
bcast
+
(
pvals
.
type
.
broadcastable
[
-
1
],)
unis
=
self
.
uniform
(
size
=
size
,
ndim
=
1
)
op
=
multinomial
.
Multinomial
(
dtype
)
return
op
(
pvals
,
unis
)
else
:
else
:
raise
NotImplementedError
(
"MRG_RandomStreams.multinomial only implemented with n == 1 and pvals.ndim = 2"
)
raise
NotImplementedError
((
"MRG_RandomStreams.multinomial only"
" implemented with n == 1 and pvals.ndim = 2"
))
def
normal
(
self
,
size
=
None
,
avg
=
0.0
,
std
=
1.0
,
ndim
=
None
,
dtype
=
config
.
floatX
):
def
normal
(
self
,
size
=
None
,
avg
=
0.0
,
std
=
1.0
,
ndim
=
None
,
dtype
=
config
.
floatX
):
"""
"""
...
...
theano/sandbox/test_multinomial.py
0 → 100644
浏览文件 @
c690c8a2
import
numpy
from
theano
import
tensor
,
shared
,
function
import
multinomial
def
test_multimomial_0
():
# This tests the multinomial Op directly, not going through the
# multinomial() call in GPU random generation.
p
=
tensor
.
matrix
()
u
=
tensor
.
vector
()
m
=
multinomial
.
Multinomial
(
'auto'
)(
p
,
u
)
#the m*2 allows the multinomial to reuse output
f
=
function
([
p
,
u
],
m
*
2
,
allow_input_downcast
=
True
)
# test that both first and second samples can be drawn
assert
numpy
.
allclose
(
f
([[
1
,
0
],
[
0
,
1
]],
[
.
1
,
.
1
]),
[[
2
,
0
],
[
0
,
2
]])
# test that both second labels can be drawn
r
=
f
([[
.
2
,
.
8
],
[
.
3
,
.
7
]],
[
.
31
,
.
31
])
assert
numpy
.
allclose
(
r
,
[[
0
,
2
],
[
0
,
2
]]),
r
# test that both first labels can be drawn
r
=
f
([[
.
2
,
.
8
],
[
.
3
,
.
7
]],
[
.
21
,
.
21
])
assert
numpy
.
allclose
(
r
,
[[
0
,
2
],
[
2
,
0
]]),
r
#change the size to make sure output gets reallocated ok
# and also make sure that the GPU version doesn't screw up the
# transposed-ness
r
=
f
([[
.
2
,
.
8
]
],
[
.
25
])
assert
numpy
.
allclose
(
r
,
[[
0
,
2
]]),
r
#TODO: check a bigger example (make sure blocking on GPU is handled correctly)
def
test_multinomial_large
():
# DEBUG_MODE will test this on GPU
p
=
tensor
.
fmatrix
()
u
=
tensor
.
fvector
()
m
=
multinomial
.
Multinomial
(
'auto'
)(
p
,
u
)
f
=
function
([
p
,
u
],
m
*
2
,
allow_input_downcast
=
True
)
pval
=
numpy
.
arange
(
10000
*
4
,
dtype
=
'float32'
)
.
reshape
((
10000
,
4
))
+
0.1
pval
=
pval
/
pval
.
sum
(
axis
=
1
)[:,
None
]
uval
=
numpy
.
ones_like
(
pval
[:,
0
])
*
0.5
mval
=
f
(
pval
,
uval
)
assert
mval
.
shape
==
pval
.
shape
assert
mval
.
dtype
==
pval
.
dtype
assert
numpy
.
allclose
(
mval
.
sum
(
axis
=
1
),
2
)
asdf
=
numpy
.
asarray
([
0
,
0
,
2
,
0
])
+
0
*
pval
assert
numpy
.
allclose
(
mval
,
asdf
)
#broadcast over all rows
def
test_multinomial_dtypes
():
p
=
tensor
.
dmatrix
()
u
=
tensor
.
dvector
()
m
=
multinomial
.
Multinomial
(
'auto'
)(
p
,
u
)
assert
m
.
dtype
==
'float64'
,
m
.
dtype
p
=
tensor
.
fmatrix
()
u
=
tensor
.
fvector
()
m
=
multinomial
.
Multinomial
(
'auto'
)(
p
,
u
)
assert
m
.
dtype
==
'float32'
,
m
.
dtype
p
=
tensor
.
fmatrix
()
u
=
tensor
.
fvector
()
m
=
multinomial
.
Multinomial
(
'float64'
)(
p
,
u
)
assert
m
.
dtype
==
'float64'
,
m
.
dtype
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论