Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a62dec23
提交
a62dec23
authored
9月 27, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add JAX conversions for dot and arange
上级
02c02d72
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
51 行增加
和
0 行删除
+51
-0
test_jax.py
tests/sandbox/test_jax.py
+32
-0
jaxify.py
theano/sandbox/jaxify.py
+19
-0
没有找到文件。
tests/sandbox/test_jax.py
浏览文件 @
a62dec23
...
@@ -500,3 +500,35 @@ def test_nnet():
...
@@ -500,3 +500,35 @@ def test_nnet():
out
=
tt
.
nnet
.
softplus
(
x
)
out
=
tt
.
nnet
.
softplus
(
x
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
x
],
[
out
])
fgraph
=
theano
.
gof
.
FunctionGraph
([
x
],
[
out
])
_
=
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
_
=
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
def
test_tensor_basics
():
y
=
tt
.
vector
(
"y"
)
y
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
theano
.
config
.
floatX
)
x
=
tt
.
vector
(
"x"
)
x
.
tag
.
test_value
=
np
.
r_
[
3.0
,
4.0
]
.
astype
(
theano
.
config
.
floatX
)
A
=
tt
.
matrix
(
"A"
)
A
.
tag
.
test_value
=
np
.
empty
((
2
,
2
),
dtype
=
theano
.
config
.
floatX
)
alpha
=
tt
.
scalar
(
"alpha"
)
alpha
.
tag
.
test_value
=
np
.
array
(
3.0
,
dtype
=
theano
.
config
.
floatX
)
beta
=
tt
.
scalar
(
"beta"
)
beta
.
tag
.
test_value
=
np
.
array
(
5.0
,
dtype
=
theano
.
config
.
floatX
)
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
out
=
y
.
dot
(
alpha
*
A
)
.
dot
(
x
)
+
beta
*
y
fgraph
=
theano
.
gof
.
FunctionGraph
([
y
,
x
,
A
,
alpha
,
beta
],
[
out
])
_
=
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
@pytest.mark.xfail
(
reason
=
"jax.numpy.arange requires concrete inputs"
)
def
test_arange
():
a
=
tt
.
scalar
(
"a"
)
a
.
tag
.
test_value
=
10
out
=
tt
.
arange
(
a
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
a
],
[
out
])
_
=
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
theano/sandbox/jaxify.py
浏览文件 @
a62dec23
...
@@ -26,6 +26,8 @@ from theano.tensor.subtensor import (
...
@@ -26,6 +26,8 @@ from theano.tensor.subtensor import (
from
theano.scan_module.scan_op
import
Scan
from
theano.scan_module.scan_op
import
Scan
from
theano.scan_module.scan_utils
import
scan_args
as
ScanArgs
from
theano.scan_module.scan_utils
import
scan_args
as
ScanArgs
from
theano.tensor.basic
import
(
from
theano.tensor.basic
import
(
Dot
,
ARange
,
TensorFromScalar
,
TensorFromScalar
,
ScalarFromTensor
,
ScalarFromTensor
,
AllocEmpty
,
AllocEmpty
,
...
@@ -198,6 +200,23 @@ def jax_funcify_Alloc(op):
...
@@ -198,6 +200,23 @@ def jax_funcify_Alloc(op):
return
alloc
return
alloc
@jax_funcify.register
(
Dot
)
def
jax_funcify_Dot
(
op
):
def
dot
(
x
,
y
):
return
jnp
.
dot
(
x
,
y
)
return
dot
@jax_funcify.register
(
ARange
)
def
jax_funcify_ARange
(
op
):
# XXX: This currently requires concrete arguments.
def
arange
(
start
,
stop
,
step
):
return
jnp
.
arange
(
start
,
stop
,
step
,
dtype
=
op
.
dtype
)
return
arange
def
jnp_safe_copy
(
x
):
def
jnp_safe_copy
(
x
):
try
:
try
:
res
=
jnp
.
copy
(
x
)
res
=
jnp
.
copy
(
x
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论