Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
28fc9acb
提交
28fc9acb
authored
8月 10, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
8月 24, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Do not raise early when a Shape operation is an input to Arange in the JAX backend
上级
71c58f39
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
20 行增加
和
6 行删除
+20
-6
tensor_basic.py
pytensor/link/jax/dispatch/tensor_basic.py
+12
-5
test_slinalg.py
tests/link/jax/test_slinalg.py
+1
-1
test_tensor_basic.py
tests/link/jax/test_tensor_basic.py
+7
-0
没有找到文件。
pytensor/link/jax/dispatch/tensor_basic.py
浏览文件 @
28fc9acb
...
...
@@ -21,6 +21,7 @@ from pytensor.tensor.basic import (
get_underlying_scalar_constant_value
,
)
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.shape
import
Shape_i
ARANGE_CONCRETE_VALUE_ERROR
=
"""JAX requires the arguments of `jax.numpy.arange`
...
...
@@ -61,14 +62,20 @@ def jax_funcify_ARange(op, node, **kwargs):
arange_args
=
node
.
inputs
constant_args
=
[]
for
arg
in
arange_args
:
if
not
isinstance
(
arg
,
Constant
):
if
arg
.
owner
and
isinstance
(
arg
.
owner
.
op
,
Shape_i
):
constant_args
.
append
(
None
)
elif
isinstance
(
arg
,
Constant
):
constant_args
.
append
(
arg
.
value
)
else
:
# TODO: This might be failing without need (e.g., if arg = shape(x)[-1] + 1)!
raise
NotImplementedError
(
ARANGE_CONCRETE_VALUE_ERROR
)
constant_args
.
append
(
arg
.
value
)
start
,
stop
,
step
=
constant_args
constant_start
,
constant_stop
,
constant_step
=
constant_args
def
arange
(
*
_
):
def
arange
(
start
,
stop
,
step
):
start
=
start
if
constant_start
is
None
else
constant_start
stop
=
stop
if
constant_stop
is
None
else
constant_stop
step
=
step
if
constant_step
is
None
else
constant_step
return
jnp
.
arange
(
start
,
stop
,
step
,
dtype
=
op
.
dtype
)
return
arange
...
...
tests/link/jax/test_slinalg.py
浏览文件 @
28fc9acb
...
...
@@ -85,7 +85,7 @@ def test_jax_basic():
],
)
out
=
at
.
diag
(
at
.
specify_shape
(
b
,
shape
=
(
10
,))
)
out
=
at
.
diag
(
b
)
out_fg
=
FunctionGraph
([
b
],
[
out
])
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
10
)
.
astype
(
config
.
floatX
)])
...
...
tests/link/jax/test_tensor_basic.py
浏览文件 @
28fc9acb
...
...
@@ -63,6 +63,13 @@ def test_arange():
compare_jax_and_py
(
fgraph
,
[])
def
test_arange_of_shape
():
x
=
vector
(
"x"
)
out
=
at
.
arange
(
1
,
x
.
shape
[
-
1
],
2
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
np
.
zeros
((
5
,))])
def
test_arange_nonconcrete
():
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论