Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8a7356ce
Unverified
提交
8a7356ce
authored
3月 25, 2025
作者:
Etienne Duchesne
提交者:
GitHub
3月 25, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement faster Multinomial JAX dispatch (#1316)
上级
2e9d502f
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
48 行增加
和
13 行删除
+48
-13
random.py
pytensor/link/jax/dispatch/random.py
+29
-9
test_random.py
tests/link/jax/test_random.py
+19
-4
没有找到文件。
pytensor/link/jax/dispatch/random.py
浏览文件 @
8a7356ce
from
functools
import
singledispatch
from
functools
import
singledispatch
import
jax
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
import
numpy
as
np
from
numpy.random
import
Generator
from
numpy.random
import
Generator
from
numpy.random.bit_generator
import
(
# type: ignore[attr-defined]
from
numpy.random.bit_generator
import
(
# type: ignore[attr-defined]
...
@@ -394,16 +395,35 @@ def jax_sample_fn_binomial(op, node):
...
@@ -394,16 +395,35 @@ def jax_sample_fn_binomial(op, node):
@jax_sample_fn.register
(
ptr
.
MultinomialRV
)
@jax_sample_fn.register
(
ptr
.
MultinomialRV
)
def
jax_sample_fn_multinomial
(
op
,
node
):
def
jax_sample_fn_multinomial
(
op
,
node
):
if
not
numpyro_available
:
raise
NotImplementedError
(
f
"No JAX implementation for the given distribution: {op.name}. "
"Implementation is available if NumPyro is installed."
)
from
numpyro.distributions.util
import
multinomial
def
sample_fn
(
rng_key
,
size
,
dtype
,
n
,
p
):
def
sample_fn
(
rng_key
,
size
,
dtype
,
n
,
p
):
sample
=
multinomial
(
key
=
rng_key
,
n
=
n
,
p
=
p
,
shape
=
size
)
if
size
is
not
None
:
n
=
jnp
.
broadcast_to
(
n
,
size
)
p
=
jnp
.
broadcast_to
(
p
,
size
+
jnp
.
shape
(
p
)[
-
1
:])
else
:
broadcast_shape
=
jax
.
lax
.
broadcast_shapes
(
jnp
.
shape
(
n
),
jnp
.
shape
(
p
)[:
-
1
])
n
=
jnp
.
broadcast_to
(
n
,
broadcast_shape
)
p
=
jnp
.
broadcast_to
(
p
,
broadcast_shape
+
jnp
.
shape
(
p
)[
-
1
:])
binom_p
=
jnp
.
moveaxis
(
p
,
-
1
,
0
)[:
-
1
,
...
]
sampling_rng
=
jax
.
random
.
split
(
rng_key
,
binom_p
.
shape
[
0
])
def
_binomial_sample_fn
(
carry
,
p_rng
):
s
,
rho
=
carry
p
,
rng
=
p_rng
samples
=
jax
.
random
.
binomial
(
rng
,
s
,
p
/
rho
)
s
=
s
-
samples
rho
=
rho
-
p
return
((
s
,
rho
),
samples
)
(
remain
,
_
),
samples
=
jax
.
lax
.
scan
(
_binomial_sample_fn
,
(
n
.
astype
(
np
.
float64
),
jnp
.
ones
(
binom_p
.
shape
[
1
:])),
(
binom_p
,
sampling_rng
),
)
sample
=
jnp
.
concatenate
(
[
jnp
.
moveaxis
(
samples
,
0
,
-
1
),
jnp
.
expand_dims
(
remain
,
-
1
)],
axis
=-
1
)
return
sample
return
sample
return
sample_fn
return
sample_fn
...
...
tests/link/jax/test_random.py
浏览文件 @
8a7356ce
...
@@ -703,14 +703,15 @@ def test_beta_binomial():
...
@@ -703,14 +703,15 @@ def test_beta_binomial():
)
)
@pytest.mark.skipif
(
not
numpyro_available
,
reason
=
"Multinomial dispatch requires numpyro"
)
def
test_multinomial
():
def
test_multinomial
():
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
# test with 'size' argument and n.shape == p.shape[:-1]
n
=
np
.
array
([
10
,
40
])
n
=
np
.
array
([
10
,
40
])
p
=
np
.
array
([[
0.3
,
0.7
,
0.0
],
[
0.1
,
0.4
,
0.5
]])
p
=
np
.
array
([[
0.3
,
0.7
,
0.0
],
[
0.1
,
0.4
,
0.5
]])
g
=
pt
.
random
.
multinomial
(
n
,
p
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
size
=
(
10
_000
,
2
)
g
=
pt
.
random
.
multinomial
(
n
,
p
,
size
=
size
,
rng
=
rng
)
g_fn
=
compile_random_function
([],
g
,
mode
=
"JAX"
)
g_fn
=
compile_random_function
([],
g
,
mode
=
"JAX"
)
samples
=
g_fn
()
samples
=
g_fn
()
np
.
testing
.
assert_allclose
(
samples
.
mean
(
axis
=
0
),
n
[
...
,
None
]
*
p
,
rtol
=
0.1
)
np
.
testing
.
assert_allclose
(
samples
.
mean
(
axis
=
0
),
n
[
...
,
None
]
*
p
,
rtol
=
0.1
)
...
@@ -718,6 +719,20 @@ def test_multinomial():
...
@@ -718,6 +719,20 @@ def test_multinomial():
samples
.
std
(
axis
=
0
),
np
.
sqrt
(
n
[
...
,
None
]
*
p
*
(
1
-
p
)),
rtol
=
0.1
samples
.
std
(
axis
=
0
),
np
.
sqrt
(
n
[
...
,
None
]
*
p
*
(
1
-
p
)),
rtol
=
0.1
)
)
# test with no 'size' argument and no static shape
n
=
np
.
broadcast_to
(
np
.
array
([
10
,
40
]),
size
)
p
=
np
.
array
([[
0.3
,
0.7
,
0.0
],
[
0.1
,
0.4
,
0.5
]])
pt_n
=
pt
.
matrix
(
"n"
)
pt_p
=
pt
.
matrix
(
"p"
)
g
=
pt
.
random
.
multinomial
(
pt_n
,
pt_p
,
rng
=
rng
,
size
=
None
)
g_fn
=
compile_random_function
([
pt_n
,
pt_p
],
g
,
mode
=
"JAX"
)
samples
=
g_fn
(
n
,
p
)
np
.
testing
.
assert_allclose
(
samples
.
mean
(
axis
=
0
),
n
[
0
,
:,
None
]
*
p
,
rtol
=
0.1
)
np
.
testing
.
assert_allclose
(
samples
.
std
(
axis
=
0
),
np
.
sqrt
(
n
[
0
,
:,
None
]
*
p
*
(
1
-
p
)),
rtol
=
0.1
)
@pytest.mark.skipif
(
not
numpyro_available
,
reason
=
"VonMises dispatch requires numpyro"
)
@pytest.mark.skipif
(
not
numpyro_available
,
reason
=
"VonMises dispatch requires numpyro"
)
def
test_vonmises_mu_outside_circle
():
def
test_vonmises_mu_outside_circle
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论