Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
79112666
提交
79112666
authored
2月 12, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
2月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove unused CheckAndRaise code and minor refactoring to MergeOptimizer
上级
48c9ef88
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
51 行增加
和
203 行删除
+51
-203
opt.py
aesara/graph/opt.py
+48
-199
test_opt.py
tests/graph/test_opt.py
+3
-4
没有找到文件。
aesara/graph/opt.py
浏览文件 @
79112666
...
...
@@ -529,9 +529,9 @@ class MergeFeature(Feature):
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
):
# If inputs to node change, it is not guaranteed that it is distinct
# from the other nodes in nodes_seen
if
node
in
self
.
nodes_seen
:
# If inputs to a node change, it's not guaranteed that the node is
# distinct from the other nodes in `self.nodes_seen`.
self
.
nodes_seen
.
discard
(
node
)
self
.
process_node
(
fgraph
,
node
)
...
...
@@ -580,52 +580,30 @@ class MergeFeature(Feature):
self
.
seen_constants
.
add
(
id
(
c
))
def
process_node
(
self
,
fgraph
,
node
):
"""Check if a `node` can be merged, and queue that replacement."""
r"""Check if a `node` can be merged, and queue that replacement.
When `node` is changed we check for other nodes (via the clients map)
that depend on the same inputs. If any of those other nodes have the
same inputs and `Op` as `node`, they are queued to be merged.
"""
if
node
in
self
.
nodes_seen
:
return
node_has_assert
=
False
# These asserts ensure that the fgraph has set the clients field
# properly.
# The clients should at least contain `node` itself!
if
node
.
inputs
:
# Take the smallest clients list. Some ops like elemwise
# have optimization that put constant as the first inputs.
# As constant have in general more clients than other type of nodes
# using always inputs[0] make us look at more nodes.
# Always pick the smallest clints list between inputs 0
# and -1 speed up optimization.
a_clients
=
fgraph
.
clients
[
node
.
inputs
[
0
]]
b_clients
=
fgraph
.
clients
[
node
.
inputs
[
-
1
]]
if
len
(
a_clients
)
<
len
(
b_clients
):
clients
=
a_clients
else
:
clients
=
b_clients
# We use the smallest clients list. Some `Op`s like `Elemwise`
# have optimizations that put constants as the first inputs. Since
# constants generally have more clients than other types of nodes,
# using `node.inputs[0]` will make us look at more nodes on
# average, so by picking the smallest clients list, we might speed
# things up?
clients
=
sorted
(
(
fgraph
.
clients
[
inp
]
for
inp
in
node
.
inputs
),
key
=
lambda
x
:
len
(
x
)
)[
0
]
assert
len
(
clients
)
>
0
merge_candidates
=
[
c
for
c
,
i
in
clients
if
c
in
self
.
nodes_seen
]
# Put all clients of `CheckAndRaise` inputs (if exist) into
# `merge_candidates`
# TODO: Deactivated for now, because it can create cycles in a
# graph. (There is a second deactivation part below.)
# for i in node.inputs:
# if i.owner and isinstance(i.owner.op, CheckAndRaise):
# node_has_assert = True
# i_clients = fgraph.clients[i.owner.inputs[0]]
# assert_clients = [c for (c, _) in i_clients if c in self.nodes_seen]
#
# for idx in range(len(assert_clients)):
# client = assert_clients[idx]
# if isinstance(i.owner.op, CheckAndRaise):
# o_clients = fgraph.clients[client.outputs[0]]
# for c in o_clients:
# if c[0] in self.nodes_seen:
# assert_clients.append(c[0])
#
# merge_candidates.extend(assert_clients)
else
:
# If two nodes have no input, but perform the same operation,
# they are not always constant-folded, so we want to merge them.
...
...
@@ -639,41 +617,9 @@ class MergeFeature(Feature):
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
continue
cand_has_assert
=
False
# Get input list of the candidate with assert removed
cand_inputs_assert_removed
=
[]
# TODO: Deactivated while `CheckAndRaise` merging is disabled. (See
# above and below.)
# for i in candidate.inputs:
# if i.owner and isinstance(i.owner.op, CheckAndRaise):
# cand_has_assert = True
# cand_inputs_assert_removed.append(i.owner.inputs[0])
# else:
# cand_inputs_assert_removed.append(i)
# TODO: Remove this when `CheckAndRaise` merging is
# re-enabled. (See above.) Without `CheckAndRaise` merging we can
# still look for identical `CheckAndRaise`, so we should not treat
# `CheckAndRaise`s separately for now.
cand_inputs_assert_removed
=
candidate
.
inputs
# Get input list of the node with assert removed
# if node_has_assert:
# node_inputs_assert_removed = []
# for i in node.inputs:
# if i.owner and isinstance(i.owner.op, CheckAndRaise):
# node_inputs_assert_removed.append(i.owner.inputs[0])
# else:
# node_inputs_assert_removed.append(i)
# else:
node_inputs_assert_removed
=
node
.
inputs
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node_inputs_assert_removed
,
cand_inputs_assert_removed
)
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
)
)
if
inputs_match
and
node
.
op
==
candidate
.
op
:
...
...
@@ -681,51 +627,14 @@ class MergeFeature(Feature):
# They were already tried, and there was an error
continue
# replace node with candidate
if
not
node_has_assert
and
not
cand_has_assert
:
# Schedule transfer of clients from node to candidate
pairs
=
list
(
zip
(
node
.
outputs
,
candidate
.
outputs
,
[
"merge"
]
*
len
(
node
.
outputs
),
)
# Schedule transfer of clients from node to candidate
pairs
=
list
(
zip
(
node
.
outputs
,
candidate
.
outputs
,
[
"merge"
]
*
len
(
node
.
outputs
),
)
# # if the current node has assert input, it should not be
# # replaced with a candidate node which has no assert input
# elif node_has_assert and not cand_has_assert:
# pairs = list(
# zip(
# candidate.outputs,
# node.outputs,
# ["merge"] * len(node.outputs),
# )
# )
# else:
# new_inputs = self.get_merged_assert_input(node, candidate)
# new_node = node.op(*new_inputs)
# pairs = list(
# zip(
# node.outputs,
# new_node.owner.outputs,
# ["new_node"] * len(node.outputs),
# )
# ) + list(
# zip(
# candidate.outputs,
# new_node.owner.outputs,
# ["new_node"] * len(node.outputs),
# )
# )
# transfer names
for
pair
in
pairs
:
node_output
,
cand_output
=
pair
[:
2
]
# clobber old name with new one
# it's arbitrary... one of the names has to go
if
node_output
.
name
:
cand_output
.
name
=
node_output
.
name
)
replacement_candidates
.
append
(
pairs
)
...
...
@@ -736,36 +645,6 @@ class MergeFeature(Feature):
if
not
node
.
inputs
:
self
.
noinput_nodes
.
add
(
node
)
# def get_merged_assert_input(self, node, candidate):
# new_inputs = []
# for node_i, cand_i in zip(node.inputs, candidate.inputs):
# if node_i.owner and isinstance(node_i.owner.op, CheckAndRaise):
# if (
# cand_i.owner
# and isinstance(cand_i.owner.op, CheckAndRaise)
# and node_i.owner.op.exc_type == cand_i.owner.op.exc_type
# ):
# # Here two assert nodes are merged.
# # Step 1. Merge conditions of both assert nodes.
# # Step 2. Make the new assert node
# node_cond = node_i.owner.inputs[1:]
# cand_cond = cand_i.owner.inputs[1:]
# new_cond = list(set(node_cond + cand_cond))
# new_raise_op = CheckAndRaise(
# node_i.owner.op.exc_type,
# "; ".join([node_i.owner.op.msg, cand_i.owner.op.msg]),
# )
# new_inputs.append(new_raise_op(*(node_i.owner.inputs[:1] + new_cond)))
#
# # node_i is assert, cand_i is not assert
# else:
# new_inputs.append(node_i)
# else:
# # if node_i is not an assert node, append cand_i
# new_inputs.append(cand_i)
#
# return new_inputs
class
MergeOptimizer
(
GlobalOptimizer
):
r"""Merges parts of the graph that are identical and redundant.
...
...
@@ -786,10 +665,6 @@ class MergeOptimizer(GlobalOptimizer):
fgraph
.
attach_feature
(
MergeFeature
())
def
apply
(
self
,
fgraph
):
from
aesara.raise_op
import
CheckAndRaise
# Constant and non-constant are now applied in the same phase.
# I am not sure why, but it seems to be faster this way.
sched
=
fgraph
.
merge_feature
.
scheduled
nb_fail
=
0
t0
=
time
.
time
()
...
...
@@ -804,51 +679,32 @@ class MergeOptimizer(GlobalOptimizer):
pairs_list
=
sched
.
pop
()
success
=
True
for
pairs_
in
pairs_list
:
# We must check again the equivalence, as the graph
# could've changed. If so, doing the replacement can
# introduce a node that depends on itself. Doing the
# full check of such cycles every time is very time
# consuming. I think this double check is faster than
# doing the full cycle check. The full cycle check is
# skipped by validate() if the graph doesn't contain
# destroyers.
var
,
candidate
,
merge_mode
=
pairs_
[
0
]
# We must check again the equivalence, as the graph could've
# changed. If so, doing the replacement can introduce a node
# that depends on itself. Doing the full check of such cycles
# every time is very time consuming. I think this double check
# is faster than doing the full cycle check. The full cycle
# check is skipped by `Validator.validate` if the graph doesn't
# contain destroyers.
var
,
candidate_var
,
merge_mode
=
pairs_
[
0
]
if
merge_mode
==
"new_node"
and
var
in
fgraph
.
variables
:
pass
elif
var
not
in
fgraph
.
variables
or
candidate
not
in
fgraph
.
variables
:
elif
(
var
not
in
fgraph
.
variables
or
candidate_var
not
in
fgraph
.
variables
):
continue
# Keep len(item) == 2 for item in pairs
pairs
=
[
pair
[:
2
]
for
pair
in
pairs_
]
if
var
.
owner
and
candidate
.
owner
:
node
=
var
.
owner
candidate
=
candidate
.
owner
# Get input list of the candidate node with assert
# nodes removed
cand_inputs_assert_removed
=
[]
for
i
in
candidate
.
inputs
:
if
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
CheckAndRaise
):
cand_inputs_assert_removed
.
append
(
i
.
owner
.
inputs
[
0
])
else
:
cand_inputs_assert_removed
.
append
(
i
)
# Get input list of the node with assert nodes removed
node_inputs_assert_removed
=
[]
for
i
in
node
.
inputs
:
if
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
CheckAndRaise
):
node_inputs_assert_removed
.
append
(
i
.
owner
.
inputs
[
0
])
else
:
node_inputs_assert_removed
.
append
(
i
)
if
var
.
owner
and
candidate_var
.
owner
:
if
merge_mode
==
"new_node"
:
inputs_match
=
True
else
:
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node_inputs_assert_removed
,
cand_inputs_assert_removed
var
.
owner
.
inputs
,
candidate_var
.
owner
.
inputs
)
)
...
...
@@ -862,15 +718,10 @@ class MergeOptimizer(GlobalOptimizer):
clients
=
(
fgraph
.
clients
[
pairs
[
0
][
0
]]
+
fgraph
.
clients
[
pairs
[
0
][
1
]]
)
if
(
sum
(
[
i
in
flatten
(
c
.
op
.
destroy_map
.
values
())
for
c
,
i
in
clients
if
c
!=
"output"
and
c
.
op
.
destroy_map
]
)
>
1
if
any
(
i
in
flatten
(
c
.
op
.
destroy_map
.
values
())
for
c
,
i
in
clients
if
c
!=
"output"
and
c
.
op
.
destroy_map
):
continue
...
...
@@ -884,10 +735,8 @@ class MergeOptimizer(GlobalOptimizer):
pairs
=
[(
pairs
[
0
][
1
],
pairs
[
0
][
0
])]
try
:
# If all Constants, no need to call validate.
# Only need to check one of the var of each pairs.
# If it is a Constant, the other must also be a Constant as we merge them.
if
all
(
isinstance
(
old
,
Constant
)
for
old
,
new
in
pairs
):
# If they're all `Constant`s, there's no need to call validate.
if
all
(
isinstance
(
old
,
Constant
)
for
old
,
_
in
pairs
):
fgraph
.
replace_all
(
pairs
,
reason
=
"MergeOptimizer"
)
else
:
fgraph
.
replace_all_validate
(
pairs
,
reason
=
"MergeOptimizer"
)
...
...
@@ -902,7 +751,6 @@ class MergeOptimizer(GlobalOptimizer):
nb_merged
+=
len
(
pairs
)
if
isinstance
(
pairs
[
0
][
0
],
Constant
):
nb_constant
+=
1
# print pairs, pairs[0][0].type
break
if
fgraph
.
profile
:
...
...
@@ -920,8 +768,9 @@ class MergeOptimizer(GlobalOptimizer):
validate_time
=
None
callback_time
=
None
callbacks_time
=
{}
# clear blacklist
fgraph
.
merge_feature
.
blacklist
=
[]
return
(
nb_fail
,
time
.
time
()
-
t0
,
...
...
tests/graph/test_opt.py
浏览文件 @
79112666
...
...
@@ -324,7 +324,7 @@ class TestMergeOptimizer:
@pytest.mark.skip
(
reason
=
"This was disabled for some unknown reason"
)
def
test_one_assert_merge
(
self
):
# Merge two nodes, one has assert, the other not.
"""Merge two nodes, one has assert, the other not."""
x1
=
matrix
(
"x1"
)
x2
=
matrix
(
"x2"
)
e
=
dot
(
x1
,
x2
)
+
dot
(
assert_op
(
x1
,
(
x1
>
x2
)
.
all
()),
x2
)
...
...
@@ -342,8 +342,7 @@ class TestMergeOptimizer:
assert
add_inputs
[
0
]
is
add_inputs
[
1
]
def
test_both_assert_merge_identical
(
self
):
# Merge two nodes, both have assert on the same node
# with the same conditions.
"""Merge two nodes, both have `Assert`s on the same node with the same conditions."""
x1
=
matrix
(
"x1"
)
x2
=
matrix
(
"x2"
)
e
=
dot
(
assert_op
(
x1
,
(
x1
>
x2
)
.
all
()),
x2
)
+
dot
(
...
...
@@ -434,7 +433,7 @@ class TestMergeOptimizer:
assert
add_inputs
[
0
]
is
add_inputs
[
1
]
def
test_merge_noinput
(
self
):
# Check that identical Apply nodes without inputs will be merged
"""Check that identical Apply nodes without inputs will be merged."""
x
=
NoInputOp
(
param
=
0
)()
y
=
NoInputOp
(
param
=
0
)()
z
=
NoInputOp
(
param
=
1
)()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论