Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cfc931fa
提交
cfc931fa
authored
8月 12, 2022
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
8月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow JAX Reshape to work with constant shape inputs
上级
6c6bf08a
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
25 行增加
和
13 行删除
+25
-13
dispatch.py
aesara/link/jax/dispatch.py
+15
-3
test_jax.py
tests/link/test_jax.py
+10
-10
没有找到文件。
aesara/link/jax/dispatch.py
浏览文件 @
cfc931fa
...
...
@@ -11,6 +11,7 @@ from numpy.random.bit_generator import _coerce_to_uint32_array
from
aesara.compile.ops
import
DeepCopyOp
,
ViewOp
from
aesara.configdefaults
import
config
from
aesara.graph
import
Constant
from
aesara.graph.fg
import
FunctionGraph
from
aesara.ifelse
import
IfElse
from
aesara.link.utils
import
fgraph_to_python
...
...
@@ -728,9 +729,20 @@ def jax_funcify_MakeVector(op, **kwargs):
@jax_funcify.register
(
Reshape
)
def
jax_funcify_Reshape
(
op
,
**
kwargs
):
def
reshape
(
x
,
shape
):
return
jnp
.
reshape
(
x
,
shape
)
def
jax_funcify_Reshape
(
op
,
node
,
**
kwargs
):
# JAX reshape only works with constant inputs, otherwise JIT fails
shape
=
node
.
inputs
[
1
]
if
isinstance
(
shape
,
Constant
):
constant_shape
=
shape
.
data
def
reshape
(
x
,
_
):
return
jax
.
numpy
.
reshape
(
x
,
constant_shape
)
else
:
def
reshape
(
x
,
shape
):
return
jax
.
numpy
.
reshape
(
x
,
shape
)
return
reshape
...
...
tests/link/test_jax.py
浏览文件 @
cfc931fa
...
...
@@ -863,10 +863,6 @@ def test_jax_MakeVector():
compare_jax_and_py
(
x_fg
,
[])
@pytest.mark.xfail
(
version_parse
(
jax
.
__version__
)
>=
version_parse
(
"0.2.12"
),
reason
=
"Omnistaging cannot be disabled"
,
)
def
test_jax_Reshape
():
a
=
vector
(
"a"
)
x
=
reshape
(
a
,
(
2
,
2
))
...
...
@@ -877,16 +873,20 @@ def test_jax_Reshape():
# See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68
x
=
reshape
(
a
,
(
a
.
shape
[
0
]
//
2
,
a
.
shape
[
0
]
//
2
))
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
with
pytest
.
raises
(
TypeError
,
match
=
"Shapes must be 1D sequences of concrete values of integer type"
,
):
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
@pytest.mark.xfail
(
reason
=
"jax.numpy.arange requires concrete inputs"
)
def
test_jax_Reshape_nonconcrete
():
a
=
vector
(
"a"
)
b
=
iscalar
(
"b"
)
x
=
reshape
(
a
,
(
b
,
b
))
x_fg
=
FunctionGraph
([
a
,
b
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
),
2
])
with
pytest
.
raises
(
TypeError
,
match
=
"Shapes must be 1D sequences of concrete values of integer type"
,
):
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
),
2
])
def
test_jax_Dimshuffle
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论