Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
613ccaf3
提交
613ccaf3
authored
11月 09, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support Blockwise in JAX backend
上级
a5d54c8e
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
72 行增加
和
0 行删除
+72
-0
__init__.py
pytensor/link/jax/dispatch/__init__.py
+1
-0
blockwise.py
pytensor/link/jax/dispatch/blockwise.py
+29
-0
test_blockwise.py
tests/link/jax/test_blockwise.py
+42
-0
没有找到文件。
pytensor/link/jax/dispatch/__init__.py
浏览文件 @
613ccaf3
...
@@ -13,5 +13,6 @@ import pytensor.link.jax.dispatch.random
...
@@ -13,5 +13,6 @@ import pytensor.link.jax.dispatch.random
import
pytensor.link.jax.dispatch.elemwise
import
pytensor.link.jax.dispatch.elemwise
import
pytensor.link.jax.dispatch.scan
import
pytensor.link.jax.dispatch.scan
import
pytensor.link.jax.dispatch.sparse
import
pytensor.link.jax.dispatch.sparse
import
pytensor.link.jax.dispatch.blockwise
# isort: on
# isort: on
pytensor/link/jax/dispatch/blockwise.py
0 → 100644
浏览文件 @
613ccaf3
import
jax.numpy
as
jnp
from
pytensor.graph
import
FunctionGraph
from
pytensor.link.jax.dispatch
import
jax_funcify
from
pytensor.tensor.blockwise
import
Blockwise
@jax_funcify.register
(
Blockwise
)
def
funcify_Blockwise
(
op
:
Blockwise
,
node
,
*
args
,
**
kwargs
):
signature
=
op
.
signature
core_node
=
op
.
_create_dummy_core_node
(
node
.
inputs
)
core_fgraph
=
FunctionGraph
(
inputs
=
core_node
.
inputs
,
outputs
=
core_node
.
outputs
)
tuple_core_fn
=
jax_funcify
(
core_fgraph
)
if
len
(
node
.
outputs
)
==
1
:
def
core_fn
(
*
inputs
):
return
tuple_core_fn
(
*
inputs
)[
0
]
else
:
core_fn
=
tuple_core_fn
vect_fn
=
jnp
.
vectorize
(
core_fn
,
signature
=
signature
)
def
blockwise_fn
(
*
inputs
):
op
.
_check_runtime_broadcast
(
node
,
inputs
)
return
vect_fn
(
*
inputs
)
return
blockwise_fn
tests/link/jax/test_blockwise.py
0 → 100644
浏览文件 @
613ccaf3
import
numpy
as
np
import
pytest
from
pytensor
import
config
from
pytensor.graph
import
FunctionGraph
from
pytensor.tensor
import
tensor
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.math
import
Dot
,
matmul
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.tensor.test_blockwise
import
check_blockwise_runtime_broadcasting
jax
=
pytest
.
importorskip
(
"jax"
)
def
test_runtime_broadcasting
():
check_blockwise_runtime_broadcasting
(
"JAX"
)
# Equivalent blockwise to matmul but with dumb signature
odd_matmul
=
Blockwise
(
Dot
(),
signature
=
"(i00,i01),(i10,i11)->(o00,o01)"
)
@pytest.mark.parametrize
(
"matmul_op"
,
(
matmul
,
odd_matmul
))
def
test_matmul
(
matmul_op
):
rng
=
np
.
random
.
default_rng
(
14
)
a
=
tensor
(
"a"
,
shape
=
(
2
,
3
,
5
))
b
=
tensor
(
"b"
,
shape
=
(
2
,
5
,
3
))
test_values
=
[
rng
.
normal
(
size
=
(
inp
.
type
.
shape
))
.
astype
(
config
.
floatX
)
for
inp
in
(
a
,
b
)
]
out
=
matmul_op
(
a
,
b
)
assert
isinstance
(
out
.
owner
.
op
,
Blockwise
)
fg
=
FunctionGraph
([
a
,
b
],
[
out
])
fn
,
_
=
compare_jax_and_py
(
fg
,
test_values
)
# Check we are not adding any unnecessary stuff
jaxpr
=
str
(
jax
.
make_jaxpr
(
fn
.
vm
.
jit_fn
)(
*
test_values
))
jaxpr
=
jaxpr
.
replace
(
"name=jax_funcified_fgraph"
,
"name=matmul"
)
expected_jaxpr
=
str
(
jax
.
make_jaxpr
(
jax
.
jit
(
jax
.
numpy
.
matmul
))(
*
test_values
))
assert
jaxpr
==
expected_jaxpr
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论