Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
bb028ae2
提交
bb028ae2
authored
5月 17, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 22, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Inline static size inputs in JAX implementation of RandomVariables
This gets around some limitations in JAX jitting system
上级
863efc01
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
42 行增加
和
6 行删除
+42
-6
random.py
pytensor/link/jax/dispatch/random.py
+18
-4
test_random.py
tests/link/jax/test_random.py
+24
-2
没有找到文件。
pytensor/link/jax/dispatch/random.py
浏览文件 @
bb028ae2
...
...
@@ -8,6 +8,7 @@ from numpy.random.bit_generator import ( # type: ignore[attr-defined]
)
import
pytensor.tensor.random.basic
as
ptr
from
pytensor.graph
import
Constant
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
,
jax_typify
from
pytensor.link.jax.dispatch.shape
import
JAXShapeTuple
from
pytensor.tensor.shape
import
Shape
,
Shape_i
...
...
@@ -91,15 +92,26 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
"""JAX implementation of random variables."""
rv
=
node
.
outputs
[
1
]
out_dtype
=
rv
.
type
.
dtype
out_siz
e
=
rv
.
type
.
shape
static_shap
e
=
rv
.
type
.
shape
batch_ndim
=
op
.
batch_ndim
(
node
)
out_size
=
node
.
default_output
()
.
type
.
shape
[:
batch_ndim
]
# Try to pass static size directly to JAX
static_size
=
static_shape
[:
batch_ndim
]
if
None
in
static_size
:
# Sometimes size can be constant folded during rewrites,
# without the RandomVariable node being updated with new static types
size_param
=
node
.
inputs
[
1
]
if
isinstance
(
size_param
,
Constant
):
size_tuple
=
tuple
(
size_param
.
data
)
# PyTensor uses empty size to represent size = None
if
len
(
size_tuple
):
static_size
=
tuple
(
size_param
.
data
)
# If one dimension has unknown size, either the size is determined
# by a `Shape` operator in which case JAX will compile, or it is
# not and we fail gracefully.
if
None
in
out
_size
:
if
None
in
static
_size
:
assert_size_argument_jax_compatible
(
node
)
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
...
...
@@ -111,7 +123,9 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
else
:
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
return
jax_sample_fn
(
op
,
node
=
node
)(
rng
,
out_size
,
out_dtype
,
*
parameters
)
return
jax_sample_fn
(
op
,
node
=
node
)(
rng
,
static_size
,
out_dtype
,
*
parameters
)
return
sample_fn
...
...
tests/link/jax/test_random.py
浏览文件 @
bb028ae2
...
...
@@ -5,6 +5,7 @@ import scipy.stats as stats
import
pytensor
import
pytensor.tensor
as
pt
import
pytensor.tensor.random.basic
as
ptr
from
pytensor
import
clone_replace
from
pytensor.compile.function
import
function
from
pytensor.compile.sharedvalue
import
SharedVariable
,
shared
from
pytensor.graph.basic
import
Constant
...
...
@@ -26,11 +27,11 @@ jax = pytest.importorskip("jax")
from
pytensor.link.jax.dispatch.random
import
numpyro_available
# noqa: E402
def
compile_random_function
(
*
args
,
**
kwargs
):
def
compile_random_function
(
*
args
,
mode
=
"JAX"
,
**
kwargs
):
with
pytest
.
warns
(
UserWarning
,
match
=
r"The RandomType SharedVariables \[.+\] will not be used"
):
return
function
(
*
args
,
**
kwargs
)
return
function
(
*
args
,
mode
=
mode
,
**
kwargs
)
def
test_random_RandomStream
():
...
...
@@ -896,3 +897,24 @@ def test_random_concrete_shape_graph_input():
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
size_pt
,
rng
=
rng
)
jax_fn
=
compile_random_function
([
size_pt
],
out
,
mode
=
jax_mode
)
assert
jax_fn
(
10
)
.
shape
==
(
10
,)
def
test_constant_shape_after_graph_rewriting
():
size
=
pt
.
vector
(
"size"
,
shape
=
(
2
,),
dtype
=
int
)
x
=
pt
.
random
.
normal
(
size
=
size
)
assert
x
.
type
.
shape
==
(
None
,
None
)
with
pytest
.
raises
(
TypeError
):
compile_random_function
([
size
],
x
)([
2
,
5
])
# Rebuild with strict=False so output type is not updated
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
new_x
=
clone_replace
(
x
,
{
size
:
pt
.
constant
([
2
,
5
])},
rebuild_strict
=
True
)
assert
new_x
.
type
.
shape
==
(
None
,
None
)
assert
compile_random_function
([],
new_x
)()
.
shape
==
(
2
,
5
)
# Rebuild with strict=True, so output type is updated
# This uses a different path in the dispatch implementation
new_x
=
clone_replace
(
x
,
{
size
:
pt
.
constant
([
2
,
5
])},
rebuild_strict
=
False
)
assert
new_x
.
type
.
shape
==
(
2
,
5
)
assert
compile_random_function
([],
new_x
)()
.
shape
==
(
2
,
5
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论