Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
92eef5ed
提交
92eef5ed
authored
10月 11, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 30, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow running JAX functions with scalar inputs for RV shapes
上级
4cdd2905
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
87 行增加
和
10 行删除
+87
-10
linker.py
pytensor/link/jax/linker.py
+38
-1
test_random.py
tests/link/jax/test_random.py
+49
-9
没有找到文件。
pytensor/link/jax/linker.py
浏览文件 @
92eef5ed
...
...
@@ -9,8 +9,13 @@ from pytensor.link.basic import JITLinker
class
JAXLinker
(
JITLinker
):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
scalar_shape_inputs
:
tuple
[
int
]
=
()
# type: ignore[annotation-unchecked]
super
()
.
__init__
(
*
args
,
**
kwargs
)
def
fgraph_convert
(
self
,
fgraph
,
input_storage
,
storage_map
,
**
kwargs
):
from
pytensor.link.jax.dispatch
import
jax_funcify
from
pytensor.link.jax.dispatch.shape
import
JAXShapeTuple
from
pytensor.tensor.random.type
import
RandomType
shared_rng_inputs
=
[
...
...
@@ -64,6 +69,23 @@ class JAXLinker(JITLinker):
fgraph
.
inputs
.
remove
(
new_inp
)
fgraph
.
inputs
.
insert
(
old_inp_fgrap_index
,
new_inp
)
fgraph_inputs
=
fgraph
.
inputs
clients
=
fgraph
.
clients
# Detect scalar shape inputs that are used only in JAXShapeTuple nodes
scalar_shape_inputs
=
[
inp
for
node
in
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
JAXShapeTuple
)
for
inp
in
node
.
inputs
if
inp
in
fgraph_inputs
and
all
(
isinstance
(
cl_node
.
op
,
JAXShapeTuple
)
for
cl_node
,
_
in
clients
[
inp
]
)
]
self
.
scalar_shape_inputs
=
tuple
(
fgraph_inputs
.
index
(
inp
)
for
inp
in
scalar_shape_inputs
)
return
jax_funcify
(
fgraph
,
input_storage
=
input_storage
,
storage_map
=
storage_map
,
**
kwargs
)
...
...
@@ -71,7 +93,22 @@ class JAXLinker(JITLinker):
def
jit_compile
(
self
,
fn
):
import
jax
return
jax
.
jit
(
fn
)
jit_fn
=
jax
.
jit
(
fn
,
static_argnums
=
self
.
scalar_shape_inputs
)
if
not
self
.
scalar_shape_inputs
:
return
jit_fn
def
convert_scalar_shape_inputs
(
*
args
,
scalar_shape_inputs
=
set
(
self
.
scalar_shape_inputs
)
):
return
jit_fn
(
*
(
int
(
arg
)
if
i
in
scalar_shape_inputs
else
arg
for
i
,
arg
in
enumerate
(
args
)
)
)
return
convert_scalar_shape_inputs
def
create_thunk_inputs
(
self
,
storage_map
):
from
pytensor.link.jax.dispatch
import
jax_typify
...
...
tests/link/jax/test_random.py
浏览文件 @
92eef5ed
...
...
@@ -894,15 +894,55 @@ class TestRandomShapeInputs:
jax_fn
=
compile_random_function
([
x_pt
],
out
)
assert
jax_fn
(
np
.
ones
((
2
,
3
)))
.
shape
==
(
2
,)
def
test_random_scalar_shape_input
(
self
):
dim0
=
pt
.
scalar
(
"dim0"
,
dtype
=
int
)
dim1
=
pt
.
scalar
(
"dim1"
,
dtype
=
int
)
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
dim0
)
jax_fn
=
compile_random_function
([
dim0
],
out
)
assert
jax_fn
(
np
.
array
(
2
))
.
shape
==
(
2
,)
assert
jax_fn
(
np
.
array
(
3
))
.
shape
==
(
3
,)
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
[
dim0
,
dim1
])
jax_fn
=
compile_random_function
([
dim0
,
dim1
],
out
)
assert
jax_fn
(
np
.
array
(
2
),
np
.
array
(
3
))
.
shape
==
(
2
,
3
)
assert
jax_fn
(
np
.
array
(
4
),
np
.
array
(
5
))
.
shape
==
(
4
,
5
)
@pytest.mark.xfail
(
r
eason
=
"`size_pt` should be specified as a static argument"
,
strict
=
True
r
aises
=
TypeError
,
reason
=
"Cannot convert scalar input to integer"
)
def
test_random_concrete_shape_graph_input
(
self
):
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
size_pt
=
pt
.
scalar
()
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
size_pt
,
rng
=
rng
)
jax_fn
=
compile_random_function
([
size_pt
],
out
)
assert
jax_fn
(
10
)
.
shape
==
(
10
,)
def
test_random_scalar_shape_input_not_supported
(
self
):
dim
=
pt
.
scalar
(
"dim"
,
dtype
=
int
)
out1
=
pt
.
random
.
normal
(
0
,
1
,
size
=
dim
)
# An operation that wouldn't work if we replaced 0d array by integer
out2
=
dim
[
...
]
.
set
(
1
)
jax_fn
=
compile_random_function
([
dim
],
[
out1
,
out2
])
res1
,
res2
=
jax_fn
(
np
.
array
(
2
))
assert
res1
.
shape
==
(
2
,)
assert
res2
==
1
@pytest.mark.xfail
(
raises
=
TypeError
,
reason
=
"Cannot convert scalar input to integer"
)
def
test_random_scalar_shape_input_not_supported2
(
self
):
dim
=
pt
.
scalar
(
"dim"
,
dtype
=
int
)
# This could theoretically be supported
# but would require knowing that * 2 is a safe operation for a python integer
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
dim
*
2
)
jax_fn
=
compile_random_function
([
dim
],
out
)
assert
jax_fn
(
np
.
array
(
2
))
.
shape
==
(
4
,)
@pytest.mark.xfail
(
raises
=
TypeError
,
reason
=
"Cannot convert tensor input to shape tuple"
)
def
test_random_vector_shape_graph_input
(
self
):
shape
=
pt
.
vector
(
"shape"
,
shape
=
(
2
,),
dtype
=
int
)
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
shape
)
jax_fn
=
compile_random_function
([
shape
],
out
)
assert
jax_fn
(
np
.
array
([
2
,
3
]))
.
shape
==
(
2
,
3
)
assert
jax_fn
(
np
.
array
([
4
,
5
]))
.
shape
==
(
4
,
5
)
def
test_constant_shape_after_graph_rewriting
(
self
):
size
=
pt
.
vector
(
"size"
,
shape
=
(
2
,),
dtype
=
int
)
...
...
@@ -912,13 +952,13 @@ class TestRandomShapeInputs:
with
pytest
.
raises
(
TypeError
):
compile_random_function
([
size
],
x
)([
2
,
5
])
# Rebuild with strict=
Fals
e so output type is not updated
# Rebuild with strict=
Tru
e so output type is not updated
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
new_x
=
clone_replace
(
x
,
{
size
:
pt
.
constant
([
2
,
5
])},
rebuild_strict
=
True
)
assert
new_x
.
type
.
shape
==
(
None
,
None
)
assert
compile_random_function
([],
new_x
)()
.
shape
==
(
2
,
5
)
# Rebuild with strict=
Tru
e, so output type is updated
# Rebuild with strict=
Fals
e, so output type is updated
# This uses a different path in the dispatch implementation
new_x
=
clone_replace
(
x
,
{
size
:
pt
.
constant
([
2
,
5
])},
rebuild_strict
=
False
)
assert
new_x
.
type
.
shape
==
(
2
,
5
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论