Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
b9792d8a
Unverified
提交
b9792d8a
authored
5月 18, 2023
作者:
jessegrabowski
提交者:
GitHub
5月 18, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add JAX support for `pt.tri` (#302)
上级
6b43b433
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
43 行增加
和
1 行删除
+43
-1
tensor_basic.py
pytensor/link/jax/dispatch/tensor_basic.py
+16
-1
test_tensor_basic.py
tests/link/jax/test_tensor_basic.py
+27
-0
没有找到文件。
pytensor/link/jax/dispatch/tensor_basic.py
浏览文件 @
b9792d8a
...
@@ -18,6 +18,7 @@ from pytensor.tensor.basic import (
...
@@ -18,6 +18,7 @@ from pytensor.tensor.basic import (
ScalarFromTensor
,
ScalarFromTensor
,
Split
,
Split
,
TensorFromScalar
,
TensorFromScalar
,
Tri
,
get_underlying_scalar_constant_value
,
get_underlying_scalar_constant_value
,
)
)
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
...
@@ -26,7 +27,6 @@ from pytensor.tensor.exceptions import NotScalarConstantError
...
@@ -26,7 +27,6 @@ from pytensor.tensor.exceptions import NotScalarConstantError
ARANGE_CONCRETE_VALUE_ERROR
=
"""JAX requires the arguments of `jax.numpy.arange`
ARANGE_CONCRETE_VALUE_ERROR
=
"""JAX requires the arguments of `jax.numpy.arange`
to be constants. The graph that you defined thus cannot be JIT-compiled
to be constants. The graph that you defined thus cannot be JIT-compiled
by JAX. An example of a graph that can be compiled to JAX:
by JAX. An example of a graph that can be compiled to JAX:
>>> import pytensor.tensor basic
>>> import pytensor.tensor basic
>>> at.arange(1, 10, 2)
>>> at.arange(1, 10, 2)
"""
"""
...
@@ -193,3 +193,18 @@ def jax_funcify_ScalarFromTensor(op, **kwargs):
...
@@ -193,3 +193,18 @@ def jax_funcify_ScalarFromTensor(op, **kwargs):
return
jnp
.
array
(
x
)
.
flatten
()[
0
]
return
jnp
.
array
(
x
)
.
flatten
()[
0
]
return
scalar_from_tensor
return
scalar_from_tensor
@jax_funcify.register
(
Tri
)
def
jax_funcify_Tri
(
op
,
node
,
**
kwargs
):
# node.inputs is N, M, k
const_args
=
[
getattr
(
x
,
"data"
,
None
)
for
x
in
node
.
inputs
]
def
tri
(
*
args
):
# args is N, M, k
args
=
[
x
if
const_x
is
None
else
const_x
for
x
,
const_x
in
zip
(
args
,
const_args
)
]
return
jnp
.
tri
(
*
args
,
dtype
=
op
.
dtype
)
return
tri
tests/link/jax/test_tensor_basic.py
浏览文件 @
b9792d8a
...
@@ -191,3 +191,30 @@ def test_jax_eye():
...
@@ -191,3 +191,30 @@ def test_jax_eye():
out_fg
=
FunctionGraph
([],
[
out
])
out_fg
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
out_fg
,
[])
def
test_tri
():
out
=
at
.
tri
(
10
,
10
,
0
)
fgraph
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
(
fgraph
,
[])
def
test_tri_nonconcrete
():
"""JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values."""
m
,
n
,
k
=
(
scalar
(
"a"
,
dtype
=
"int64"
),
scalar
(
"n"
,
dtype
=
"int64"
),
scalar
(
"k"
,
dtype
=
"int64"
),
)
m
.
tag
.
test_value
=
10
n
.
tag
.
test_value
=
10
k
.
tag
.
test_value
=
0
out
=
at
.
tri
(
m
,
n
,
k
)
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but
# the error handler raises an Attribute error first, so that's what this test needs to pass
with
pytest
.
raises
(
AttributeError
):
fgraph
=
FunctionGraph
([
m
,
n
,
k
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论