Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
90ef59eb
Unverified
提交
90ef59eb
authored
10月 01, 2020
作者:
Thomas Wiecki
提交者:
GitHub
10月 01, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add jaxification for linear algebra operations (#59)
Co-authored-by:
Brandon T. Willard
<
brandonwillard@users.noreply.github.com
>
上级
4c72bf9e
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
150 行增加
和
11 行删除
+150
-11
test_jax.py
tests/sandbox/test_jax.py
+0
-0
jax_linker.py
theano/sandbox/jax_linker.py
+14
-9
jaxify.py
theano/sandbox/jaxify.py
+136
-2
没有找到文件。
tests/sandbox/test_jax.py
浏览文件 @
90ef59eb
差异被折叠。
点击展开。
theano/sandbox/jax_linker.py
浏览文件 @
90ef59eb
...
@@ -74,15 +74,12 @@ class JAXLinker(PerformLinker):
...
@@ -74,15 +74,12 @@ class JAXLinker(PerformLinker):
thunk_outputs
=
[
storage_map
[
n
]
for
n
in
node
.
outputs
]
thunk_outputs
=
[
storage_map
[
n
]
for
n
in
node
.
outputs
]
# JIT-compile the functions
if
not
isinstance
(
jax_funcs
,
Sequence
):
if
len
(
node
.
outputs
)
>
1
:
jax_funcs
=
[
jax_funcs
]
assert
len
(
jax_funcs
)
==
len
(
node
.
ouptputs
)
jax_impl_jits
=
[
jax_impl_jits
=
[
jax
.
jit
(
jax_func
,
static_argnums
)
for
jax_func
in
jax_funcs
jax
.
jit
(
jax_func
,
static_argnums
)
for
jax_func
in
jax_funcs
]
]
else
:
assert
not
isinstance
(
jax_funcs
,
Sequence
)
jax_impl_jits
=
[
jax
.
jit
(
jax_funcs
,
static_argnums
)]
def
thunk
(
def
thunk
(
node
=
node
,
jax_impl_jits
=
jax_impl_jits
,
thunk_outputs
=
thunk_outputs
node
=
node
,
jax_impl_jits
=
jax_impl_jits
,
thunk_outputs
=
thunk_outputs
...
@@ -92,6 +89,14 @@ class JAXLinker(PerformLinker):
...
@@ -92,6 +89,14 @@ class JAXLinker(PerformLinker):
for
jax_impl_jit
in
jax_impl_jits
for
jax_impl_jit
in
jax_impl_jits
]
]
if
len
(
jax_impl_jits
)
<
len
(
node
.
outputs
):
# In this case, the JAX function will output a single
# output that contains the other outputs.
# This happens for multi-output `Op`s that directly
# correspond to multi-output JAX functions (e.g. `SVD` and
# `jax.numpy.linalg.svd`).
outputs
=
outputs
[
0
]
for
o_node
,
o_storage
,
o_val
in
zip
(
for
o_node
,
o_storage
,
o_val
in
zip
(
node
.
outputs
,
thunk_outputs
,
outputs
node
.
outputs
,
thunk_outputs
,
outputs
):
):
...
...
theano/sandbox/jaxify.py
浏览文件 @
90ef59eb
...
@@ -2,6 +2,7 @@ import theano
...
@@ -2,6 +2,7 @@ import theano
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
jax.scipy
as
jsp
from
warnings
import
warn
from
warnings
import
warn
from
functools
import
update_wrapper
,
reduce
from
functools
import
update_wrapper
,
reduce
...
@@ -49,12 +50,36 @@ from theano.tensor.opt import MakeVector
...
@@ -49,12 +50,36 @@ from theano.tensor.opt import MakeVector
from
theano.tensor.nnet.sigm
import
ScalarSoftplus
from
theano.tensor.nnet.sigm
import
ScalarSoftplus
from
theano.tensor.nlinalg
import
(
Det
,
Eig
,
Eigh
,
MatrixInverse
,
QRFull
,
QRIncomplete
,
SVD
,
ExtractDiag
,
AllocDiag
,
)
from
theano.tensor.slinalg
import
(
Cholesky
,
Solve
,
)
if
theano
.
config
.
floatX
==
"float64"
:
jax
.
config
.
update
(
"jax_enable_x64"
,
True
)
else
:
jax
.
config
.
update
(
"jax_enable_x64"
,
False
)
# XXX: Enabling this will break some shape-based functionality, and severely
# XXX: Enabling this will break some shape-based functionality, and severely
# limit the types of graphs that can be converted.
# limit the types of graphs that can be converted.
# See https://github.com/google/jax/blob/4d556837cc9003492f674c012689efc3d68fdf5f/design_notes/omnistaging.md
# See https://github.com/google/jax/blob/4d556837cc9003492f674c012689efc3d68fdf5f/design_notes/omnistaging.md
jax
.
config
.
disable_omnistaging
()
# Older versions < 0.2.0 do not have this flag so we don't need to set it.
jax
.
config
.
update
(
"jax_enable_x64"
,
True
)
try
:
jax
.
config
.
disable_omnistaging
()
except
AttributeError
:
pass
subtensor_ops
=
(
Subtensor
,
AdvancedSubtensor1
,
BaseAdvancedSubtensor
)
subtensor_ops
=
(
Subtensor
,
AdvancedSubtensor1
,
BaseAdvancedSubtensor
)
incsubtensor_ops
=
(
IncSubtensor
,
AdvancedIncSubtensor1
,
BaseAdvancedIncSubtensor
)
incsubtensor_ops
=
(
IncSubtensor
,
AdvancedIncSubtensor1
,
BaseAdvancedIncSubtensor
)
...
@@ -629,3 +654,112 @@ def jax_funcify_Join(op):
...
@@ -629,3 +654,112 @@ def jax_funcify_Join(op):
return
jnp
.
concatenate
(
tensors
,
axis
=
axis
)
return
jnp
.
concatenate
(
tensors
,
axis
=
axis
)
return
join
return
join
@jax_funcify.register
(
ExtractDiag
)
def
jax_funcify_ExtractDiag
(
op
):
offset
=
op
.
offset
axis1
=
op
.
axis1
axis2
=
op
.
axis2
def
extract_diag
(
x
,
offset
=
offset
,
axis1
=
axis1
,
axis2
=
axis2
):
return
jnp
.
diagonal
(
x
,
offset
=
offset
,
axis1
=
axis1
,
axis2
=
axis2
)
return
extract_diag
@jax_funcify.register
(
Cholesky
)
def
jax_funcify_Cholesky
(
op
):
lower
=
op
.
lower
def
cholesky
(
a
,
lower
=
lower
):
return
jsp
.
linalg
.
cholesky
(
a
,
lower
=
lower
)
.
astype
(
a
.
dtype
)
return
cholesky
@jax_funcify.register
(
AllocDiag
)
def
jax_funcify_AllocDiag
(
op
):
def
alloc_diag
(
x
):
return
jnp
.
diag
(
x
)
return
alloc_diag
@jax_funcify.register
(
Solve
)
def
jax_funcify_Solve
(
op
):
if
op
.
A_structure
==
"lower_triangular"
:
lower
=
True
else
:
lower
=
False
def
solve
(
a
,
b
,
lower
=
lower
):
return
jsp
.
linalg
.
solve
(
a
,
b
,
lower
=
lower
)
return
solve
@jax_funcify.register
(
Det
)
def
jax_funcify_Det
(
op
):
def
det
(
x
):
return
jnp
.
linalg
.
det
(
x
)
return
det
@jax_funcify.register
(
Eig
)
def
jax_funcify_Eig
(
op
):
def
eig
(
x
):
return
jnp
.
linalg
.
eig
(
x
)
return
eig
@jax_funcify.register
(
Eigh
)
def
jax_funcify_Eigh
(
op
):
uplo
=
op
.
UPLO
def
eigh
(
x
,
uplo
=
uplo
):
return
jnp
.
linalg
.
eigh
(
x
,
UPLO
=
uplo
)
return
eigh
@jax_funcify.register
(
MatrixInverse
)
def
jax_funcify_MatrixInverse
(
op
):
def
matrix_inverse
(
x
):
return
jnp
.
linalg
.
inv
(
x
)
return
matrix_inverse
@jax_funcify.register
(
QRFull
)
def
jax_funcify_QRFull
(
op
):
mode
=
op
.
mode
def
qr_full
(
x
,
mode
=
mode
):
return
jnp
.
linalg
.
qr
(
x
,
mode
=
mode
)
return
qr_full
@jax_funcify.register
(
QRIncomplete
)
def
jax_funcify_QRIncomplete
(
op
):
mode
=
op
.
mode
def
qr_incomplete
(
x
,
mode
=
mode
):
return
jnp
.
linalg
.
qr
(
x
,
mode
=
mode
)
return
qr_incomplete
@jax_funcify.register
(
SVD
)
def
jax_funcify_SVD
(
op
):
full_matrices
=
op
.
full_matrices
compute_uv
=
op
.
compute_uv
def
svd
(
x
,
full_matrices
=
full_matrices
,
compute_uv
=
compute_uv
):
return
jnp
.
linalg
.
svd
(
x
,
full_matrices
=
full_matrices
,
compute_uv
=
compute_uv
)
return
svd
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论