Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
79112666
提交
79112666
authored
2月 12, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
2月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove unused CheckAndRaise code and minor refactoring to MergeOptimizer
上级
48c9ef88
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
41 行增加
和
193 行删除
+41
-193
opt.py
aesara/graph/opt.py
+38
-189
test_opt.py
tests/graph/test_opt.py
+3
-4
没有找到文件。
aesara/graph/opt.py
浏览文件 @
79112666
...
@@ -529,9 +529,9 @@ class MergeFeature(Feature):
...
@@ -529,9 +529,9 @@ class MergeFeature(Feature):
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
):
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
):
# If inputs to node change, it is not guaranteed that it is distinct
# from the other nodes in nodes_seen
if
node
in
self
.
nodes_seen
:
if
node
in
self
.
nodes_seen
:
# If inputs to a node change, it's not guaranteed that the node is
# distinct from the other nodes in `self.nodes_seen`.
self
.
nodes_seen
.
discard
(
node
)
self
.
nodes_seen
.
discard
(
node
)
self
.
process_node
(
fgraph
,
node
)
self
.
process_node
(
fgraph
,
node
)
...
@@ -580,52 +580,30 @@ class MergeFeature(Feature):
...
@@ -580,52 +580,30 @@ class MergeFeature(Feature):
self
.
seen_constants
.
add
(
id
(
c
))
self
.
seen_constants
.
add
(
id
(
c
))
def
process_node
(
self
,
fgraph
,
node
):
def
process_node
(
self
,
fgraph
,
node
):
"""Check if a `node` can be merged, and queue that replacement."""
r"""Check if a `node` can be merged, and queue that replacement.
When `node` is changed we check for other nodes (via the clients map)
that depend on the same inputs. If any of those other nodes have the
same inputs and `Op` as `node`, they are queued to be merged.
"""
if
node
in
self
.
nodes_seen
:
if
node
in
self
.
nodes_seen
:
return
return
node_has_assert
=
False
# These asserts ensure that the fgraph has set the clients field
# properly.
# The clients should at least contain `node` itself!
if
node
.
inputs
:
if
node
.
inputs
:
# Take the smallest clients list. Some ops like elemwise
# We use the smallest clients list. Some `Op`s like `Elemwise`
# have optimization that put constant as the first inputs.
# have optimizations that put constants as the first inputs. Since
# As constant have in general more clients than other type of nodes
# constants generally have more clients than other types of nodes,
# using always inputs[0] make us look at more nodes.
# using `node.inputs[0]` will make us look at more nodes on
# Always pick the smallest clints list between inputs 0
# average, so by picking the smallest clients list, we might speed
# and -1 speed up optimization.
# things up?
clients
=
sorted
(
a_clients
=
fgraph
.
clients
[
node
.
inputs
[
0
]]
(
fgraph
.
clients
[
inp
]
for
inp
in
node
.
inputs
),
key
=
lambda
x
:
len
(
x
)
b_clients
=
fgraph
.
clients
[
node
.
inputs
[
-
1
]]
)[
0
]
if
len
(
a_clients
)
<
len
(
b_clients
):
clients
=
a_clients
else
:
clients
=
b_clients
assert
len
(
clients
)
>
0
assert
len
(
clients
)
>
0
merge_candidates
=
[
c
for
c
,
i
in
clients
if
c
in
self
.
nodes_seen
]
merge_candidates
=
[
c
for
c
,
i
in
clients
if
c
in
self
.
nodes_seen
]
# Put all clients of `CheckAndRaise` inputs (if exist) into
# `merge_candidates`
# TODO: Deactivated for now, because it can create cycles in a
# graph. (There is a second deactivation part below.)
# for i in node.inputs:
# if i.owner and isinstance(i.owner.op, CheckAndRaise):
# node_has_assert = True
# i_clients = fgraph.clients[i.owner.inputs[0]]
# assert_clients = [c for (c, _) in i_clients if c in self.nodes_seen]
#
# for idx in range(len(assert_clients)):
# client = assert_clients[idx]
# if isinstance(i.owner.op, CheckAndRaise):
# o_clients = fgraph.clients[client.outputs[0]]
# for c in o_clients:
# if c[0] in self.nodes_seen:
# assert_clients.append(c[0])
#
# merge_candidates.extend(assert_clients)
else
:
else
:
# If two nodes have no input, but perform the same operation,
# If two nodes have no input, but perform the same operation,
# they are not always constant-folded, so we want to merge them.
# they are not always constant-folded, so we want to merge them.
...
@@ -639,41 +617,9 @@ class MergeFeature(Feature):
...
@@ -639,41 +617,9 @@ class MergeFeature(Feature):
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
continue
continue
cand_has_assert
=
False
# Get input list of the candidate with assert removed
cand_inputs_assert_removed
=
[]
# TODO: Deactivated while `CheckAndRaise` merging is disabled. (See
# above and below.)
# for i in candidate.inputs:
# if i.owner and isinstance(i.owner.op, CheckAndRaise):
# cand_has_assert = True
# cand_inputs_assert_removed.append(i.owner.inputs[0])
# else:
# cand_inputs_assert_removed.append(i)
# TODO: Remove this when `CheckAndRaise` merging is
# re-enabled. (See above.) Without `CheckAndRaise` merging we can
# still look for identical `CheckAndRaise`, so we should not treat
# `CheckAndRaise`s separately for now.
cand_inputs_assert_removed
=
candidate
.
inputs
# Get input list of the node with assert removed
# if node_has_assert:
# node_inputs_assert_removed = []
# for i in node.inputs:
# if i.owner and isinstance(i.owner.op, CheckAndRaise):
# node_inputs_assert_removed.append(i.owner.inputs[0])
# else:
# node_inputs_assert_removed.append(i)
# else:
node_inputs_assert_removed
=
node
.
inputs
inputs_match
=
all
(
inputs_match
=
all
(
node_in
is
cand_in
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
)
node_inputs_assert_removed
,
cand_inputs_assert_removed
)
)
)
if
inputs_match
and
node
.
op
==
candidate
.
op
:
if
inputs_match
and
node
.
op
==
candidate
.
op
:
...
@@ -681,8 +627,6 @@ class MergeFeature(Feature):
...
@@ -681,8 +627,6 @@ class MergeFeature(Feature):
# They were already tried, and there was an error
# They were already tried, and there was an error
continue
continue
# replace node with candidate
if
not
node_has_assert
and
not
cand_has_assert
:
# Schedule transfer of clients from node to candidate
# Schedule transfer of clients from node to candidate
pairs
=
list
(
pairs
=
list
(
zip
(
zip
(
...
@@ -692,41 +636,6 @@ class MergeFeature(Feature):
...
@@ -692,41 +636,6 @@ class MergeFeature(Feature):
)
)
)
)
# # if the current node has assert input, it should not be
# # replaced with a candidate node which has no assert input
# elif node_has_assert and not cand_has_assert:
# pairs = list(
# zip(
# candidate.outputs,
# node.outputs,
# ["merge"] * len(node.outputs),
# )
# )
# else:
# new_inputs = self.get_merged_assert_input(node, candidate)
# new_node = node.op(*new_inputs)
# pairs = list(
# zip(
# node.outputs,
# new_node.owner.outputs,
# ["new_node"] * len(node.outputs),
# )
# ) + list(
# zip(
# candidate.outputs,
# new_node.owner.outputs,
# ["new_node"] * len(node.outputs),
# )
# )
# transfer names
for
pair
in
pairs
:
node_output
,
cand_output
=
pair
[:
2
]
# clobber old name with new one
# it's arbitrary... one of the names has to go
if
node_output
.
name
:
cand_output
.
name
=
node_output
.
name
replacement_candidates
.
append
(
pairs
)
replacement_candidates
.
append
(
pairs
)
if
replacement_candidates
:
if
replacement_candidates
:
...
@@ -736,36 +645,6 @@ class MergeFeature(Feature):
...
@@ -736,36 +645,6 @@ class MergeFeature(Feature):
if
not
node
.
inputs
:
if
not
node
.
inputs
:
self
.
noinput_nodes
.
add
(
node
)
self
.
noinput_nodes
.
add
(
node
)
# def get_merged_assert_input(self, node, candidate):
# new_inputs = []
# for node_i, cand_i in zip(node.inputs, candidate.inputs):
# if node_i.owner and isinstance(node_i.owner.op, CheckAndRaise):
# if (
# cand_i.owner
# and isinstance(cand_i.owner.op, CheckAndRaise)
# and node_i.owner.op.exc_type == cand_i.owner.op.exc_type
# ):
# # Here two assert nodes are merged.
# # Step 1. Merge conditions of both assert nodes.
# # Step 2. Make the new assert node
# node_cond = node_i.owner.inputs[1:]
# cand_cond = cand_i.owner.inputs[1:]
# new_cond = list(set(node_cond + cand_cond))
# new_raise_op = CheckAndRaise(
# node_i.owner.op.exc_type,
# "; ".join([node_i.owner.op.msg, cand_i.owner.op.msg]),
# )
# new_inputs.append(new_raise_op(*(node_i.owner.inputs[:1] + new_cond)))
#
# # node_i is assert, cand_i is not assert
# else:
# new_inputs.append(node_i)
# else:
# # if node_i is not an assert node, append cand_i
# new_inputs.append(cand_i)
#
# return new_inputs
class
MergeOptimizer
(
GlobalOptimizer
):
class
MergeOptimizer
(
GlobalOptimizer
):
r"""Merges parts of the graph that are identical and redundant.
r"""Merges parts of the graph that are identical and redundant.
...
@@ -786,10 +665,6 @@ class MergeOptimizer(GlobalOptimizer):
...
@@ -786,10 +665,6 @@ class MergeOptimizer(GlobalOptimizer):
fgraph
.
attach_feature
(
MergeFeature
())
fgraph
.
attach_feature
(
MergeFeature
())
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
from
aesara.raise_op
import
CheckAndRaise
# Constant and non-constant are now applied in the same phase.
# I am not sure why, but it seems to be faster this way.
sched
=
fgraph
.
merge_feature
.
scheduled
sched
=
fgraph
.
merge_feature
.
scheduled
nb_fail
=
0
nb_fail
=
0
t0
=
time
.
time
()
t0
=
time
.
time
()
...
@@ -804,51 +679,32 @@ class MergeOptimizer(GlobalOptimizer):
...
@@ -804,51 +679,32 @@ class MergeOptimizer(GlobalOptimizer):
pairs_list
=
sched
.
pop
()
pairs_list
=
sched
.
pop
()
success
=
True
success
=
True
for
pairs_
in
pairs_list
:
for
pairs_
in
pairs_list
:
# We must check again the equivalence, as the graph
# We must check again the equivalence, as the graph could've
# could've changed. If so, doing the replacement can
# changed. If so, doing the replacement can introduce a node
# introduce a node that depends on itself. Doing the
# that depends on itself. Doing the full check of such cycles
# full check of such cycles every time is very time
# every time is very time consuming. I think this double check
# consuming. I think this double check is faster than
# is faster than doing the full cycle check. The full cycle
# doing the full cycle check. The full cycle check is
# check is skipped by `Validator.validate` if the graph doesn't
# skipped by validate() if the graph doesn't contain
# contain destroyers.
# destroyers.
var
,
candidate_var
,
merge_mode
=
pairs_
[
0
]
var
,
candidate
,
merge_mode
=
pairs_
[
0
]
if
merge_mode
==
"new_node"
and
var
in
fgraph
.
variables
:
if
merge_mode
==
"new_node"
and
var
in
fgraph
.
variables
:
pass
pass
elif
var
not
in
fgraph
.
variables
or
candidate
not
in
fgraph
.
variables
:
elif
(
var
not
in
fgraph
.
variables
or
candidate_var
not
in
fgraph
.
variables
):
continue
continue
# Keep len(item) == 2 for item in pairs
# Keep len(item) == 2 for item in pairs
pairs
=
[
pair
[:
2
]
for
pair
in
pairs_
]
pairs
=
[
pair
[:
2
]
for
pair
in
pairs_
]
if
var
.
owner
and
candidate
.
owner
:
if
var
.
owner
and
candidate_var
.
owner
:
node
=
var
.
owner
candidate
=
candidate
.
owner
# Get input list of the candidate node with assert
# nodes removed
cand_inputs_assert_removed
=
[]
for
i
in
candidate
.
inputs
:
if
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
CheckAndRaise
):
cand_inputs_assert_removed
.
append
(
i
.
owner
.
inputs
[
0
])
else
:
cand_inputs_assert_removed
.
append
(
i
)
# Get input list of the node with assert nodes removed
node_inputs_assert_removed
=
[]
for
i
in
node
.
inputs
:
if
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
CheckAndRaise
):
node_inputs_assert_removed
.
append
(
i
.
owner
.
inputs
[
0
])
else
:
node_inputs_assert_removed
.
append
(
i
)
if
merge_mode
==
"new_node"
:
if
merge_mode
==
"new_node"
:
inputs_match
=
True
inputs_match
=
True
else
:
else
:
inputs_match
=
all
(
inputs_match
=
all
(
node_in
is
cand_in
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
for
node_in
,
cand_in
in
zip
(
node_inputs_assert_removed
,
cand_inputs_assert_removed
var
.
owner
.
inputs
,
candidate_var
.
owner
.
inputs
)
)
)
)
...
@@ -862,15 +718,10 @@ class MergeOptimizer(GlobalOptimizer):
...
@@ -862,15 +718,10 @@ class MergeOptimizer(GlobalOptimizer):
clients
=
(
clients
=
(
fgraph
.
clients
[
pairs
[
0
][
0
]]
+
fgraph
.
clients
[
pairs
[
0
][
1
]]
fgraph
.
clients
[
pairs
[
0
][
0
]]
+
fgraph
.
clients
[
pairs
[
0
][
1
]]
)
)
if
(
if
any
(
sum
(
[
i
in
flatten
(
c
.
op
.
destroy_map
.
values
())
i
in
flatten
(
c
.
op
.
destroy_map
.
values
())
for
c
,
i
in
clients
for
c
,
i
in
clients
if
c
!=
"output"
and
c
.
op
.
destroy_map
if
c
!=
"output"
and
c
.
op
.
destroy_map
]
)
>
1
):
):
continue
continue
...
@@ -884,10 +735,8 @@ class MergeOptimizer(GlobalOptimizer):
...
@@ -884,10 +735,8 @@ class MergeOptimizer(GlobalOptimizer):
pairs
=
[(
pairs
[
0
][
1
],
pairs
[
0
][
0
])]
pairs
=
[(
pairs
[
0
][
1
],
pairs
[
0
][
0
])]
try
:
try
:
# If all Constants, no need to call validate.
# If they're all `Constant`s, there's no need to call validate.
# Only need to check one of the var of each pairs.
if
all
(
isinstance
(
old
,
Constant
)
for
old
,
_
in
pairs
):
# If it is a Constant, the other must also be a Constant as we merge them.
if
all
(
isinstance
(
old
,
Constant
)
for
old
,
new
in
pairs
):
fgraph
.
replace_all
(
pairs
,
reason
=
"MergeOptimizer"
)
fgraph
.
replace_all
(
pairs
,
reason
=
"MergeOptimizer"
)
else
:
else
:
fgraph
.
replace_all_validate
(
pairs
,
reason
=
"MergeOptimizer"
)
fgraph
.
replace_all_validate
(
pairs
,
reason
=
"MergeOptimizer"
)
...
@@ -902,7 +751,6 @@ class MergeOptimizer(GlobalOptimizer):
...
@@ -902,7 +751,6 @@ class MergeOptimizer(GlobalOptimizer):
nb_merged
+=
len
(
pairs
)
nb_merged
+=
len
(
pairs
)
if
isinstance
(
pairs
[
0
][
0
],
Constant
):
if
isinstance
(
pairs
[
0
][
0
],
Constant
):
nb_constant
+=
1
nb_constant
+=
1
# print pairs, pairs[0][0].type
break
break
if
fgraph
.
profile
:
if
fgraph
.
profile
:
...
@@ -920,8 +768,9 @@ class MergeOptimizer(GlobalOptimizer):
...
@@ -920,8 +768,9 @@ class MergeOptimizer(GlobalOptimizer):
validate_time
=
None
validate_time
=
None
callback_time
=
None
callback_time
=
None
callbacks_time
=
{}
callbacks_time
=
{}
# clear blacklist
fgraph
.
merge_feature
.
blacklist
=
[]
fgraph
.
merge_feature
.
blacklist
=
[]
return
(
return
(
nb_fail
,
nb_fail
,
time
.
time
()
-
t0
,
time
.
time
()
-
t0
,
...
...
tests/graph/test_opt.py
浏览文件 @
79112666
...
@@ -324,7 +324,7 @@ class TestMergeOptimizer:
...
@@ -324,7 +324,7 @@ class TestMergeOptimizer:
@pytest.mark.skip
(
reason
=
"This was disabled for some unknown reason"
)
@pytest.mark.skip
(
reason
=
"This was disabled for some unknown reason"
)
def
test_one_assert_merge
(
self
):
def
test_one_assert_merge
(
self
):
# Merge two nodes, one has assert, the other not.
"""Merge two nodes, one has assert, the other not."""
x1
=
matrix
(
"x1"
)
x1
=
matrix
(
"x1"
)
x2
=
matrix
(
"x2"
)
x2
=
matrix
(
"x2"
)
e
=
dot
(
x1
,
x2
)
+
dot
(
assert_op
(
x1
,
(
x1
>
x2
)
.
all
()),
x2
)
e
=
dot
(
x1
,
x2
)
+
dot
(
assert_op
(
x1
,
(
x1
>
x2
)
.
all
()),
x2
)
...
@@ -342,8 +342,7 @@ class TestMergeOptimizer:
...
@@ -342,8 +342,7 @@ class TestMergeOptimizer:
assert
add_inputs
[
0
]
is
add_inputs
[
1
]
assert
add_inputs
[
0
]
is
add_inputs
[
1
]
def
test_both_assert_merge_identical
(
self
):
def
test_both_assert_merge_identical
(
self
):
# Merge two nodes, both have assert on the same node
"""Merge two nodes, both have `Assert`s on the same node with the same conditions."""
# with the same conditions.
x1
=
matrix
(
"x1"
)
x1
=
matrix
(
"x1"
)
x2
=
matrix
(
"x2"
)
x2
=
matrix
(
"x2"
)
e
=
dot
(
assert_op
(
x1
,
(
x1
>
x2
)
.
all
()),
x2
)
+
dot
(
e
=
dot
(
assert_op
(
x1
,
(
x1
>
x2
)
.
all
()),
x2
)
+
dot
(
...
@@ -434,7 +433,7 @@ class TestMergeOptimizer:
...
@@ -434,7 +433,7 @@ class TestMergeOptimizer:
assert
add_inputs
[
0
]
is
add_inputs
[
1
]
assert
add_inputs
[
0
]
is
add_inputs
[
1
]
def
test_merge_noinput
(
self
):
def
test_merge_noinput
(
self
):
# Check that identical Apply nodes without inputs will be merged
"""Check that identical Apply nodes without inputs will be merged."""
x
=
NoInputOp
(
param
=
0
)()
x
=
NoInputOp
(
param
=
0
)()
y
=
NoInputOp
(
param
=
0
)()
y
=
NoInputOp
(
param
=
0
)()
z
=
NoInputOp
(
param
=
1
)()
z
=
NoInputOp
(
param
=
1
)()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论