Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ce01b113
Unverified
提交
ce01b113
authored
3月 20, 2021
作者:
Adrian Seyboldt
提交者:
GitHub
3月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add JAX conversion for LogSoftmax and a fix for jax_funcify_join (#343)
上级
219a9516
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
54 行增加
和
1 行删除
+54
-1
jax_dispatch.py
aesara/link/jax/jax_dispatch.py
+11
-1
test_jax.py
tests/link/test_jax.py
+43
-0
没有找到文件。
aesara/link/jax/jax_dispatch.py
浏览文件 @
ce01b113
...
@@ -52,7 +52,7 @@ from aesara.tensor.nlinalg import (
...
@@ -52,7 +52,7 @@ from aesara.tensor.nlinalg import (
QRFull
,
QRFull
,
QRIncomplete
,
QRIncomplete
,
)
)
from
aesara.tensor.nnet.basic
import
Softmax
from
aesara.tensor.nnet.basic
import
LogSoftmax
,
Softmax
from
aesara.tensor.nnet.sigm
import
ScalarSoftplus
from
aesara.tensor.nnet.sigm
import
ScalarSoftplus
from
aesara.tensor.random.op
import
RandomVariable
from
aesara.tensor.random.op
import
RandomVariable
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
...
@@ -275,6 +275,14 @@ def jax_funcify_Softmax(op):
...
@@ -275,6 +275,14 @@ def jax_funcify_Softmax(op):
return
softmax
return
softmax
@jax_funcify.register
(
LogSoftmax
)
def
jax_funcify_LogSoftmax
(
op
):
def
log_softmax
(
x
):
return
jax
.
nn
.
log_softmax
(
x
)
return
log_softmax
@jax_funcify.register
(
ScalarSoftplus
)
@jax_funcify.register
(
ScalarSoftplus
)
def
jax_funcify_ScalarSoftplus
(
op
):
def
jax_funcify_ScalarSoftplus
(
op
):
def
scalarsoftplus
(
x
):
def
scalarsoftplus
(
x
):
...
@@ -786,6 +794,8 @@ def jax_funcify_DimShuffle(op):
...
@@ -786,6 +794,8 @@ def jax_funcify_DimShuffle(op):
@jax_funcify.register
(
Join
)
@jax_funcify.register
(
Join
)
def
jax_funcify_Join
(
op
):
def
jax_funcify_Join
(
op
):
def
join
(
axis
,
*
tensors
):
def
join
(
axis
,
*
tensors
):
# tensors could also be tuples, and in this case they don't have a ndim
tensors
=
[
jnp
.
asarray
(
tensor
)
for
tensor
in
tensors
]
view
=
op
.
view
view
=
op
.
view
if
(
view
!=
-
1
)
and
all
(
if
(
view
!=
-
1
)
and
all
(
[
[
...
...
tests/link/test_jax.py
浏览文件 @
ce01b113
...
@@ -703,6 +703,45 @@ def test_jax_Dimshuffle():
...
@@ -703,6 +703,45 @@ def test_jax_Dimshuffle():
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
def
test_jax_Join
():
a
=
matrix
(
"a"
)
b
=
matrix
(
"b"
)
x
=
aet
.
join
(
0
,
a
,
b
)
x_fg
=
FunctionGraph
([
a
,
b
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
],
)
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
]]
.
astype
(
config
.
floatX
),
],
)
x
=
aet
.
join
(
1
,
a
,
b
)
x_fg
=
FunctionGraph
([
a
,
b
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
],
)
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
],
)
def
test_jax_variadic_Scalar
():
def
test_jax_variadic_Scalar
():
mu
=
vector
(
"mu"
,
dtype
=
config
.
floatX
)
mu
=
vector
(
"mu"
,
dtype
=
config
.
floatX
)
mu
.
tag
.
test_value
=
np
.
r_
[
0.1
,
1.1
]
.
astype
(
config
.
floatX
)
mu
.
tag
.
test_value
=
np
.
r_
[
0.1
,
1.1
]
.
astype
(
config
.
floatX
)
...
@@ -777,6 +816,10 @@ def test_nnet():
...
@@ -777,6 +816,10 @@ def test_nnet():
fgraph
=
FunctionGraph
([
x
],
[
out
])
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
aet_nnet
.
logsoftmax
(
x
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
def
test_tensor_basics
():
def
test_tensor_basics
():
y
=
vector
(
"y"
)
y
=
vector
(
"y"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论