Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
0c203e99
Unverified
提交
0c203e99
authored
11月 20, 2020
作者:
Brandon T. Willard
提交者:
GitHub
11月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #169 from junpenglao/jax_scan
Implement a JAX conversion for the Scan Op
上级
454ae317
4fee2746
全部展开
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
90 行增加
和
55 行删除
+90
-55
test_jax.py
tests/sandbox/test_jax.py
+0
-0
jaxify.py
theano/sandbox/jaxify.py
+75
-40
utils.py
theano/scan/utils.py
+15
-15
没有找到文件。
tests/sandbox/test_jax.py
浏览文件 @
0c203e99
差异被折叠。
点击展开。
theano/sandbox/jaxify.py
浏览文件 @
0c203e99
...
...
@@ -177,6 +177,7 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
def
jax_func
(
*
inputs
):
func_args
=
[
fn
(
*
inputs
)
for
fn
in
input_funcs
]
# func_args = jax.tree_map(lambda fn: fn(*inputs), input_funcs)
return
return_func
(
*
func_args
)
jax_funcs
.
append
(
update_wrapper
(
jax_func
,
return_func
))
...
...
@@ -420,7 +421,7 @@ def jax_funcify_Scan(op):
def
scan
(
*
outer_inputs
):
scan_args
=
ScanArgs
(
outer_inputs
,
[
None
]
*
op
.
n_outs
,
op
.
inputs
,
op
.
outputs
,
op
.
info
list
(
outer_inputs
)
,
[
None
]
*
op
.
n_outs
,
op
.
inputs
,
op
.
outputs
,
op
.
info
)
# `outer_inputs` is a list with the following composite form:
...
...
@@ -435,9 +436,9 @@ def jax_funcify_Scan(op):
n_steps
=
scan_args
.
n_steps
seqs
=
scan_args
.
outer_in_seqs
n_non_seqs
=
len
(
scan_args
.
outer_in_non_seqs
)
# TODO: mit_mots
mit_mot_in_slices
=
[]
# TODO: sit_sots
mit_sot_in_slices
=
[]
for
tap
,
seq
in
zip
(
scan_args
.
mit_sot_in_slices
,
scan_args
.
outer_in_mit_sot
):
neg_taps
=
[
abs
(
t
)
for
t
in
tap
if
t
<
0
]
...
...
@@ -447,7 +448,15 @@ def jax_funcify_Scan(op):
init_slice
=
seq
[:
max_neg
+
max_pos
]
mit_sot_in_slices
.
append
(
init_slice
)
init_carry
=
[
mit_sot_in_slices
,
scan_args
.
outer_in_non_seqs
]
sit_sot_in_slices
=
[
seq
[
0
]
for
seq
in
scan_args
.
outer_in_sit_sot
]
init_carry
=
(
mit_mot_in_slices
,
mit_sot_in_slices
,
sit_sot_in_slices
,
scan_args
.
outer_in_shared
,
scan_args
.
outer_in_non_seqs
,
)
def
jax_args_to_inner_scan
(
op
,
carry
,
x
):
# `carry` contains all inner-output taps, non_seqs, and shared
...
...
@@ -470,15 +479,22 @@ def jax_funcify_Scan(op):
# + inner_in_sit_sot
# + inner_in_shared
# + inner_in_non_seqs
inner_scan_inputs
=
[
inner_in_mit_sot_flatten
=
[]
for
array
,
index
in
zip
(
inner_in_mit_sot
,
scan_args
.
mit_sot_in_slices
):
inner_in_mit_sot_flatten
.
extend
(
array
[
index
])
inner_scan_inputs
=
sum
(
[
inner_in_seqs
,
inner_in_mit_mot
,
inner_in_mit_sot
,
inner_in_mit_sot_flatten
,
inner_in_sit_sot
,
inner_in_shared
,
inner_in_non_seqs
,
]
],
[],
)
raise
NotImplementedError
()
return
inner_scan_inputs
def
inner_scan_outs_to_jax_outs
(
...
...
@@ -486,47 +502,66 @@ def jax_funcify_Scan(op):
old_carry
,
inner_scan_outs
,
):
# `inner_scan_outs` is a list with the following
# composite form:
# outer_out_mit_mot
# + outer_out_mit_sot
# + outer_out_sit_sot
# + outer_out_nit_sot
# + outer_out_shared
# + cond
(
outer_out_mit_mot
,
outer_out_mit_sot
,
outer_out_sit_sot
,
outer_out_nit_sot
,
outer_out_shared
,
cond
,
)
=
inner_scan_outs
outer_out_non_seqs
=
old_carry
[:
-
n_non_seqs
]
inner_in_mit_mot
,
inner_in_mit_sot
,
inner_in_sit_sot
,
inner_in_shared
,
inner_in_non_seqs
,
)
=
old_carry
def
update_mit_sot
(
mit_sot
,
new_val
):
return
jnp
.
concatenate
([
mit_sot
[
1
:],
new_val
[
None
,
...
]],
axis
=
0
)
inner_out_mit_sot
=
[
update_mit_sot
(
mit_sot
,
new_val
)
for
mit_sot
,
new_val
in
zip
(
inner_in_mit_sot
,
inner_scan_outs
)
]
# This should contain all inner-output taps, non_seqs, and shared
# terms
carry
=
[
outer_out_mit_mot
,
outer_out_mit_sot
,
outer_out_sit_sot
,
outer_out_shared
,
outer_out_non_seqs
,
]
# This should contain all inner-outputs that produce
# outer-outputs
y
=
[]
if
not
inner_in_sit_sot
:
inner_out_sit_sot
=
[]
else
:
inner_out_sit_sot
=
inner_scan_outs
new_carry
=
(
inner_in_mit_mot
,
inner_out_mit_sot
,
inner_out_sit_sot
,
inner_in_shared
,
inner_in_non_seqs
,
)
raise
NotImplementedError
()
return
(
carry
,
y
)
return
new_carry
def
jax_inner_func
(
carry
,
x
):
inner_args
=
jax_args_to_inner_scan
(
op
,
carry
,
x
)
inner_scan_outs
=
jax_tt_inner_func
(
*
inner_args
)
new_carry
,
y
=
inner_scan_outs_to_jax_outs
(
op
,
inner_scan_outs
)
return
new_carry
,
y
inner_scan_outs
=
[
fn
(
*
inner_args
)
for
fn
in
jax_tt_inner_func
]
new_carry
=
inner_scan_outs_to_jax_outs
(
op
,
carry
,
inner_scan_outs
)
return
new_carry
,
inner_scan_outs
_
,
scan_out
=
jax
.
lax
.
scan
(
jax_inner_func
,
init_carry
,
seqs
,
length
=
n_steps
)
# We need to prepend the initial values so that the JAX output will
# match the raw `Scan` `Op` output and, thus, work with a downstream
# `Subtensor` `Op` introduced by the `scan` helper function.
def
append_scan_out
(
scan_in_part
,
scan_out_part
):
return
jnp
.
concatenate
([
scan_in_part
[:
-
n_steps
],
scan_out_part
],
axis
=
0
)
if
scan_args
.
outer_in_mit_sot
:
scan_out_final
=
[
append_scan_out
(
init
,
out
)
for
init
,
out
in
zip
(
scan_args
.
outer_in_mit_sot
,
scan_out
)
]
elif
scan_args
.
outer_in_sit_sot
:
scan_out_final
=
[
append_scan_out
(
init
,
out
)
for
init
,
out
in
zip
(
scan_args
.
outer_in_sit_sot
,
scan_out
)
]
return
jax
.
lax
.
scan
(
jax_inner_func
,
init_carry
,
seqs
,
length
=
n_steps
)
if
len
(
scan_out_final
)
==
1
:
scan_out_final
=
scan_out_final
[
0
]
return
scan_out_final
return
scan
...
...
theano/scan/utils.py
浏览文件 @
0c203e99
...
...
@@ -1075,8 +1075,9 @@ class scan_args:
if
k
in
info
:
self
.
other_info
[
k
]
=
info
[
k
]
inner_inputs
=
property
(
lambda
self
:
(
@property
def
inner_inputs
(
self
):
return
(
self
.
inner_in_seqs
+
sum
(
self
.
inner_in_mit_mot
,
[])
+
sum
(
self
.
inner_in_mit_sot
,
[])
...
...
@@ -1084,10 +1085,10 @@ class scan_args:
+
self
.
inner_in_shared
+
self
.
inner_in_non_seqs
)
)
outer_inputs
=
property
(
lambda
self
:
(
@property
def
outer_inputs
(
self
):
return
(
[
self
.
n_steps
]
+
self
.
outer_in_seqs
+
self
.
outer_in_mit_mot
...
...
@@ -1097,10 +1098,10 @@ class scan_args:
+
self
.
outer_in_nit_sot
+
self
.
outer_in_non_seqs
)
)
inner_outputs
=
property
(
lambda
self
:
(
@property
def
inner_outputs
(
self
):
return
(
sum
(
self
.
inner_out_mit_mot
,
[])
+
self
.
inner_out_mit_sot
+
self
.
inner_out_sit_sot
...
...
@@ -1108,20 +1109,20 @@ class scan_args:
+
self
.
inner_out_shared
+
self
.
cond
)
)
outer_outputs
=
property
(
lambda
self
:
(
@property
def
outer_outputs
(
self
):
return
(
self
.
outer_out_mit_mot
+
self
.
outer_out_mit_sot
+
self
.
outer_out_sit_sot
+
self
.
outer_out_nit_sot
+
self
.
outer_out_shared
)
)
info
=
property
(
lambda
self
:
OrderedDict
(
@property
def
info
(
self
):
return
OrderedDict
(
n_seqs
=
len
(
self
.
outer_in_seqs
),
n_mit_mot
=
len
(
self
.
outer_in_mit_mot
),
n_mit_sot
=
len
(
self
.
outer_in_mit_sot
),
...
...
@@ -1137,7 +1138,6 @@ class scan_args:
mit_mot_out_slices
=
self
.
mit_mot_out_slices
,
**
self
.
other_info
,
)
)
def
__copy__
(
self
):
res
=
object
.
__new__
(
type
(
self
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论