Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d5122713
Unverified
提交
d5122713
authored
7月 06, 2024
作者:
Pham Nguyen Hung
提交者:
GitHub
7月 06, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix JAX implementation of Argmax (#809)
Co-authored-by:
Ricardo Vieira
<
28983449+ricardoV94@users.noreply.github.com
>
上级
31bf6822
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
8 行增加
和
18 行删除
+8
-18
nlinalg.py
pytensor/link/jax/dispatch/nlinalg.py
+5
-6
test_extra_ops.py
tests/link/jax/test_extra_ops.py
+2
-2
test_nlinalg.py
tests/link/jax/test_nlinalg.py
+1
-10
没有找到文件。
pytensor/link/jax/dispatch/nlinalg.py
浏览文件 @
d5122713
import
jax.numpy
as
jnp
import
numpy
as
np
from
pytensor.link.jax.dispatch
import
jax_funcify
from
pytensor.tensor.blas
import
BatchedDot
...
...
@@ -137,12 +138,10 @@ def jax_funcify_Argmax(op, **kwargs):
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes
=
jnp
.
array
(
[
i
for
i
in
range
(
x
.
ndim
)
if
i
not
in
axes
],
dtype
=
"int64"
)
keep_axes
=
np
.
array
([
i
for
i
in
range
(
x
.
ndim
)
if
i
not
in
axes
],
dtype
=
"int64"
)
# Not-reduced axes in front
transposed_x
=
jnp
.
transpose
(
x
,
jnp
.
concatenate
((
keep_axes
,
jnp
.
array
(
axes
,
dtype
=
"int64"
)))
x
,
tuple
(
np
.
concatenate
((
keep_axes
,
np
.
array
(
axes
,
dtype
=
"int64"
)
)))
)
kept_shape
=
transposed_x
.
shape
[:
len
(
keep_axes
)]
reduced_shape
=
transposed_x
.
shape
[
len
(
keep_axes
)
:]
...
...
@@ -151,9 +150,9 @@ def jax_funcify_Argmax(op, **kwargs):
# Otherwise reshape would complain citing float arg
new_shape
=
(
*
kept_shape
,
jnp
.
prod
(
j
np
.
array
(
reduced_shape
,
dtype
=
"int64"
),
dtype
=
"int64"
),
np
.
prod
(
np
.
array
(
reduced_shape
,
dtype
=
"int64"
),
dtype
=
"int64"
),
)
reshaped_x
=
transposed_x
.
reshape
(
new_shape
)
reshaped_x
=
transposed_x
.
reshape
(
tuple
(
new_shape
)
)
max_idx_res
=
jnp
.
argmax
(
reshaped_x
,
axis
=-
1
)
.
astype
(
"int64"
)
...
...
tests/link/jax/test_extra_ops.py
浏览文件 @
d5122713
...
...
@@ -65,9 +65,9 @@ def test_extra_ops():
@pytest.mark.xfail
(
version_parse
(
jax
.
__version__
)
>=
version_parse
(
"0.2.12"
),
reason
=
"
Omnistaging cannot be disabled
"
,
reason
=
"
JAX Numpy API does not support dynamic shapes
"
,
)
def
test_extra_ops_
omni
():
def
test_extra_ops_
dynamic_shapes
():
a
=
matrix
(
"a"
)
a
.
tag
.
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
...
...
tests/link/jax/test_nlinalg.py
浏览文件 @
d5122713
import
numpy
as
np
import
pytest
from
packaging.version
import
parse
as
version_parse
from
pytensor.compile.function
import
function
from
pytensor.compile.mode
import
Mode
...
...
@@ -80,11 +79,7 @@ def test_jax_basic_multiout():
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
@pytest.mark.xfail
(
version_parse
(
jax
.
__version__
)
>=
version_parse
(
"0.2.12"
),
reason
=
"Omnistaging cannot be disabled"
,
)
def
test_jax_basic_multiout_omni
():
def
test_jax_max_and_argmax
():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x
=
dvector
()
...
...
@@ -95,10 +90,6 @@ def test_jax_basic_multiout_omni():
compare_jax_and_py
(
out_fg
,
[
np
.
r_
[
1
,
2
]])
@pytest.mark.xfail
(
version_parse
(
jax
.
__version__
)
>=
version_parse
(
"0.2.12"
),
reason
=
"Omnistaging cannot be disabled"
,
)
def
test_tensor_basics
():
y
=
vector
(
"y"
)
y
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论