Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
0087e562
提交
0087e562
authored
12月 12, 2022
作者:
Rémi Louf
提交者:
Ricardo Vieira
2月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add rewrites to re-express boolean indexing logic
上级
37474223
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
119 行增加
和
32 行删除
+119
-32
__init__.py
pytensor/tensor/rewriting/__init__.py
+3
-0
jax.py
pytensor/tensor/rewriting/jax.py
+78
-0
test_subtensor.py
tests/link/jax/test_subtensor.py
+38
-32
没有找到文件。
pytensor/tensor/rewriting/__init__.py
浏览文件 @
0087e562
import
pytensor.tensor.rewriting.basic
import
pytensor.tensor.rewriting.elemwise
import
pytensor.tensor.rewriting.extra_ops
# Register JAX specializations
import
pytensor.tensor.rewriting.jax
import
pytensor.tensor.rewriting.math
import
pytensor.tensor.rewriting.shape
import
pytensor.tensor.rewriting.special
...
...
pytensor/tensor/rewriting/jax.py
0 → 100644
浏览文件 @
0087e562
from
pytensor.compile
import
optdb
from
pytensor.graph.rewriting.basic
import
in2out
,
node_rewriter
from
pytensor.tensor.var
import
TensorVariable
import
pytensor.tensor
as
at
from
pytensor.tensor.subtensor
import
AdvancedIncSubtensor
,
AdvancedSubtensor
from
pytensor.tensor.math
import
Sum
@node_rewriter
([
AdvancedIncSubtensor
])
def
boolean_indexing_set_or_inc
(
fgraph
,
node
):
"""Replace `AdvancedIncSubtensor` when using boolean indexing using `Switch`.
JAX cannot JIT-compile functions that use boolean indexing to set values in
an array. A workaround is to re-express this logic using `jax.numpy.where`.
This rewrite allows to improve upon JAX's API.
"""
op
=
node
.
op
x
=
node
.
inputs
[
0
]
y
=
node
.
inputs
[
1
]
cond
=
node
.
inputs
[
2
]
if
not
isinstance
(
cond
,
TensorVariable
):
return
if
not
cond
.
type
.
dtype
==
'bool'
:
return
if
op
.
set_instead_of_inc
:
out
=
at
.
where
(
cond
,
y
,
x
)
return
out
.
owner
.
outputs
else
:
out
=
at
.
where
(
cond
,
x
+
y
,
x
)
return
out
.
owner
.
outputs
optdb
.
register
(
"jax_boolean_indexing_set_or_inc"
,
in2out
(
boolean_indexing_set_or_inc
),
"jax"
,
position
=
100
)
@node_rewriter
([
Sum
])
def
boolean_indexing_sum
(
fgraph
,
node
):
"""Replace the sum of `AdvancedSubtensor` with boolean indexing.
JAX cannot JIT-compile functions that use boolean indexing, but can compile
those expressions that can be re-expressed using `jax.numpy.where`. This
rewrite re-rexpressed the model on the behalf of the user and thus allows to
improve upon JAX's API.
"""
operand
=
node
.
inputs
[
0
]
if
not
isinstance
(
operand
,
TensorVariable
):
return
if
operand
.
owner
is
None
:
return
if
not
isinstance
(
operand
.
owner
.
op
,
AdvancedSubtensor
):
return
x
=
operand
.
owner
.
inputs
[
0
]
cond
=
operand
.
owner
.
inputs
[
1
]
if
not
isinstance
(
cond
,
TensorVariable
):
return
if
not
cond
.
type
.
dtype
==
'bool'
:
return
out
=
at
.
sum
(
at
.
where
(
cond
,
x
,
0
))
return
out
.
owner
.
outputs
optdb
.
register
(
"jax_boolean_indexing_sum"
,
in2out
(
boolean_indexing_sum
),
"jax"
,
position
=
100
)
tests/link/jax/test_subtensor.py
浏览文件 @
0087e562
...
...
@@ -80,15 +80,21 @@ def test_jax_Subtensor_boolean_mask():
compare_jax_and_py
(
out_fg
,
[])
@pytest.mark.xfail
(
reason
=
"Re-expressible boolean logic. We need a rewrite PyTensor-side."
)
def
test_jax_Subtensor_boolean_mask_reexpressible
():
"""Some boolean logic can be re-expressed and JIT-compiled"""
x_at
=
at
.
arange
(
-
5
,
5
)
"""Summing values with boolean indexing.
This test ensures that the sum of an `AdvancedSubtensor` `Op`s with boolean
indexing is replaced with the sum of an equivalent `Switch` `Op`, using the
`jax_boolean_indexing_sum` rewrite.
JAX forces users to re-express this logic manually, so this is an
improvement over its user interface.
"""
x_at
=
at
.
vector
(
"x"
)
out_at
=
x_at
[
x_at
<
0
]
.
sum
()
out_fg
=
FunctionGraph
([],
[
out_at
])
compare_jax_and_py
(
out_fg
,
[])
out_fg
=
FunctionGraph
([
x_at
],
[
out_at
])
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
-
5
,
5
)
.
astype
(
config
.
floatX
)
])
def
test_jax_IncSubtensor
():
...
...
@@ -177,42 +183,42 @@ def test_jax_IncSubtensor():
out_fg
=
FunctionGraph
([],
[
out_at
])
compare_jax_and_py
(
out_fg
,
[])
@pytest.mark.xfail
(
reason
=
"Re-expressible boolean logic. We need a rewrite PyTensor-side to remove the DimShuffle."
)
def
test_jax_IncSubtensor_boolean_mask_reexpressible
():
"""Some boolean logic can be re-expressed and JIT-compiled"""
rng
=
np
.
random
.
default_rng
(
213234
)
x_np
=
rng
.
uniform
(
-
1
,
1
,
size
=
(
3
,
4
,
5
))
.
astype
(
config
.
floatX
)
x_at
=
at
.
constant
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))
.
astype
(
config
.
floatX
))
mask_at
=
at
.
as_tensor
(
x_np
)
>
0
out_at
=
at_subtensor
.
set_subtensor
(
x_at
[
mask_at
],
0.0
)
st_at
=
at
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
out_at
=
at_subtensor
.
set_subtensor
(
x_at
[[
0
,
2
],
0
,
:
3
],
st_at
)
assert
isinstance
(
out_at
.
owner
.
op
,
at_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_at
])
compare_jax_and_py
(
out_fg
,
[])
mask_at
=
at
.
as_tensor
(
x_np
)
>
0
out_at
=
at_subtensor
.
inc_subtensor
(
x_at
[
mask_at
],
1.0
)
st_at
=
at
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
out_at
=
at_subtensor
.
inc_subtensor
(
x_at
[
[
0
,
2
],
0
,
:
3
],
st_at
)
assert
isinstance
(
out_at
.
owner
.
op
,
at_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_at
])
compare_jax_and_py
(
out_fg
,
[])
def
test_jax_IncSubtensors_unsupported
():
def
test_jax_IncSubtensor_boolean_indexing_reexpressible
():
"""Setting or incrementing values with boolean indexing.
This test ensures that `AdvancedIncSubtensor` `Op`s with boolean indexing is
replaced with an equivalent `Switch` `Op`, using the
`jax_boolean_indexing_set_of_inc` rewrite.
JAX forces users to re-express this logic manually, so this is an
improvement over its user interface.
"""
rng
=
np
.
random
.
default_rng
(
213234
)
x_np
=
rng
.
uniform
(
-
1
,
1
,
size
=
(
3
,
4
,
5
))
.
astype
(
config
.
floatX
)
x_at
=
at
.
constant
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))
.
astype
(
config
.
floatX
))
x_np
=
rng
.
uniform
(
-
1
,
1
,
size
=
(
4
,
5
))
.
astype
(
config
.
floatX
)
st_at
=
at
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
out_at
=
at_subtensor
.
set_subtensor
(
x_at
[[
0
,
2
],
0
,
:
3
],
st_at
)
x_at
=
at
.
matrix
(
"x"
)
mask_at
=
at
.
as_tensor
(
x_at
)
>
0
out_at
=
at_subtensor
.
set_subtensor
(
x_at
[
mask_at
],
0.0
)
assert
isinstance
(
out_at
.
owner
.
op
,
at_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_at
])
compare_jax_and_py
(
out_fg
,
[])
out_fg
=
FunctionGraph
([
x_at
],
[
out_at
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
st_at
=
at
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
out_at
=
at_subtensor
.
inc_subtensor
(
x_at
[
[
0
,
2
],
0
,
:
3
],
st_at
)
mask_at
=
at
.
as_tensor
(
x_at
)
>
0
out_at
=
at_subtensor
.
inc_subtensor
(
x_at
[
mask_at
],
1.0
)
assert
isinstance
(
out_at
.
owner
.
op
,
at_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_at
])
compare_jax_and_py
(
out_fg
,
[])
out_fg
=
FunctionGraph
([
x_at
],
[
out_at
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论