Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
54678486
提交
54678486
authored
9月 27, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add JAX support for sigmoid scalar Ops
上级
8a8c7e7d
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
41 行增加
和
1 行删除
+41
-1
test_jax.py
tests/sandbox/test_jax.py
+23
-0
jaxify.py
theano/sandbox/jaxify.py
+14
-1
sigm.py
theano/tensor/nnet/sigm.py
+4
-0
没有找到文件。
tests/sandbox/test_jax.py
浏览文件 @
54678486
...
...
@@ -150,6 +150,10 @@ def test_jax_basic():
assert
jax_res
[
0
,
0
]
==
-
10.0
assert
jax_res
[
0
,
1
]
==
-
8.0
out
=
tt
.
clip
(
x
,
y
,
5
)
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
,
y
],
[
out
])
(
jax_res
,)
=
compare_jax_and_py
(
out_fg
,
test_input_vals
)
@pytest.mark.skip
(
reason
=
"Not fully implemented, yet."
)
def
test_jax_scan
():
...
...
@@ -478,3 +482,22 @@ def test_jax_multioutput():
fgraph
=
theano
.
gof
.
FunctionGraph
([
x
,
y
],
[
w
,
v
])
_
=
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
def
test_nnet
():
x
=
tt
.
vector
(
"x"
)
x
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
tt
.
config
.
floatX
)
out
=
tt
.
nnet
.
sigmoid
(
x
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
x
],
[
out
])
_
=
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
tt
.
nnet
.
ultra_fast_sigmoid
(
x
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
x
],
[
out
])
_
=
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
tt
.
nnet
.
softplus
(
x
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
x
],
[
out
])
_
=
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
theano/sandbox/jaxify.py
浏览文件 @
54678486
...
...
@@ -50,6 +50,8 @@ from theano.compile.ops import (
)
from
theano.tensor.opt
import
MakeVector
from
theano.tensor.nnet.sigm
import
ScalarSoftplus
# XXX: Enabling this will break some shape-based functionality, and severely
# limit the types of graphs that can be converted.
...
...
@@ -165,7 +167,18 @@ def jax_funcify_ScalarOp(op):
@jax_funcify.register
(
Clip
)
def
jax_funcify_Clip
(
op
):
return
partial
(
op
.
impl
,
None
)
def
clip
(
x
,
min
,
max
):
return
jnp
.
where
(
x
<
min
,
min
,
jnp
.
where
(
x
>
max
,
max
,
x
))
return
clip
@jax_funcify.register
(
ScalarSoftplus
)
def
jax_funcify_ScalarSoftplus
(
op
):
def
scalarsoftplus
(
x
):
return
jnp
.
where
(
x
<
-
30.0
,
0.0
,
jnp
.
where
(
x
>
30.0
,
x
,
jnp
.
log1p
(
jnp
.
exp
(
x
))))
return
scalarsoftplus
@jax_funcify.register
(
AllocEmpty
)
...
...
theano/tensor/nnet/sigm.py
浏览文件 @
54678486
...
...
@@ -31,6 +31,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp):
"""
nfunc_spec
=
(
"scipy.special.expit"
,
1
,
1
)
@staticmethod
def
st_impl
(
x
):
if
x
<
-
30.0
:
...
...
@@ -196,6 +198,8 @@ class UltraFastScalarSigmoid(scalar.UnaryScalarOp):
"""
nfunc_spec
=
(
"scipy.special.expit"
,
1
,
1
)
@staticmethod
def
st_impl
(
x
):
x
=
0.5
*
x
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论