Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
730f790e
提交
730f790e
authored
10月 28, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
11月 18, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Convert Scan global optimizers to local optimizers
上级
5a3c0195
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
249 行增加
和
374 行删除
+249
-374
opt.py
aesara/scan/opt.py
+249
-374
没有找到文件。
aesara/scan/opt.py
浏览文件 @
730f790e
"""
"""This module provides optimizations for the `Scan` `Op`."""
This module provides optimizations for scan.
The Optimization provided in this file:
local opt: remove_constants_and_unused_inputs_scan,
constant_folding_for_scan2,
scan_merge_inouts
They are wrapped in in2out to create global opt.
global opt: ScanInplaceOptimizer,
PushOutNonSeqScan,
PushOutSeqScan,
PushOutDot1,
ScanMerge,
ScanSaveMem
How the are registered:
optdb: scan_eqopt1 (.1), scan_eqopt2(1.6), scan_inplace(75)
scan_eqopt1 -> scan_seqopt1
scan_seqopt1 -> in2out(remove_constants_and_unused_inputs_scan)(1),
PushOutNonSeqScan(2),
PushOutSeqScan(3), PushOutDot1(4)
scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
This is important, as the order is important and all global
optimizer run before local optimizer in the order they where
registered. (So don't change the order we register them!)
If we convert to local optimizer, we must convert all of them
to local optimizer. But:
1) can ScanMerge be made local? Can we keep only this one
global?
2) ScanSaveMem assert that we remove all nodes outputs,
we need to keep this.
3) It is ScanSaveMem suppose the the others ran before.
I added an assert at one place, but didn't looked for
other place.
4) Moving this to local opt could speed up significant this opt,
as we pass frequently on all nodes in the graph for no
good reason.
5) We register remove_constant_* many places, as some
opt create them and let this one clean up the mess.
Doing it that way, make things simpler for those already
complex opt.
in2out(constant_folding),
in2out(remove_constants_and_unused_inputs_scan1),
ScanMerge,
in2out(remove_constants_and_unused_inputs_scan2),
in2out(scan_merge_inouts),
ScanSaveMem,
in2out(remove_constants_and_unused_inputs_scan3)
"""
import
copy
import
copy
import
dataclasses
import
dataclasses
import
logging
import
logging
from
sys
import
maxsize
from
sys
import
maxsize
from
typing
import
Dict
,
List
,
Tuple
import
numpy
as
np
import
numpy
as
np
...
@@ -65,6 +16,7 @@ from aesara.compile.function.types import deep_copy_op
...
@@ -65,6 +16,7 @@ from aesara.compile.function.types import deep_copy_op
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
(
from
aesara.graph.basic
import
(
Constant
,
Constant
,
Node
,
Variable
,
Variable
,
clone_replace
,
clone_replace
,
equal_computations
,
equal_computations
,
...
@@ -74,7 +26,7 @@ from aesara.graph.basic import (
...
@@ -74,7 +26,7 @@ from aesara.graph.basic import (
)
)
from
aesara.graph.destroyhandler
import
DestroyHandler
from
aesara.graph.destroyhandler
import
DestroyHandler
from
aesara.graph.features
import
ReplaceValidate
from
aesara.graph.features
import
ReplaceValidate
from
aesara.graph.fg
import
InconsistencyError
from
aesara.graph.fg
import
FunctionGraph
,
InconsistencyError
from
aesara.graph.op
import
compute_test_value
from
aesara.graph.op
import
compute_test_value
from
aesara.graph.opt
import
GlobalOptimizer
,
in2out
,
local_optimizer
from
aesara.graph.opt
import
GlobalOptimizer
,
in2out
,
local_optimizer
from
aesara.graph.optdb
import
EquilibriumDB
,
SequenceDB
from
aesara.graph.optdb
import
EquilibriumDB
,
SequenceDB
...
@@ -236,38 +188,19 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
...
@@ -236,38 +188,19 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
return
False
return
False
class
PushOutNonSeqScan
(
GlobalOptimizer
):
@local_optimizer
([
Scan
])
r"""Pushing out the variables inside the `Scan` that depend only on non-sequences.
def
push_out_non_seq_scan
(
fgraph
,
node
):
r"""Push out the variables inside the `Scan` that depend only on non-sequences.
This optimizations pushes, out of `Scan`'s inner function and into the outer
This optimizations pushes, out of `Scan`'s inner function and into the outer
function, computation that depends only on non-sequence inputs. Such
function, computation that depends only on non-sequence inputs. Such
computation ends up being done every iteration on the same values so moving
computation ends up being done every iteration on the same values so moving
it to the outer function to be executed only once, before the `Scan` `Op`,
it to the outer function to be executed only once, before the `Scan` `Op`,
reduces the amount of computation that needs to be performed.
reduces the amount of computation that needs to be performed.
TODO: This is a global opt for historical reasonons. It should be possible
to change it to a local opt.
"""
def
__init__
(
self
):
super
()
.
__init__
()
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ReplaceValidate
())
def
apply
(
self
,
fgraph
):
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
for
node
in
nodelist
:
self
.
process_node
(
fgraph
,
node
)
def
process_node
(
self
,
fgraph
,
node
):
"""
"""
IMPORTANT NOTE: This function uses set and dictionary data structures.
if
not
isinstance
(
node
.
op
,
Scan
):
By default they are not ordered for efficiency reasons. Take care
return
False
and make sure of changing them with their Ordered counterparts if you
need to iterate over these variables.
"""
# this flag tells if there was any change during the last iterations
# this flag tells if there was any change during the last iterations
clean_inputs
,
clean_outputs
=
reconstruct_graph
(
node
.
op
.
inputs
,
node
.
op
.
outputs
)
clean_inputs
,
clean_outputs
=
reconstruct_graph
(
node
.
op
.
inputs
,
node
.
op
.
outputs
)
...
@@ -337,16 +270,9 @@ class PushOutNonSeqScan(GlobalOptimizer):
...
@@ -337,16 +270,9 @@ class PushOutNonSeqScan(GlobalOptimizer):
elif
isinstance
(
x
,
Constant
):
elif
isinstance
(
x
,
Constant
):
outside_ins
.
append
(
x
.
clone
())
outside_ins
.
append
(
x
.
clone
())
else
:
else
:
raise
Exception
(
# TODO: Explain why is this an error, and raise an
(
# appropriate exception type.
"Error in the `scan_pushout_non_seq_"
raise
RuntimeError
()
"operations`. The optimization tries "
"to move some computation from scan "
"which is not allowed to move. Report "
"this on aesara-users list"
),
x
,
)
outside_ins
=
[
outside_ins
=
[
x
.
type
.
filter_variable
(
y
)
for
x
,
y
in
zip
(
nd
.
inputs
,
outside_ins
)
x
.
type
.
filter_variable
(
y
)
for
x
,
y
in
zip
(
nd
.
inputs
,
outside_ins
)
]
]
...
@@ -425,12 +351,9 @@ class PushOutNonSeqScan(GlobalOptimizer):
...
@@ -425,12 +351,9 @@ class PushOutNonSeqScan(GlobalOptimizer):
# Do not call make_node for test_value
# Do not call make_node for test_value
nw_node
=
nwScan
(
*
(
node
.
inputs
+
nw_outer
),
return_list
=
True
)[
0
]
.
owner
nw_node
=
nwScan
(
*
(
node
.
inputs
+
nw_outer
),
return_list
=
True
)[
0
]
.
owner
fgraph
.
replace_all_validate_remove
(
replacements
=
dict
(
zip
(
node
.
outputs
,
nw_node
.
outputs
))
list
(
zip
(
node
.
outputs
,
nw_node
.
outputs
)),
replacements
[
"remove"
]
=
[
node
]
remove
=
[
node
],
return
replacements
reason
=
"scanOp_pushout_nonseqs_ops"
,
)
return
True
elif
not
to_keep_set
:
elif
not
to_keep_set
:
# Nothing in the inner graph should be kept
# Nothing in the inner graph should be kept
replace_with
=
{}
replace_with
=
{}
...
@@ -448,11 +371,8 @@ class PushOutNonSeqScan(GlobalOptimizer):
...
@@ -448,11 +371,8 @@ class PushOutNonSeqScan(GlobalOptimizer):
if
len
(
node
.
outputs
)
==
len
(
replace_with
):
if
len
(
node
.
outputs
)
==
len
(
replace_with
):
# Every output of the node has a replacement, the Scan
# Every output of the node has a replacement, the Scan
# node can be removed from the graph
# node can be removed from the graph
fgraph
.
replace_all_validate_remove
(
replace_with
[
"remove"
]
=
[
node
]
replace_with
.
items
(),
return
replace_with
remove
=
[
node
],
reason
=
"scanOp_pushout_nonseqs_ops"
,
)
else
:
else
:
# The node has some outputs for which no replacement has
# The node has some outputs for which no replacement has
# been established. This can occur for outputs that are
# been established. This can occur for outputs that are
...
@@ -461,15 +381,14 @@ class PushOutNonSeqScan(GlobalOptimizer):
...
@@ -461,15 +381,14 @@ class PushOutNonSeqScan(GlobalOptimizer):
# passed directly as outputs. The replacements can be
# passed directly as outputs. The replacements can be
# performed but the Scan node can't be removed at this
# performed but the Scan node can't be removed at this
# point.
# point.
fgraph
.
replace_all_validate
(
return
replace_with
replace_with
.
items
(),
reason
=
"scanOp_pushout_nonseqs_ops"
)
else
:
else
:
return
False
return
False
class
PushOutSeqScan
(
GlobalOptimizer
):
@local_optimizer
([
Scan
])
def
push_out_seq_scan
(
fgraph
,
node
):
r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
This optimization resembles `PushOutNonSeqScan` but it tries to push, out of
This optimization resembles `PushOutNonSeqScan` but it tries to push, out of
...
@@ -479,30 +398,10 @@ class PushOutSeqScan(GlobalOptimizer):
...
@@ -479,30 +398,10 @@ class PushOutSeqScan(GlobalOptimizer):
a single operation on a large tensor rather then perform that same operation
a single operation on a large tensor rather then perform that same operation
many times on many smaller tensors. In many cases, this optimization can
many times on many smaller tensors. In many cases, this optimization can
increase memory usage but, in some specific cases, it can also decrease it.
increase memory usage but, in some specific cases, it can also decrease it.
TODO: This is a global opt for historical reasonons. It should be possible
to change it to a local opt.
"""
"""
if
not
isinstance
(
node
.
op
,
Scan
):
return
False
def
__init__
(
self
):
super
()
.
__init__
()
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ReplaceValidate
())
def
apply
(
self
,
fgraph
):
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
for
node
in
nodelist
:
self
.
process_node
(
fgraph
,
node
)
def
process_node
(
self
,
fgraph
,
node
):
"""
IMPORTANT NOTE: This function uses set and dictionary data structure.
By default they are not ordered for efficiency reasons. Take care
and make sure of changing them to Ordered versions if you need to
iterate over those variables.
"""
# this flag tells if there was any change during the last iterations
# this flag tells if there was any change during the last iterations
clean_inputs
,
clean_outputs
=
reconstruct_graph
(
node
.
op
.
inputs
,
node
.
op
.
outputs
)
clean_inputs
,
clean_outputs
=
reconstruct_graph
(
node
.
op
.
inputs
,
node
.
op
.
outputs
)
...
@@ -607,10 +506,7 @@ class PushOutSeqScan(GlobalOptimizer):
...
@@ -607,10 +506,7 @@ class PushOutSeqScan(GlobalOptimizer):
elif
(
elif
(
nd
not
in
to_remove_set
nd
not
in
to_remove_set
and
isinstance
(
nd
.
op
,
DimShuffle
)
and
isinstance
(
nd
.
op
,
DimShuffle
)
and
(
and
(
nd
.
inputs
[
0
]
in
inner_seqs_set
or
nd
.
inputs
[
0
]
.
owner
in
to_remove_set
)
nd
.
inputs
[
0
]
in
inner_seqs_set
or
nd
.
inputs
[
0
]
.
owner
in
to_remove_set
)
):
):
to_remove_set
.
add
(
nd
)
to_remove_set
.
add
(
nd
)
...
@@ -687,9 +583,7 @@ class PushOutSeqScan(GlobalOptimizer):
...
@@ -687,9 +583,7 @@ class PushOutSeqScan(GlobalOptimizer):
op_ins
=
nw_inner
+
clean_inputs
op_ins
=
nw_inner
+
clean_inputs
# Reconstruct node
# Reconstruct node
nw_info
=
dataclasses
.
replace
(
nw_info
=
dataclasses
.
replace
(
op
.
info
,
n_seqs
=
op
.
info
.
n_seqs
+
len
(
nw_inner
))
op
.
info
,
n_seqs
=
op
.
info
.
n_seqs
+
len
(
nw_inner
)
)
nwScan
=
Scan
(
nwScan
=
Scan
(
op_ins
,
op_ins
,
op_outs
,
op_outs
,
...
@@ -709,12 +603,10 @@ class PushOutSeqScan(GlobalOptimizer):
...
@@ -709,12 +603,10 @@ class PushOutSeqScan(GlobalOptimizer):
return_list
=
True
,
return_list
=
True
,
)[
0
]
.
owner
)[
0
]
.
owner
fgraph
.
replace_all_validate_remove
(
replacements
=
dict
(
zip
(
node
.
outputs
,
nw_node
.
outputs
))
list
(
zip
(
node
.
outputs
,
nw_node
.
outputs
)),
replacements
[
"remove"
]
=
[
node
]
remove
=
[
node
],
return
replacements
reason
=
"scanOp_pushout_seqs_ops"
,
)
return
True
elif
not
to_keep_set
and
not
op
.
as_while
and
not
op
.
outer_mitmot
(
node
.
inputs
):
elif
not
to_keep_set
and
not
op
.
as_while
and
not
op
.
outer_mitmot
(
node
.
inputs
):
# Nothing in the inner graph should be kept
# Nothing in the inner graph should be kept
replace_with
=
{}
replace_with
=
{}
...
@@ -740,154 +632,21 @@ class PushOutSeqScan(GlobalOptimizer):
...
@@ -740,154 +632,21 @@ class PushOutSeqScan(GlobalOptimizer):
# We need to add one extra dimension to the outputs
# We need to add one extra dimension to the outputs
if
replace_with
and
len
(
replace_with
)
==
len
(
node
.
outputs
):
if
replace_with
and
len
(
replace_with
)
==
len
(
node
.
outputs
):
fgraph
.
replace_all_validate_remove
(
replacements
=
dict
(
replace_with
.
items
())
list
(
replace_with
.
items
()),
replacements
[
"remove"
]
=
[
node
]
remove
=
[
node
],
return
replacements
reason
=
"scanOp_pushout_seqs_ops"
,
)
return
True
else
:
else
:
return
False
return
False
class
PushOutScanOutput
(
GlobalOptimizer
):
def
inner_sitsot_only_last_step_used
(
r"""Push operations performed at the end of the inner graph of `Scan` to outside of `Scan`.
fgraph
:
FunctionGraph
,
var
:
Variable
,
scan_args
:
ScanArgs
)
->
bool
:
This optimizations attempts to push out some of the computation at the end
of the inner function to the outer function, to be executed after the `Scan`
node. Like `PushOutSeqScan`, this optimization aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to
increased memory usage.
"""
def
__init__
(
self
):
super
()
.
__init__
()
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ReplaceValidate
())
def
apply
(
self
,
fgraph
):
# Don't perform the optimization on as_while scans. Because these scans
# don't run for a predetermined number of steps, handling them is
# more complicated and this optimization doesn't support it at the
# moment.
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
(
isinstance
(
x
.
op
,
Scan
)
and
not
x
.
op
.
as_while
)
]
for
node
in
nodelist
:
# Process the node as long as something gets optimized
while
node
is
not
None
:
node
=
self
.
process_node
(
fgraph
,
node
)
def
process_node
(
self
,
fgraph
,
node
):
op
=
node
.
op
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args
=
ScanArgs
(
node
.
inputs
,
node
.
outputs
,
op
.
inputs
,
op
.
outputs
,
op
.
info
,
op
.
as_while
)
new_scan_node
=
None
clients
=
{}
local_fgraph_topo
=
io_toposort
(
args
.
inner_inputs
,
args
.
inner_outputs
,
clients
=
clients
)
for
nd
in
local_fgraph_topo
:
if
(
isinstance
(
nd
.
op
,
Elemwise
)
and
isinstance
(
nd
.
op
.
scalar_op
,
aes
.
Add
)
and
nd
.
out
in
args
.
inner_out_sit_sot
and
self
.
inner_sitsot_only_last_step_used
(
fgraph
,
nd
.
out
,
args
)
):
# Ensure that one of the input to the add is the output of
# the add from a previous iteration of the inner function
sitsot_idx
=
args
.
inner_out_sit_sot
.
index
(
nd
.
out
)
if
args
.
inner_in_sit_sot
[
sitsot_idx
]
in
nd
.
inputs
:
# Ensure that the other input to the add is a dot product
# between 2 matrices which will become a tensor3 and a
# matrix if pushed outside of the scan. Also make sure
# that the output of the Dot is ONLY used by the 'add'
# otherwise doing a Dot in the outer graph will only
# duplicate computation.
sitsot_in_idx
=
nd
.
inputs
.
index
(
args
.
inner_in_sit_sot
[
sitsot_idx
])
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx
=
1
-
sitsot_in_idx
dot_input
=
nd
.
inputs
[
dot_in_idx
]
if
(
dot_input
.
owner
is
not
None
and
isinstance
(
dot_input
.
owner
.
op
,
Dot
)
and
len
(
clients
[
dot_input
])
==
1
and
dot_input
.
owner
.
inputs
[
0
]
.
ndim
==
2
and
dot_input
.
owner
.
inputs
[
1
]
.
ndim
==
2
and
self
.
get_outer_ndim
(
dot_input
.
owner
.
inputs
[
0
],
args
)
==
3
and
self
.
get_outer_ndim
(
dot_input
.
owner
.
inputs
[
1
],
args
)
==
3
):
# The optimization can be be applied in this case.
# Move out of scan the two inputs to the Dot and
# perform a dot outside of scan on these two inputs
inner_dot_inputs
=
nd
.
inputs
[
dot_in_idx
]
.
owner
.
inputs
(
outer_dot_inputs
,
new_scan_node
,
new_scan_args
,
)
=
self
.
push_out_inner_vars
(
fgraph
,
inner_dot_inputs
,
node
,
args
)
# Collapse some of the dimensions of the tensors
# so that they become matrices. This is because a
# dot is usually faster on two large matrices than
# a bunch of small ones
outer_dot_inputs
[
0
]
=
aet
.
flatten
(
outer_dot_inputs
[
0
]
.
dimshuffle
(
1
,
0
,
2
),
ndim
=
2
)
shape_input1
=
shape
(
outer_dot_inputs
[
1
])
outer_dot_inputs
[
1
]
=
outer_dot_inputs
[
1
]
.
reshape
(
(
shape_input1
[
0
]
*
shape_input1
[
1
],
shape_input1
[
2
])
)
# Perform the dot on the newly obtained matrices and
# add the initial value
outer_dot_output
=
dot
(
*
outer_dot_inputs
)
init_value
=
new_scan_args
.
outer_in_sit_sot
[
sitsot_idx
][
0
]
replacement
=
outer_dot_output
+
init_value
# Alter the outer graph to use the output of the
# external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot
outer_sitsot
=
new_scan_args
.
outer_out_sit_sot
[
sitsot_idx
]
subtensor_node
=
fgraph
.
clients
[
outer_sitsot
][
0
][
0
]
outer_sitsot_last_step
=
subtensor_node
.
outputs
[
0
]
fgraph
.
replace_all
(
[(
outer_sitsot_last_step
,
replacement
)],
reason
=
"scanOp_pushout_output"
,
)
break
return
new_scan_node
def
inner_sitsot_only_last_step_used
(
self
,
fgraph
,
var
,
scan_args
):
"""
"""
Given a inner nit_sot output of scan, return True
iff the outer
Given a inner nit-sot output of `Scan`, return ``True``
iff the outer
nit_sot output has only one client and that client is a Subtensor
nit-sot output has only one client and that client is a `Subtensor`
instance that takes only the last step (last element along the first
instance that takes only the last step (last element along the first
axis).
axis).
"""
"""
idx
=
scan_args
.
inner_out_sit_sot
.
index
(
var
)
idx
=
scan_args
.
inner_out_sit_sot
.
index
(
var
)
outer_var
=
scan_args
.
outer_out_sit_sot
[
idx
]
outer_var
=
scan_args
.
outer_out_sit_sot
[
idx
]
...
@@ -901,23 +660,28 @@ class PushOutScanOutput(GlobalOptimizer):
...
@@ -901,23 +660,28 @@ class PushOutScanOutput(GlobalOptimizer):
return
False
return
False
def
get_outer_ndim
(
self
,
var
,
scan_args
):
# Given a variable, determine the number of dimension it would have if
def
get_outer_ndim
(
var
:
Variable
,
scan_args
:
ScanArgs
)
->
int
:
# it was pushed out of scan
"""Determine the number of dimension a variable would have if it was pushed out of a `Scan`."""
if
var
in
scan_args
.
inner_in_non_seqs
or
isinstance
(
var
,
Constant
):
if
var
in
scan_args
.
inner_in_non_seqs
or
isinstance
(
var
,
Constant
):
outer_ndim
=
var
.
ndim
outer_ndim
=
var
.
ndim
else
:
else
:
outer_ndim
=
var
.
ndim
+
1
outer_ndim
=
var
.
ndim
+
1
return
outer_ndim
return
outer_ndim
def
push_out_inner_vars
(
self
,
fgraph
,
inner_vars
,
old_scan_node
,
old_scan_args
):
def
push_out_inner_vars
(
fgraph
:
FunctionGraph
,
inner_vars
:
List
[
Variable
],
old_scan_node
:
Node
,
old_scan_args
:
ScanArgs
,
)
->
Tuple
[
List
[
Variable
],
ScanArgs
,
Dict
[
Variable
,
Variable
]]:
outer_vars
=
[
None
]
*
len
(
inner_vars
)
outer_vars
=
[
None
]
*
len
(
inner_vars
)
new_scan_node
=
old_scan_node
new_scan_node
=
old_scan_node
new_scan_args
=
old_scan_args
new_scan_args
=
old_scan_args
replacements
=
{}
# For the inner_vars that already exist in the outer graph,
# For the inner_vars that already exist in the outer graph,
# simply obtain a reference to them
# simply obtain a reference to them
...
@@ -942,14 +706,12 @@ class PushOutScanOutput(GlobalOptimizer):
...
@@ -942,14 +706,12 @@ class PushOutScanOutput(GlobalOptimizer):
# For the inner_vars that don't already exist in the outer graph, add
# For the inner_vars that don't already exist in the outer graph, add
# them as new nitsot outputs to the scan node.
# them as new nitsot outputs to the scan node.
idx_add_as_nitsots
=
[
idx_add_as_nitsots
=
[
i
for
i
in
range
(
len
(
outer_vars
))
if
outer_vars
[
i
]
is
None
]
i
for
i
in
range
(
len
(
outer_vars
))
if
outer_vars
[
i
]
is
None
]
add_as_nitsots
=
[
inner_vars
[
idx
]
for
idx
in
idx_add_as_nitsots
]
add_as_nitsots
=
[
inner_vars
[
idx
]
for
idx
in
idx_add_as_nitsots
]
if
len
(
add_as_nitsots
)
>
0
:
if
len
(
add_as_nitsots
)
>
0
:
new_scan_node
=
self
.
add_nitsot_outputs
(
new_scan_node
,
replacements
=
add_nitsot_outputs
(
fgraph
,
old_scan_node
,
old_scan_args
,
add_as_nitsots
fgraph
,
old_scan_node
,
old_scan_args
,
add_as_nitsots
)
)
...
@@ -966,20 +728,22 @@ class PushOutScanOutput(GlobalOptimizer):
...
@@ -966,20 +728,22 @@ class PushOutScanOutput(GlobalOptimizer):
for
i
in
range
(
len
(
new_outs
)):
for
i
in
range
(
len
(
new_outs
)):
outer_vars
[
idx_add_as_nitsots
[
i
]]
=
new_outs
[
i
]
outer_vars
[
idx_add_as_nitsots
[
i
]]
=
new_outs
[
i
]
return
outer_vars
,
new_scan_node
,
new_scan_arg
s
return
outer_vars
,
new_scan_args
,
replacement
s
def
add_nitsot_outputs
(
self
,
fgraph
,
old_scan_node
,
old_scan_args
,
new_outputs_inner
def
add_nitsot_outputs
(
):
fgraph
:
FunctionGraph
,
old_scan_node
:
Node
,
old_scan_args
:
ScanArgs
,
new_outputs_inner
,
)
->
Tuple
[
Node
,
Dict
[
Variable
,
Variable
]]:
nb_new_outs
=
len
(
new_outputs_inner
)
nb_new_outs
=
len
(
new_outputs_inner
)
# Create the initial values for the new nitsot outputs
# Create the initial values for the new nitsot outputs
# (the initial value is the nb of steps to store. For a nistot,
# (the initial value is the nb of steps to store. For a nistot,
# it should be the number of steps performed by scan)
# it should be the number of steps performed by scan)
new_nitsots_initial_value
=
[
new_nitsots_initial_value
=
[
old_scan_node
.
inputs
[
0
]
for
i
in
range
(
nb_new_outs
)]
old_scan_node
.
inputs
[
0
]
for
i
in
range
(
nb_new_outs
)
]
# Create the `ScanArgs` corresponding to the new `Scan` `Op` to create
# Create the `ScanArgs` corresponding to the new `Scan` `Op` to create
new_scan_args
=
copy
.
copy
(
old_scan_args
)
new_scan_args
=
copy
.
copy
(
old_scan_args
)
...
@@ -1002,9 +766,7 @@ class PushOutScanOutput(GlobalOptimizer):
...
@@ -1002,9 +766,7 @@ class PushOutScanOutput(GlobalOptimizer):
)
)
# Create the Apply node for the scan op
# Create the Apply node for the scan op
new_scan_node
=
new_scan_op
(
*
new_scan_args
.
outer_inputs
,
return_list
=
True
)[
new_scan_node
=
new_scan_op
(
*
new_scan_args
.
outer_inputs
,
return_list
=
True
)[
0
]
.
owner
0
]
.
owner
# Modify the outer graph to make sure the outputs of the new scan are
# Modify the outer graph to make sure the outputs of the new scan are
# used instead of the outputs of the old scan
# used instead of the outputs of the old scan
...
@@ -1017,13 +779,123 @@ class PushOutScanOutput(GlobalOptimizer):
...
@@ -1017,13 +779,123 @@ class PushOutScanOutput(GlobalOptimizer):
+
new_scan_node
.
outputs
[
new_node_new_outputs_idx
+
nb_new_outs
:]
+
new_scan_node
.
outputs
[
new_node_new_outputs_idx
+
nb_new_outs
:]
)
)
# TODO FIXME:
# replacements = dict(zip(old_scan_node.outputs, new_node_old_outputs))
# replacements["remove"] = [old_scan_node]
# return new_scan_node, replacements
fgraph
.
replace_all_validate_remove
(
fgraph
.
replace_all_validate_remove
(
list
(
zip
(
old_scan_node
.
outputs
,
new_node_old_outputs
)),
list
(
zip
(
old_scan_node
.
outputs
,
new_node_old_outputs
)),
remove
=
[
old_scan_node
],
remove
=
[
old_scan_node
],
reason
=
"scanOp
_pushout_output"
,
reason
=
"scan
_pushout_output"
,
)
)
return
new_scan_node
,
{}
@local_optimizer
([
Scan
])
def
push_out_add_scan
(
fgraph
,
node
):
r"""Push `Add` operations performed at the end of the inner graph to the outside.
return
new_scan_node
Like `push_out_seq_scan`, this optimization aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to
increased memory usage.
"""
# Don't perform the optimization on `as_while` `Scan`s. Because these
# `Scan`s don't run for a predetermined number of steps, handling them is
# more complicated and this optimization doesn't support it at the moment.
if
not
(
isinstance
(
node
.
op
,
Scan
)
and
not
node
.
op
.
as_while
):
return
False
op
=
node
.
op
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args
=
ScanArgs
(
node
.
inputs
,
node
.
outputs
,
op
.
inputs
,
op
.
outputs
,
op
.
info
,
op
.
as_while
)
clients
=
{}
local_fgraph_topo
=
io_toposort
(
args
.
inner_inputs
,
args
.
inner_outputs
,
clients
=
clients
)
for
nd
in
local_fgraph_topo
:
if
(
isinstance
(
nd
.
op
,
Elemwise
)
and
isinstance
(
nd
.
op
.
scalar_op
,
aes
.
Add
)
and
nd
.
out
in
args
.
inner_out_sit_sot
and
inner_sitsot_only_last_step_used
(
fgraph
,
nd
.
out
,
args
)
):
# Ensure that one of the input to the add is the output of
# the add from a previous iteration of the inner function
sitsot_idx
=
args
.
inner_out_sit_sot
.
index
(
nd
.
out
)
if
args
.
inner_in_sit_sot
[
sitsot_idx
]
in
nd
.
inputs
:
# Ensure that the other input to the add is a dot product
# between 2 matrices which will become a tensor3 and a
# matrix if pushed outside of the scan. Also make sure
# that the output of the Dot is ONLY used by the 'add'
# otherwise doing a Dot in the outer graph will only
# duplicate computation.
sitsot_in_idx
=
nd
.
inputs
.
index
(
args
.
inner_in_sit_sot
[
sitsot_idx
])
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx
=
1
-
sitsot_in_idx
dot_input
=
nd
.
inputs
[
dot_in_idx
]
if
(
dot_input
.
owner
is
not
None
and
isinstance
(
dot_input
.
owner
.
op
,
Dot
)
and
len
(
clients
[
dot_input
])
==
1
and
dot_input
.
owner
.
inputs
[
0
]
.
ndim
==
2
and
dot_input
.
owner
.
inputs
[
1
]
.
ndim
==
2
and
get_outer_ndim
(
dot_input
.
owner
.
inputs
[
0
],
args
)
==
3
and
get_outer_ndim
(
dot_input
.
owner
.
inputs
[
1
],
args
)
==
3
):
# The optimization can be be applied in this case.
# Move out of scan the two inputs to the Dot and
# perform a dot outside of scan on these two inputs
inner_dot_inputs
=
nd
.
inputs
[
dot_in_idx
]
.
owner
.
inputs
(
outer_dot_inputs
,
new_scan_args
,
replacements
,
)
=
push_out_inner_vars
(
fgraph
,
inner_dot_inputs
,
node
,
args
)
# Collapse some of the dimensions of the tensors
# so that they become matrices. This is because a
# dot is usually faster on two large matrices than
# a bunch of small ones
outer_dot_inputs
[
0
]
=
aet
.
flatten
(
outer_dot_inputs
[
0
]
.
dimshuffle
(
1
,
0
,
2
),
ndim
=
2
)
shape_input1
=
shape
(
outer_dot_inputs
[
1
])
outer_dot_inputs
[
1
]
=
outer_dot_inputs
[
1
]
.
reshape
(
(
shape_input1
[
0
]
*
shape_input1
[
1
],
shape_input1
[
2
])
)
# Perform the dot on the newly obtained matrices and
# add the initial value
outer_dot_output
=
dot
(
*
outer_dot_inputs
)
init_value
=
new_scan_args
.
outer_in_sit_sot
[
sitsot_idx
][
0
]
replacement
=
outer_dot_output
+
init_value
# Alter the outer graph to use the output of the
# external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot
outer_sitsot
=
new_scan_args
.
outer_out_sit_sot
[
sitsot_idx
]
subtensor_node
=
fgraph
.
clients
[
outer_sitsot
][
0
][
0
]
outer_sitsot_last_step
=
subtensor_node
.
outputs
[
0
]
replacements
[
outer_sitsot_last_step
]
=
replacement
return
replacements
return
False
class
ScanInplaceOptimizer
(
GlobalOptimizer
):
class
ScanInplaceOptimizer
(
GlobalOptimizer
):
...
@@ -1203,7 +1075,31 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -1203,7 +1075,31 @@ class ScanInplaceOptimizer(GlobalOptimizer):
node
=
self
.
attempt_scan_inplace
(
fgraph
,
node
,
[
pos
],
alloc_ops
)
node
=
self
.
attempt_scan_inplace
(
fgraph
,
node
,
[
pos
],
alloc_ops
)
class
ScanSaveMem
(
GlobalOptimizer
):
def
select_min
(
x
,
y
):
if
x
is
None
:
return
y
if
y
is
None
:
return
x
return
minimum
(
x
,
y
)
def
select_max
(
x
,
y
):
if
x
is
None
:
return
y
if
y
is
None
:
return
x
return
maximum
(
x
,
y
)
def
sanitize
(
x
):
if
x
is
None
:
return
None
else
:
return
aet
.
as_tensor_variable
(
x
)
@local_optimizer
([
Scan
])
def
save_mem_new_scan
(
fgraph
,
node
):
r"""Graph optimizer that reduces scan memory consumption.
r"""Graph optimizer that reduces scan memory consumption.
This optimizations attempts to determine if a `Scan` node, during its execution,
This optimizations attempts to determine if a `Scan` node, during its execution,
...
@@ -1224,35 +1120,8 @@ class ScanSaveMem(GlobalOptimizer):
...
@@ -1224,35 +1120,8 @@ class ScanSaveMem(GlobalOptimizer):
be kept in memory.
be kept in memory.
"""
"""
if
not
isinstance
(
node
.
op
,
Scan
):
def
__init__
(
self
):
return
False
super
()
.
__init__
()
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ReplaceValidate
())
def
process_node
(
self
,
fgraph
,
node
):
# helpful functions
def
select_min
(
x
,
y
):
if
x
is
None
:
return
y
if
y
is
None
:
return
x
return
minimum
(
x
,
y
)
def
select_max
(
x
,
y
):
if
x
is
None
:
return
y
if
y
is
None
:
return
x
return
maximum
(
x
,
y
)
def
sanitize
(
x
):
if
x
is
None
:
return
None
else
:
return
aet
.
as_tensor_variable
(
x
)
if
hasattr
(
fgraph
,
"shape_feature"
):
if
hasattr
(
fgraph
,
"shape_feature"
):
shape_of
=
fgraph
.
shape_feature
.
shape_of
shape_of
=
fgraph
.
shape_feature
.
shape_of
...
@@ -1487,17 +1356,12 @@ class ScanSaveMem(GlobalOptimizer):
...
@@ -1487,17 +1356,12 @@ class ScanSaveMem(GlobalOptimizer):
first_mitsot_idx
=
node
.
op
.
n_mit_mot
first_mitsot_idx
=
node
.
op
.
n_mit_mot
last_sitsot_idx
=
(
last_sitsot_idx
=
(
node
.
op
.
n_mit_mot
node
.
op
.
n_mit_mot
+
node
.
op
.
n_mit_sot
+
node
.
op
.
n_sit_sot
-
1
+
node
.
op
.
n_mit_sot
+
node
.
op
.
n_sit_sot
-
1
)
)
preallocable_output
=
first_mitsot_idx
<=
i
<=
last_sitsot_idx
preallocable_output
=
first_mitsot_idx
<=
i
<=
last_sitsot_idx
if
prealloc_outs
and
preallocable_output
:
if
prealloc_outs
and
preallocable_output
:
pval
=
select_max
(
pval
=
select_max
(
nw_steps
-
start
+
init_l
[
i
],
init_l
[
i
]
+
1
)
nw_steps
-
start
+
init_l
[
i
],
init_l
[
i
]
+
1
)
else
:
else
:
pval
=
select_max
(
nw_steps
-
start
+
init_l
[
i
],
init_l
[
i
])
pval
=
select_max
(
nw_steps
-
start
+
init_l
[
i
],
init_l
[
i
])
...
@@ -1544,9 +1408,7 @@ class ScanSaveMem(GlobalOptimizer):
...
@@ -1544,9 +1408,7 @@ class ScanSaveMem(GlobalOptimizer):
# TODO: commit change below with Razvan
# TODO: commit change below with Razvan
if
(
if
(
nw_inputs
[
offset
+
idx
]
.
owner
nw_inputs
[
offset
+
idx
]
.
owner
and
isinstance
(
and
isinstance
(
nw_inputs
[
offset
+
idx
]
.
owner
.
op
,
IncSubtensor
)
nw_inputs
[
offset
+
idx
]
.
owner
.
op
,
IncSubtensor
)
and
isinstance
(
and
isinstance
(
nw_inputs
[
offset
+
idx
]
.
owner
.
op
.
idx_list
[
0
],
slice
nw_inputs
[
offset
+
idx
]
.
owner
.
op
.
idx_list
[
0
],
slice
)
)
...
@@ -1558,9 +1420,7 @@ class ScanSaveMem(GlobalOptimizer):
...
@@ -1558,9 +1420,7 @@ class ScanSaveMem(GlobalOptimizer):
_nw_input
=
nw_inputs
[
offset
+
idx
]
.
owner
.
inputs
[
1
]
_nw_input
=
nw_inputs
[
offset
+
idx
]
.
owner
.
inputs
[
1
]
cval
=
aet
.
as_tensor_variable
(
val
)
cval
=
aet
.
as_tensor_variable
(
val
)
initl
=
aet
.
as_tensor_variable
(
init_l
[
i
])
initl
=
aet
.
as_tensor_variable
(
init_l
[
i
])
tmp_idx
=
aet
.
switch
(
tmp_idx
=
aet
.
switch
(
cval
<
initl
,
cval
+
initl
,
cval
-
initl
)
cval
<
initl
,
cval
+
initl
,
cval
-
initl
)
nw_input
=
expand_empty
(
_nw_input
,
tmp_idx
)
nw_input
=
expand_empty
(
_nw_input
,
tmp_idx
)
else
:
else
:
tmp
=
aet
.
as_tensor_variable
(
val
)
tmp
=
aet
.
as_tensor_variable
(
val
)
...
@@ -1645,7 +1505,7 @@ class ScanSaveMem(GlobalOptimizer):
...
@@ -1645,7 +1505,7 @@ class ScanSaveMem(GlobalOptimizer):
# TODO: currently we don't support scan with 0 step. So
# TODO: currently we don't support scan with 0 step. So
# don't create one.
# don't create one.
if
aet
.
extract_constant
(
node_ins
[
0
])
==
0
:
if
aet
.
extract_constant
(
node_ins
[
0
])
==
0
:
return
return
False
# Do not call make_node for test_value
# Do not call make_node for test_value
new_op
=
Scan
(
new_op
=
Scan
(
...
@@ -1758,19 +1618,18 @@ class ScanSaveMem(GlobalOptimizer):
...
@@ -1758,19 +1618,18 @@ class ScanSaveMem(GlobalOptimizer):
]
]
if
any
(
old_scan_is_used
):
if
any
(
old_scan_is_used
):
return
False
return
False
remove
=
[
old
.
owner
for
(
old
,
new
)
in
old_new
]
replacements
=
dict
(
old_new
)
# remove = [old.owner for (old, new) in old_new]
# As Fred suggested assert that also the old node is not in
# As Fred suggested assert that also the old node is not in
# the Graph as that will make things suboptimal
# the Graph as that will make things suboptimal
remove
.
append
(
node
)
# remove.append(node)
fgraph
.
replace_all_validate_remove
(
replacements
[
"remove"
]
=
[
node
]
old_new
,
remove
,
reason
=
"scanOp_save_mem"
)
def
apply
(
self
,
fgraph
):
return
replacements
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
return
False
for
node
in
nodelist
:
self
.
process_node
(
fgraph
,
node
)
class
ScanMerge
(
GlobalOptimizer
):
class
ScanMerge
(
GlobalOptimizer
):
...
@@ -2271,27 +2130,16 @@ def scan_merge_inouts(fgraph, node):
...
@@ -2271,27 +2130,16 @@ def scan_merge_inouts(fgraph, node):
return
na
.
outer_outputs
return
na
.
outer_outputs
class
PushOutDot1
(
GlobalOptimizer
):
@local_optimizer
([
Scan
])
def
push_out_dot1_scan
(
fgraph
,
node
):
r"""
r"""
This is another optimization that attempts to detect certain patterns of
This is another optimization that attempts to detect certain patterns of
computation in a `Scan` `Op`'s inner function and move this computation to the
computation in a `Scan` `Op`'s inner function and move this computation to the
outer graph.
outer graph.
"""
"""
if
not
isinstance
(
node
.
op
,
Scan
):
return
False
def
__init__
(
self
):
super
()
.
__init__
()
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ReplaceValidate
())
def
apply
(
self
,
fgraph
):
nodes
=
fgraph
.
toposort
()
scan_nodes
=
[
x
for
x
in
nodes
if
(
isinstance
(
x
.
op
,
Scan
))]
for
node
in
scan_nodes
:
self
.
apply_opt
(
fgraph
,
node
)
def
apply_opt
(
self
,
fgraph
,
node
):
# Replace pattern of the form
# Replace pattern of the form
# x[t] = x[t-1] + dot(seq[t], value)
# x[t] = x[t-1] + dot(seq[t], value)
# with Sequence.reshape((-1, seq.shape[2])) \dot Value
# with Sequence.reshape((-1, seq.shape[2])) \dot Value
...
@@ -2470,9 +2318,11 @@ class PushOutDot1(GlobalOptimizer):
...
@@ -2470,9 +2318,11 @@ class PushOutDot1(GlobalOptimizer):
old
=
fgraph
.
clients
[
node
.
outputs
[
pos
]][
0
][
0
]
.
outputs
[
0
]
old
=
fgraph
.
clients
[
node
.
outputs
[
pos
]][
0
][
0
]
.
outputs
[
0
]
old_new
.
append
((
old
,
new_out
))
old_new
.
append
((
old
,
new_out
))
old_new
+=
list
(
zip
(
node
.
outputs
[
pos
+
1
:],
new_outs
[
pos
:]))
old_new
+=
list
(
zip
(
node
.
outputs
[
pos
+
1
:],
new_outs
[
pos
:]))
fgraph
.
replace_all_validate_remove
(
replacements
=
dict
(
old_new
)
old_new
,
remove
=
[
node
],
reason
=
"scan_pushout_dot1"
replacements
[
"remove"
]
=
[
node
]
)
return
replacements
return
False
# I've added an equilibrium because later scan optimization in the sequence
# I've added an equilibrium because later scan optimization in the sequence
...
@@ -2490,7 +2340,13 @@ optdb.register("scan_eqopt1", scan_eqopt1, 0.05, "fast_run", "scan")
...
@@ -2490,7 +2340,13 @@ optdb.register("scan_eqopt1", scan_eqopt1, 0.05, "fast_run", "scan")
# but after stabilize at 1.5. Should we put it before stabilize?
# but after stabilize at 1.5. Should we put it before stabilize?
optdb
.
register
(
"scan_eqopt2"
,
scan_eqopt2
,
1.6
,
"fast_run"
,
"scan"
)
optdb
.
register
(
"scan_eqopt2"
,
scan_eqopt2
,
1.6
,
"fast_run"
,
"scan"
)
# ScanSaveMem should execute only once per node.
# ScanSaveMem should execute only once per node.
optdb
.
register
(
"scanOp_save_mem"
,
ScanSaveMem
(),
1.61
,
"fast_run"
,
"scan"
)
optdb
.
register
(
"scanOp_save_mem"
,
in2out
(
save_mem_new_scan
,
ignore_newtrees
=
True
),
1.61
,
"fast_run"
,
"scan"
,
)
optdb
.
register
(
optdb
.
register
(
"scanOp_make_inplace"
,
"scanOp_make_inplace"
,
ScanInplaceOptimizer
(
typeInfer
=
None
),
ScanInplaceOptimizer
(
typeInfer
=
None
),
...
@@ -2514,22 +2370,41 @@ scan_seqopt1.register(
...
@@ -2514,22 +2370,41 @@ scan_seqopt1.register(
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
"scanOp_pushout_nonseqs_ops"
,
PushOutNonSeqScan
(),
2
,
"fast_run"
,
"scan"
"scanOp_pushout_nonseqs_ops"
,
in2out
(
push_out_non_seq_scan
,
ignore_newtrees
=
True
),
2
,
"fast_run"
,
"scan"
,
)
)
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
"scanOp_pushout_seqs_ops"
,
PushOutSeqScan
(),
3
,
"fast_run"
,
"scan"
"scanOp_pushout_seqs_ops"
,
in2out
(
push_out_seq_scan
,
ignore_newtrees
=
True
),
3
,
"fast_run"
,
"scan"
,
)
)
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
"scan_pushout_dot1"
,
PushOutDot1
(),
4
,
"fast_run"
,
"more_mem"
,
"scan"
"scan_pushout_dot1"
,
in2out
(
push_out_dot1_scan
,
ignore_newtrees
=
True
),
4
,
"fast_run"
,
"more_mem"
,
"scan"
,
)
)
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
"scanOp_pushout_output"
,
PushOutScanOutput
(),
5
,
"fast_run"
,
"more_mem"
,
"scan"
"scanOp_pushout_output"
,
# TODO: Perhaps this should be an `EquilibriumOptimizer`?
in2out
(
push_out_add_scan
,
ignore_newtrees
=
False
),
5
,
"fast_run"
,
"more_mem"
,
"scan"
,
)
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论