Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
37474223
提交
37474223
authored
12月 08, 2022
作者:
Rémi Louf
提交者:
Ricardo Vieira
2月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplify the `IncSubtensor` dispatcher
上级
b3f12b26
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
29 行增加
和
42 行删除
+29
-42
subtensor.py
pytensor/link/jax/dispatch/subtensor.py
+12
-26
test_subtensor.py
tests/link/jax/test_subtensor.py
+17
-16
没有找到文件。
pytensor/link/jax/dispatch/subtensor.py
浏览文件 @
37474223
import
jax
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
...
...
@@ -33,7 +31,7 @@ slice length.
"""
def
assert_indices_jax_compatible
(
node
,
idx_list
):
def
subtensor_
assert_indices_jax_compatible
(
node
,
idx_list
):
from
pytensor.graph.basic
import
Constant
from
pytensor.tensor.var
import
TensorVariable
...
...
@@ -55,7 +53,7 @@ def assert_indices_jax_compatible(node, idx_list):
def
jax_funcify_Subtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
assert_indices_jax_compatible
(
node
,
idx_list
)
subtensor_
assert_indices_jax_compatible
(
node
,
idx_list
)
def
subtensor_constant
(
x
,
*
ilists
):
indices
=
indices_from_subtensor
(
ilists
,
idx_list
)
...
...
@@ -69,25 +67,19 @@ def jax_funcify_Subtensor(op, node, **kwargs):
@jax_funcify.register
(
IncSubtensor
)
@jax_funcify.register
(
AdvancedIncSubtensor1
)
def
jax_funcify_IncSubtensor
(
op
,
**
kwargs
):
def
jax_funcify_IncSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
jax_fn
=
getattr
(
jax
.
ops
,
"index_update"
,
None
)
if
jax_fn
is
None
:
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
set
(
y
)
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
set
(
y
)
else
:
jax_fn
=
getattr
(
jax
.
ops
,
"index_add"
,
None
)
if
jax_fn
is
None
:
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
add
(
y
)
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
add
(
y
)
def
incsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
,
idx_list
=
idx_list
):
indices
=
indices_from_subtensor
(
ilist
,
idx_list
)
...
...
@@ -100,23 +92,17 @@ def jax_funcify_IncSubtensor(op, **kwargs):
@jax_funcify.register
(
AdvancedIncSubtensor
)
def
jax_funcify_AdvancedIncSubtensor
(
op
,
**
kwargs
):
def
jax_funcify_AdvancedIncSubtensor
(
op
,
node
,
**
kwargs
):
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
jax_fn
=
getattr
(
jax
.
ops
,
"index_update"
,
None
)
if
jax_fn
is
None
:
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
set
(
y
)
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
set
(
y
)
else
:
jax_fn
=
getattr
(
jax
.
ops
,
"index_add"
,
None
)
if
jax_fn
is
None
:
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
add
(
y
)
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
add
(
y
)
def
advancedincsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
):
return
jax_fn
(
x
,
ilist
,
y
)
...
...
tests/link/jax/test_subtensor.py
浏览文件 @
37474223
import
numpy
as
np
import
pytest
from
jax._src.errors
import
NonConcreteBooleanIndexError
import
pytensor.tensor
as
at
from
pytensor.configdefaults
import
config
...
...
@@ -179,7 +178,11 @@ def test_jax_IncSubtensor():
compare_jax_and_py
(
out_fg
,
[])
def
test_jax_IncSubtensors_unsupported
():
@pytest.mark.xfail
(
reason
=
"Re-expressible boolean logic. We need a rewrite PyTensor-side to remove the DimShuffle."
)
def
test_jax_IncSubtensor_boolean_mask_reexpressible
():
"""Some boolean logic can be re-expressed and JIT-compiled"""
rng
=
np
.
random
.
default_rng
(
213234
)
x_np
=
rng
.
uniform
(
-
1
,
1
,
size
=
(
3
,
4
,
5
))
.
astype
(
config
.
floatX
)
x_at
=
at
.
constant
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))
.
astype
(
config
.
floatX
))
...
...
@@ -188,30 +191,28 @@ def test_jax_IncSubtensors_unsupported():
out_at
=
at_subtensor
.
set_subtensor
(
x_at
[
mask_at
],
0.0
)
assert
isinstance
(
out_at
.
owner
.
op
,
at_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_at
])
with
pytest
.
raises
(
NonConcreteBooleanIndexError
,
match
=
"Array boolean indices must be concrete"
):
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
out_fg
,
[])
mask_at
=
at
.
as_tensor
_variable
(
x_np
)
>
0
out_at
=
at_subtensor
.
set
_subtensor
(
x_at
[
mask_at
],
1.0
)
mask_at
=
at
.
as_tensor
(
x_np
)
>
0
out_at
=
at_subtensor
.
inc
_subtensor
(
x_at
[
mask_at
],
1.0
)
assert
isinstance
(
out_at
.
owner
.
op
,
at_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_at
])
with
pytest
.
raises
(
NonConcreteBooleanIndexError
,
match
=
"Array boolean indices must be concrete"
):
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
out_fg
,
[])
def
test_jax_IncSubtensors_unsupported
():
rng
=
np
.
random
.
default_rng
(
213234
)
x_np
=
rng
.
uniform
(
-
1
,
1
,
size
=
(
3
,
4
,
5
))
.
astype
(
config
.
floatX
)
x_at
=
at
.
constant
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))
.
astype
(
config
.
floatX
))
st_at
=
at
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
out_at
=
at_subtensor
.
set_subtensor
(
x_at
[[
0
,
2
],
0
,
:
3
],
st_at
)
assert
isinstance
(
out_at
.
owner
.
op
,
at_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_at
])
with
pytest
.
raises
(
IndexError
,
match
=
"Array slice indices must have static"
):
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
out_fg
,
[])
st_at
=
at
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
out_at
=
at_subtensor
.
inc_subtensor
(
x_at
[[
0
,
2
],
0
,
:
3
],
st_at
)
assert
isinstance
(
out_at
.
owner
.
op
,
at_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_at
])
with
pytest
.
raises
(
IndexError
,
match
=
"Array slice indices must have static"
):
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
out_fg
,
[])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论