Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
91d3b7c0
提交
91d3b7c0
authored
11月 22, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 23, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Do not merge while scans with different until condition
The rewrite did not check if nominal variables in the graph of the until condition corresponded to the equivalent outer variables
上级
eb552eef
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
169 行增加
和
47 行删除
+169
-47
rewriting.py
pytensor/scan/rewriting.py
+41
-10
test_rewriting.py
tests/scan/test_rewriting.py
+128
-37
没有找到文件。
pytensor/scan/rewriting.py
浏览文件 @
91d3b7c0
...
...
@@ -17,7 +17,9 @@ from pytensor.configdefaults import config
from
pytensor.graph.basic
import
(
Apply
,
Constant
,
NominalVariable
,
Variable
,
ancestors
,
apply_depends_on
,
equal_computations
,
graph_inputs
,
...
...
@@ -1950,11 +1952,13 @@ class ScanMerge(GraphRewriter):
Questionable, we should also consider profile ?
"""
rep
=
set_nodes
[
0
]
op
=
node
.
op
rep_node
=
set_nodes
[
0
]
rep_op
=
rep_node
.
op
if
(
rep
.
op
.
info
.
as_while
!=
node
.
op
.
info
.
as_while
or
node
.
op
.
truncate_gradient
!=
rep
.
op
.
truncate_gradient
or
node
.
op
.
mode
!=
rep
.
op
.
mode
op
.
info
.
as_while
!=
rep_
op
.
info
.
as_while
or
op
.
truncate_gradient
!=
rep_
op
.
truncate_gradient
or
op
.
mode
!=
rep_
op
.
mode
):
return
False
...
...
@@ -1964,7 +1968,7 @@ class ScanMerge(GraphRewriter):
except
NotScalarConstantError
:
pass
rep_nsteps
=
rep
.
inputs
[
0
]
rep_nsteps
=
rep
_node
.
inputs
[
0
]
try
:
rep_nsteps
=
int
(
get_underlying_scalar_constant_value
(
rep_nsteps
))
except
NotScalarConstantError
:
...
...
@@ -1978,14 +1982,41 @@ class ScanMerge(GraphRewriter):
if
apply_depends_on
(
node
,
nd
)
or
apply_depends_on
(
nd
,
node
):
return
False
if
not
node
.
op
.
info
.
as_while
:
if
not
op
.
info
.
as_while
:
return
True
cond
=
node
.
op
.
inner_outputs
[
-
1
]
rep_cond
=
rep
.
op
.
inner_outputs
[
-
1
]
return
equal_computations
(
[
cond
],
[
rep_cond
],
node
.
op
.
inner_inputs
,
rep
.
op
.
inner_inputs
# We need to check the while conditions are identical
conds
=
[
op
.
inner_outputs
[
-
1
]]
rep_conds
=
[
rep_op
.
inner_outputs
[
-
1
]]
if
not
equal_computations
(
conds
,
rep_conds
,
op
.
inner_inputs
,
rep_op
.
inner_inputs
):
return
False
# If they depend on inner inputs we need to check for equivalence on the respective outer inputs
nominal_inputs
=
[
a
for
a
in
ancestors
(
conds
)
if
isinstance
(
a
,
NominalVariable
)]
if
not
nominal_inputs
:
return
True
rep_nominal_inputs
=
[
a
for
a
in
ancestors
(
rep_conds
)
if
isinstance
(
a
,
NominalVariable
)
]
conds
=
[]
rep_conds
=
[]
mapping
=
op
.
get_oinp_iinp_iout_oout_mappings
()[
"outer_inp_from_inner_inp"
]
rep_mapping
=
rep_op
.
get_oinp_iinp_iout_oout_mappings
()[
"outer_inp_from_inner_inp"
]
inner_inputs
=
op
.
inner_inputs
rep_inner_inputs
=
rep_op
.
inner_inputs
for
nominal_input
,
rep_nominal_input
in
zip
(
nominal_inputs
,
rep_nominal_inputs
):
conds
.
append
(
node
.
inputs
[
mapping
[
inner_inputs
.
index
(
nominal_input
)]])
rep_conds
.
append
(
rep_node
.
inputs
[
rep_mapping
[
rep_inner_inputs
.
index
(
rep_nominal_input
)]]
)
return
equal_computations
(
conds
,
rep_conds
)
def
apply
(
self
,
fgraph
):
# Collect all scan nodes ordered according to toposort
scan_nodes
=
[
nd
for
nd
in
fgraph
.
toposort
()
if
isinstance
(
nd
.
op
,
Scan
)]
...
...
tests/scan/test_rewriting.py
浏览文件 @
91d3b7c0
...
...
@@ -15,6 +15,7 @@ from pytensor.graph.replace import clone_replace
from
pytensor.scan.op
import
Scan
from
pytensor.scan.rewriting
import
ScanInplaceOptimizer
,
ScanMerge
from
pytensor.scan.utils
import
until
from
pytensor.tensor
import
stack
from
pytensor.tensor.blas
import
Dot22
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.math
import
Dot
,
dot
,
sigmoid
...
...
@@ -796,7 +797,13 @@ class TestPushOutAddScan:
class
TestScanMerge
:
mode
=
get_default_mode
()
.
including
(
"scan"
)
mode
=
get_default_mode
()
.
including
(
"scan"
)
.
excluding
(
"scan_pushout_seqs_ops"
)
@staticmethod
def
count_scans
(
fn
):
nodes
=
fn
.
maker
.
fgraph
.
apply_nodes
scans
=
[
node
for
node
in
nodes
if
isinstance
(
node
.
op
,
Scan
)]
return
len
(
scans
)
def
test_basic
(
self
):
x
=
vector
()
...
...
@@ -808,56 +815,38 @@ class TestScanMerge:
sx
,
upx
=
scan
(
sum
,
sequences
=
[
x
])
sy
,
upy
=
scan
(
sum
,
sequences
=
[
y
])
f
=
function
(
[
x
,
y
],
[
sx
,
sy
],
mode
=
self
.
mode
.
excluding
(
"scan_pushout_seqs_ops"
)
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
scans
=
[
n
for
n
in
topo
if
isinstance
(
n
.
op
,
Scan
)]
assert
len
(
scans
)
==
2
f
=
function
([
x
,
y
],
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
2
sx
,
upx
=
scan
(
sum
,
sequences
=
[
x
],
n_steps
=
2
)
sy
,
upy
=
scan
(
sum
,
sequences
=
[
y
],
n_steps
=
3
)
f
=
function
(
[
x
,
y
],
[
sx
,
sy
],
mode
=
self
.
mode
.
excluding
(
"scan_pushout_seqs_ops"
)
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
scans
=
[
n
for
n
in
topo
if
isinstance
(
n
.
op
,
Scan
)]
assert
len
(
scans
)
==
2
f
=
function
([
x
,
y
],
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
2
sx
,
upx
=
scan
(
sum
,
sequences
=
[
x
],
n_steps
=
4
)
sy
,
upy
=
scan
(
sum
,
sequences
=
[
y
],
n_steps
=
4
)
f
=
function
(
[
x
,
y
],
[
sx
,
sy
],
mode
=
self
.
mode
.
excluding
(
"scan_pushout_seqs_ops"
)
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
scans
=
[
n
for
n
in
topo
if
isinstance
(
n
.
op
,
Scan
)]
assert
len
(
scans
)
==
1
f
=
function
([
x
,
y
],
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
1
sx
,
upx
=
scan
(
sum
,
sequences
=
[
x
])
sy
,
upy
=
scan
(
sum
,
sequences
=
[
x
])
f
=
function
([
x
],
[
sx
,
sy
],
mode
=
self
.
mode
.
excluding
(
"scan_pushout_seqs_ops"
))
topo
=
f
.
maker
.
fgraph
.
toposort
()
scans
=
[
n
for
n
in
topo
if
isinstance
(
n
.
op
,
Scan
)]
assert
len
(
scans
)
==
1
f
=
function
([
x
],
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
1
sx
,
upx
=
scan
(
sum
,
sequences
=
[
x
])
sy
,
upy
=
scan
(
sum
,
sequences
=
[
x
],
mode
=
"FAST_COMPILE"
)
f
=
function
([
x
],
[
sx
,
sy
],
mode
=
self
.
mode
.
excluding
(
"scan_pushout_seqs_ops"
))
topo
=
f
.
maker
.
fgraph
.
toposort
()
scans
=
[
n
for
n
in
topo
if
isinstance
(
n
.
op
,
Scan
)]
assert
len
(
scans
)
==
1
f
=
function
([
x
],
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
1
sx
,
upx
=
scan
(
sum
,
sequences
=
[
x
])
sy
,
upy
=
scan
(
sum
,
sequences
=
[
x
],
truncate_gradient
=
1
)
f
=
function
([
x
],
[
sx
,
sy
],
mode
=
self
.
mode
.
excluding
(
"scan_pushout_seqs_ops"
))
topo
=
f
.
maker
.
fgraph
.
toposort
()
scans
=
[
n
for
n
in
topo
if
isinstance
(
n
.
op
,
Scan
)]
assert
len
(
scans
)
==
2
f
=
function
([
x
],
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
2
def
test_three_scans
(
self
):
r"""
...
...
@@ -877,12 +866,8 @@ class TestScanMerge:
sy
,
upy
=
scan
(
sum
,
sequences
=
[
2
*
y
+
2
],
n_steps
=
4
,
name
=
"Y"
)
sz
,
upz
=
scan
(
sum
,
sequences
=
[
sx
],
n_steps
=
4
,
name
=
"Z"
)
f
=
function
(
[
x
,
y
],
[
sy
,
sz
],
mode
=
self
.
mode
.
excluding
(
"scan_pushout_seqs_ops"
)
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
scans
=
[
n
for
n
in
topo
if
isinstance
(
n
.
op
,
Scan
)]
assert
len
(
scans
)
==
2
f
=
function
([
x
,
y
],
[
sy
,
sz
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
2
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
x_val
=
rng
.
uniform
(
size
=
(
4
,))
.
astype
(
config
.
floatX
)
...
...
@@ -913,6 +898,112 @@ class TestScanMerge:
assert
not
opt_obj
.
belongs_to_set
(
scan_node1
,
[
scan_node2
])
assert
not
opt_obj
.
belongs_to_set
(
scan_node2
,
[
scan_node1
])
@config.change_flags
(
cxx
=
""
)
# Just for faster compilation
def
test_while_scan
(
self
):
x
=
vector
(
"x"
)
y
=
vector
(
"y"
)
def
add
(
s
):
return
s
+
1
,
until
(
s
>
5
)
def
sub
(
s
):
return
s
-
1
,
until
(
s
>
5
)
def
sub_alt
(
s
):
return
s
-
1
,
until
(
s
>
4
)
sx
,
upx
=
scan
(
add
,
sequences
=
[
x
])
sy
,
upy
=
scan
(
sub
,
sequences
=
[
y
])
f
=
function
([
x
,
y
],
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
2
sx
,
upx
=
scan
(
add
,
sequences
=
[
x
])
sy
,
upy
=
scan
(
sub
,
sequences
=
[
x
])
f
=
function
([
x
],
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
1
sx
,
upx
=
scan
(
add
,
sequences
=
[
x
])
sy
,
upy
=
scan
(
sub_alt
,
sequences
=
[
x
])
f
=
function
([
x
],
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
2
@config.change_flags
(
cxx
=
""
)
# Just for faster compilation
def
test_while_scan_nominal_dependency
(
self
):
"""Test case where condition depends on nominal variables.
This is a regression test for #509
"""
c1
=
scalar
(
"c1"
)
c2
=
scalar
(
"c2"
)
x
=
vector
(
"x"
,
shape
=
(
5
,))
y
=
vector
(
"y"
,
shape
=
(
5
,))
z
=
vector
(
"z"
,
shape
=
(
5
,))
def
add
(
s1
,
s2
,
const
):
return
s1
+
1
,
until
(
s2
>
const
)
def
sub
(
s1
,
s2
,
const
):
return
s1
-
1
,
until
(
s2
>
const
)
sx
,
_
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
])
sy
,
_
=
scan
(
sub
,
sequences
=
[
y
,
-
z
],
non_sequences
=
[
c1
])
f
=
pytensor
.
function
(
inputs
=
[
x
,
y
,
z
,
c1
],
outputs
=
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
2
res_sx
,
res_sy
=
f
(
x
=
[
0
,
0
,
0
,
0
,
0
],
y
=
[
0
,
0
,
0
,
0
,
0
],
z
=
[
0
,
1
,
2
,
3
,
4
],
c1
=
0
,
)
np
.
testing
.
assert_array_equal
(
res_sx
,
[
1
,
1
])
np
.
testing
.
assert_array_equal
(
res_sy
,
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
])
sx
,
_
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
])
sy
,
_
=
scan
(
sub
,
sequences
=
[
y
,
z
],
non_sequences
=
[
c2
])
f
=
pytensor
.
function
(
inputs
=
[
x
,
y
,
z
,
c1
,
c2
],
outputs
=
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
2
res_sx
,
res_sy
=
f
(
x
=
[
0
,
0
,
0
,
0
,
0
],
y
=
[
0
,
0
,
0
,
0
,
0
],
z
=
[
0
,
1
,
2
,
3
,
4
],
c1
=
3
,
c2
=
1
,
)
np
.
testing
.
assert_array_equal
(
res_sx
,
[
1
,
1
,
1
,
1
,
1
])
np
.
testing
.
assert_array_equal
(
res_sy
,
[
-
1
,
-
1
,
-
1
])
sx
,
_
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
])
sy
,
_
=
scan
(
sub
,
sequences
=
[
y
,
z
],
non_sequences
=
[
c1
])
f
=
pytensor
.
function
(
inputs
=
[
x
,
y
,
z
,
c1
],
outputs
=
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
1
def
nested_scan
(
c
,
x
,
z
):
sx
,
_
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c
])
sy
,
_
=
scan
(
sub
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c
])
return
sx
.
sum
()
+
sy
.
sum
()
sz
,
_
=
scan
(
nested_scan
,
sequences
=
[
stack
([
c1
,
c2
])],
non_sequences
=
[
x
,
z
],
mode
=
self
.
mode
,
)
f
=
pytensor
.
function
(
inputs
=
[
x
,
z
,
c1
,
c2
],
outputs
=
sz
,
mode
=
mode
)
[
scan_node
]
=
[
node
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
]
inner_f
=
scan_node
.
op
.
fn
assert
self
.
count_scans
(
inner_f
)
==
1
class
TestScanInplaceOptimizer
:
mode
=
get_default_mode
()
.
including
(
"scan_make_inplace"
,
"inplace"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论