Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e85c7fd0
提交
e85c7fd0
authored
8月 08, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
9月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace Scan info dict with ScanInfo dataclass
上级
d83ff33c
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
360 行增加
和
362 行删除
+360
-362
basic.py
aesara/scan/basic.py
+25
-24
op.py
aesara/scan/op.py
+149
-162
opt.py
aesara/scan/opt.py
+78
-70
utils.py
aesara/scan/utils.py
+98
-101
requirements.txt
requirements.txt
+1
-0
setup.py
setup.py
+8
-1
test_utils.py
tests/scan/test_utils.py
+1
-4
没有找到文件。
aesara/scan/basic.py
浏览文件 @
e85c7fd0
...
...
@@ -13,7 +13,7 @@ from aesara.graph.fg import MissingInputError
from
aesara.graph.op
import
get_test_value
from
aesara.graph.utils
import
TestValueError
from
aesara.scan
import
utils
from
aesara.scan.op
import
Scan
from
aesara.scan.op
import
Scan
,
ScanInfo
from
aesara.scan.utils
import
safe_new
,
traverse
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.math
import
minimum
...
...
@@ -1022,31 +1022,32 @@ def scan(
# Step 7. Create the Scan Op
##
tap_array
=
mit_sot_tap_array
+
[[
-
1
]
for
x
in
range
(
n_sit_sot
)]
tap_array
=
tuple
(
tuple
(
v
)
for
v
in
mit_sot_tap_array
)
+
tuple
(
(
-
1
,)
for
x
in
range
(
n_sit_sot
)
)
if
allow_gc
is
None
:
allow_gc
=
config
.
scan__allow_gc
info
=
OrderedDict
()
info
[
"tap_array"
]
=
tap_array
info
[
"n_seqs"
]
=
n_seqs
info
[
"n_mit_mot"
]
=
n_mit_mot
info
[
"n_mit_mot_outs"
]
=
n_mit_mot_outs
info
[
"mit_mot_out_slices"
]
=
mit_mot_out_slices
info
[
"n_mit_sot"
]
=
n_mit_sot
info
[
"n_sit_sot"
]
=
n_sit_sot
info
[
"n_shared_outs"
]
=
n_shared_outs
info
[
"n_nit_sot"
]
=
n_nit_sot
info
[
"truncate_gradient"
]
=
truncate_gradient
info
[
"name"
]
=
name
info
[
"mode"
]
=
mode
info
[
"destroy_map"
]
=
OrderedDict
()
info
[
"gpua"
]
=
False
info
[
"as_while"
]
=
as_while
info
[
"profile"
]
=
profile
info
[
"allow_gc"
]
=
allow_gc
info
[
"strict"
]
=
strict
local_op
=
Scan
(
inner_inputs
,
new_outs
,
info
)
info
=
ScanInfo
(
tap_array
=
tap_array
,
n_seqs
=
n_seqs
,
n_mit_mot
=
n_mit_mot
,
n_mit_mot_outs
=
n_mit_mot_outs
,
mit_mot_out_slices
=
tuple
(
tuple
(
v
)
for
v
in
mit_mot_out_slices
),
n_mit_sot
=
n_mit_sot
,
n_sit_sot
=
n_sit_sot
,
n_shared_outs
=
n_shared_outs
,
n_nit_sot
=
n_nit_sot
,
truncate_gradient
=
truncate_gradient
,
name
=
name
,
gpua
=
False
,
as_while
=
as_while
,
profile
=
profile
,
allow_gc
=
allow_gc
,
strict
=
strict
,
)
local_op
=
Scan
(
inner_inputs
,
new_outs
,
info
,
mode
)
##
# Step 8. Compute the outputs using the scan op
...
...
aesara/scan/op.py
浏览文件 @
e85c7fd0
...
...
@@ -44,11 +44,12 @@ relies on the following elements to work properly :
"""
import
copy
import
dataclasses
import
itertools
import
logging
import
time
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Optional
,
Union
import
numpy
as
np
...
...
@@ -57,7 +58,7 @@ from aesara import tensor as aet
from
aesara.compile.builders
import
infer_shape
from
aesara.compile.function
import
function
from
aesara.compile.io
import
In
,
Out
from
aesara.compile.mode
import
AddFeatureOptimizer
,
get_mode
from
aesara.compile.mode
import
AddFeatureOptimizer
,
Mode
,
get_default_mode
,
get_mode
from
aesara.compile.profiling
import
ScanProfileStats
,
register_profiler_printer
from
aesara.configdefaults
import
config
from
aesara.gradient
import
DisconnectedType
,
NullType
,
Rop
,
grad
,
grad_undefined
...
...
@@ -76,7 +77,7 @@ from aesara.graph.op import Op, ops_with_inner_function
from
aesara.link.c.basic
import
CLinker
from
aesara.link.c.exceptions
import
MissingGXX
from
aesara.link.utils
import
raise_with_op
from
aesara.scan.utils
import
Validator
,
forced_replace
,
hash_listsDictsTuples
,
safe_new
from
aesara.scan.utils
import
Validator
,
forced_replace
,
safe_new
from
aesara.tensor.basic
import
as_tensor_variable
from
aesara.tensor.math
import
minimum
from
aesara.tensor.shape
import
Shape_i
...
...
@@ -88,57 +89,90 @@ from aesara.tensor.var import TensorVariable
_logger
=
logging
.
getLogger
(
"aesara.scan.op"
)
class
Scan
(
Op
):
"""
Parameters
----------
inputs
Inputs of the inner function of scan.
outputs
Outputs of the inner function of scan.
info
Dictionary containing different properties of the scan op (like number
of different types of arguments, name, mode, if it should run on GPU or
not, etc.).
typeConstructor
Function that constructs an equivalent to Aesara TensorType.
Notes
-----
``typeConstructor`` had been added to refactor how
Aesara deals with the GPU. If it runs on the GPU, scan needs
to construct certain outputs (those who reside in the GPU
memory) as the GPU-specific type. However we can not import
gpu code in this file (as it is in sandbox, and not available
on each machine) so the workaround is that the GPU
optimization passes to the constructor of this class a
function that is able to construct a GPU type. This way the
class Scan does not need to be aware of the details for the
GPU, it just constructs any tensor using this function (which
by default constructs normal tensors).
"""
@dataclasses.dataclass
(
frozen
=
True
)
class
ScanInfo
:
tap_array
:
tuple
n_seqs
:
int
n_mit_mot
:
int
n_mit_mot_outs
:
int
mit_mot_out_slices
:
tuple
n_mit_sot
:
int
n_sit_sot
:
int
n_shared_outs
:
int
n_nit_sot
:
int
truncate_gradient
:
bool
=
False
name
:
Optional
[
str
]
=
None
gpua
:
bool
=
False
as_while
:
bool
=
False
profile
:
Optional
[
Union
[
str
,
bool
]]
=
None
allow_gc
:
bool
=
True
strict
:
bool
=
True
TensorConstructorType
=
Callable
[[
List
[
bool
],
Union
[
str
,
np
.
generic
]],
TensorType
]
class
Scan
(
Op
):
def
__init__
(
self
,
inputs
,
outputs
,
info
,
typeConstructor
=
None
,
inputs
:
List
[
Variable
],
outputs
:
List
[
Variable
],
info
:
ScanInfo
,
mode
:
Optional
[
Mode
]
=
None
,
typeConstructor
:
Optional
[
TensorConstructorType
]
=
None
,
):
r"""
Parameters
----------
inputs
Inputs of the inner function of `Scan`.
outputs
Outputs of the inner function of `Scan`.
info
Dictionary containing different properties of the `Scan` `Op` (like
number of different types of arguments, name, mode, if it should run on
GPU or not, etc.).
mode
The compilation mode for the inner graph.
typeConstructor
Function that constructs an equivalent to Aesara `TensorType`.
Notes
-----
`typeConstructor` had been added to refactor how
Aesara deals with the GPU. If it runs on the GPU, scan needs
to construct certain outputs (those who reside in the GPU
memory) as the GPU-specific type. However we can not import
gpu code in this file (as it is in sandbox, and not available
on each machine) so the workaround is that the GPU
optimization passes to the constructor of this class a
function that is able to construct a GPU type. This way the
class `Scan` does not need to be aware of the details for the
GPU, it just constructs any tensor using this function (which
by default constructs normal tensors).
"""
# adding properties into self
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
__dict__
.
update
(
info
)
# I keep a version of info in self, to use in __eq__ and __hash__,
# since info contains all tunable parameters of the op, so for two
# scan to be equal this tunable parameters should be the same
self
.
info
=
info
self
.
__dict__
.
update
(
dataclasses
.
asdict
(
info
))
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if
self
.
name
:
message
=
self
.
name
+
" sub profile"
else
:
message
=
"Scan sub profile"
self
.
mode
=
get_default_mode
()
if
mode
is
None
else
mode
self
.
mode_instance
=
get_mode
(
self
.
mode
)
.
clone
(
link_kwargs
=
dict
(
allow_gc
=
self
.
allow_gc
),
message
=
message
)
# build a list of output types for any Apply node using this op.
self
.
output_types
=
[]
idx
=
0
jdx
=
0
def
tensorConstructor
(
broadcastable
,
dtype
):
return
TensorType
(
broadcastable
=
broadcastable
,
dtype
=
dtype
)
...
...
@@ -146,6 +180,8 @@ class Scan(Op):
if
typeConstructor
is
None
:
typeConstructor
=
tensorConstructor
idx
=
0
jdx
=
0
while
idx
<
self
.
n_mit_mot_outs
:
# Not that for mit_mot there are several output slices per
# output sequence
...
...
@@ -176,24 +212,13 @@ class Scan(Op):
if
self
.
as_while
:
self
.
output_types
=
self
.
output_types
[:
-
1
]
mode_instance
=
get_mode
(
self
.
mode
)
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if
self
.
name
:
message
=
self
.
name
+
" sub profile"
else
:
message
=
"Scan sub profile"
self
.
mode_instance
=
mode_instance
.
clone
(
link_kwargs
=
dict
(
allow_gc
=
self
.
allow_gc
),
message
=
message
)
if
not
hasattr
(
self
,
"name"
)
or
self
.
name
is
None
:
self
.
name
=
"scan_fn"
# to have a fair __eq__ comparison later on, we update the info with
# the actual mode used to compile the function and the name of the
# function that we set in case none was given
self
.
info
[
"name"
]
=
self
.
name
self
.
info
=
dataclasses
.
replace
(
self
.
info
,
name
=
self
.
name
)
# Pre-computing some values to speed up perform
self
.
mintaps
=
[
np
.
min
(
x
)
for
x
in
self
.
tap_array
]
...
...
@@ -205,8 +230,8 @@ class Scan(Op):
self
.
nit_sot_arg_offset
=
self
.
shared_arg_offset
+
self
.
n_shared_outs
self
.
n_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
self
.
n_tap_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
if
self
.
info
[
"gpua"
]
:
self
.
_hash_inner_graph
=
self
.
info
[
"gpu_hash"
]
if
self
.
info
.
gpua
:
self
.
_hash_inner_graph
=
self
.
info
.
gpu_hash
else
:
# Do the missing inputs check here to have the error early.
for
var
in
graph_inputs
(
self
.
outputs
,
self
.
inputs
):
...
...
@@ -256,7 +281,7 @@ class Scan(Op):
# output with type GpuArrayType
from
aesara.gpuarray
import
GpuArrayType
if
not
self
.
info
.
g
et
(
"gpua"
,
False
)
:
if
not
self
.
info
.
g
pua
:
for
inp
in
self
.
inputs
:
if
isinstance
(
inp
.
type
,
GpuArrayType
):
raise
TypeError
(
...
...
@@ -279,9 +304,7 @@ class Scan(Op):
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
if
"allow_gc"
not
in
self
.
__dict__
:
self
.
allow_gc
=
True
self
.
info
[
"allow_gc"
]
=
True
if
not
hasattr
(
self
,
"var_mappings"
):
# Generate the mappings between inner and outer inputs and outputs
# if they haven't already been generated.
...
...
@@ -731,41 +754,20 @@ class Scan(Op):
return
apply_node
def
__eq__
(
self
,
other
):
# Check if we are dealing with same type of objects
if
not
type
(
self
)
==
type
(
other
):
if
type
(
self
)
!=
type
(
other
):
return
False
if
"destroy_map"
not
in
self
.
info
:
self
.
info
[
"destroy_map"
]
=
OrderedDict
()
if
"destroy_map"
not
in
other
.
info
:
other
.
info
[
"destroy_map"
]
=
OrderedDict
()
keys_to_check
=
[
"truncate_gradient"
,
"profile"
,
"n_seqs"
,
"tap_array"
,
"as_while"
,
"n_mit_sot"
,
"destroy_map"
,
"n_nit_sot"
,
"n_shared_outs"
,
"n_sit_sot"
,
"gpua"
,
"n_mit_mot_outs"
,
"n_mit_mot"
,
"mit_mot_out_slices"
,
]
# This are some safety checks ( namely that the inner graph has the
# same number of inputs and same number of outputs )
if
not
len
(
self
.
inputs
)
==
len
(
other
.
inputs
):
if
self
.
info
!=
other
.
info
:
return
False
elif
not
len
(
self
.
outputs
)
==
len
(
other
.
outputs
):
# Compare inner graphs
# TODO: Use `self.inner_fgraph == other.inner_fgraph`
if
len
(
self
.
inputs
)
!=
len
(
other
.
inputs
):
return
False
for
key
in
keys_to_check
:
if
self
.
info
[
key
]
!=
other
.
info
[
key
]:
return
False
# If everything went OK up to here, there is still one thing to
# check. Namely, do the internal graph represent same
# computations
if
len
(
self
.
outputs
)
!=
len
(
other
.
outputs
):
return
False
for
self_in
,
other_in
in
zip
(
self
.
inputs
,
other
.
inputs
):
if
self_in
.
type
!=
other_in
.
type
:
return
False
...
...
@@ -801,15 +803,7 @@ class Scan(Op):
return
aux_txt
def
__hash__
(
self
):
return
hash
(
(
type
(
self
),
# and a hash representing the inner graph using the
# CLinker.cmodule_key_
self
.
_hash_inner_graph
,
hash_listsDictsTuples
(
self
.
info
),
)
)
return
hash
((
type
(
self
),
self
.
_hash_inner_graph
,
self
.
info
))
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
,
impl
=
None
):
"""
...
...
@@ -864,7 +858,6 @@ class Scan(Op):
wrapped_inputs
=
[
In
(
x
,
borrow
=
False
)
for
x
in
self
.
inputs
[:
self
.
n_seqs
]]
new_outputs
=
[
x
for
x
in
self
.
outputs
]
preallocated_mitmot_outs
=
[]
new_mit_mot_out_slices
=
copy
.
deepcopy
(
self
.
mit_mot_out_slices
)
input_idx
=
self
.
n_seqs
for
mitmot_idx
in
range
(
self
.
n_mit_mot
):
...
...
@@ -894,7 +887,6 @@ class Scan(Op):
)
wrapped_inputs
.
append
(
wrapped_inp
)
preallocated_mitmot_outs
.
append
(
output_idx
)
new_mit_mot_out_slices
[
mitmot_idx
]
.
remove
(
inp_tap
)
else
:
# Wrap the corresponding input as usual. Leave the
# output as-is.
...
...
@@ -1963,7 +1955,7 @@ class Scan(Op):
outer_oidx
=
0
# Handle sequences inputs
for
i
in
range
(
self
.
info
[
"n_seqs"
]
):
for
i
in
range
(
self
.
info
.
n_seqs
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([])
...
...
@@ -1975,8 +1967,8 @@ class Scan(Op):
outer_oidx
+=
0
# Handle mitmots, mitsots and sitsots variables
for
i
in
range
(
len
(
self
.
info
[
"tap_array"
]
)):
nb_input_taps
=
len
(
self
.
info
[
"tap_array"
]
[
i
])
for
i
in
range
(
len
(
self
.
info
.
tap_array
)):
nb_input_taps
=
len
(
self
.
info
.
tap_array
[
i
])
if
i
<
self
.
n_mit_mot
:
nb_output_taps
=
len
(
self
.
mit_mot_out_slices
[
i
])
...
...
@@ -1999,7 +1991,7 @@ class Scan(Op):
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx
+=
self
.
info
[
"n_shared_outs"
]
outer_iidx
+=
self
.
info
.
n_shared_outs
# Handle nitsots variables
for
i
in
range
(
self
.
n_nit_sot
):
...
...
@@ -2015,10 +2007,10 @@ class Scan(Op):
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx
-=
self
.
info
[
"n_shared_outs"
]
+
self
.
n_nit_sot
outer_iidx
-=
self
.
info
.
n_shared_outs
+
self
.
n_nit_sot
# Handle shared states
for
i
in
range
(
self
.
info
[
"n_shared_outs"
]
):
for
i
in
range
(
self
.
info
.
n_shared_outs
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([
inner_oidx
])
...
...
@@ -2158,10 +2150,10 @@ class Scan(Op):
oidx
+=
1
iidx
-=
len
(
taps
)
if
iidx
<
self
.
info
[
"n_sit_sot"
]
:
if
iidx
<
self
.
info
.
n_sit_sot
:
return
oidx
+
iidx
else
:
return
oidx
+
iidx
+
self
.
info
[
"n_nit_sot"
]
return
oidx
+
iidx
+
self
.
info
.
n_nit_sot
def
get_out_idx
(
iidx
):
oidx
=
0
...
...
@@ -2241,9 +2233,9 @@ class Scan(Op):
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because
# the exact same variable can be used as multiple outputs.
idx_nitsot_start
=
(
self
.
info
[
"n_mit_mot"
]
+
self
.
info
[
"n_mit_sot"
]
+
self
.
info
[
"n_sit_sot"
]
self
.
info
.
n_mit_mot
+
self
.
info
.
n_mit_sot
+
self
.
info
.
n_sit_sot
)
idx_nitsot_end
=
idx_nitsot_start
+
self
.
info
[
"n_nit_sot"
]
idx_nitsot_end
=
idx_nitsot_start
+
self
.
info
.
n_nit_sot
if
idx
<
idx_nitsot_start
or
idx
>=
idx_nitsot_end
:
# What we do here is loop through dC_douts and collect all
# those that are connected to the specific one and do an
...
...
@@ -2668,27 +2660,23 @@ class Scan(Op):
n_sitsot_outs
=
len
(
outer_inp_sitsot
)
new_tap_array
=
mitmot_inp_taps
+
[[
-
1
]
for
k
in
range
(
n_sitsot_outs
)]
info
=
OrderedDict
()
info
[
"n_seqs"
]
=
len
(
outer_inp_seqs
)
info
[
"n_mit_sot"
]
=
0
info
[
"tap_array"
]
=
new_tap_array
info
[
"gpua"
]
=
False
info
[
"n_mit_mot"
]
=
len
(
outer_inp_mitmot
)
info
[
"n_mit_mot_outs"
]
=
n_mitmot_outs
info
[
"mit_mot_out_slices"
]
=
mitmot_out_taps
info
[
"truncate_gradient"
]
=
self
.
truncate_gradient
info
[
"n_sit_sot"
]
=
n_sitsot_outs
info
[
"n_shared_outs"
]
=
0
info
[
"n_nit_sot"
]
=
n_nit_sot
info
[
"as_while"
]
=
False
info
[
"profile"
]
=
self
.
profile
info
[
"destroy_map"
]
=
OrderedDict
()
if
self
.
name
:
info
[
"name"
]
=
"grad_of_"
+
self
.
name
else
:
info
[
"name"
]
=
None
info
[
"mode"
]
=
self
.
mode
info
[
"allow_gc"
]
=
self
.
allow_gc
info
=
ScanInfo
(
n_seqs
=
len
(
outer_inp_seqs
),
n_mit_sot
=
0
,
tap_array
=
tuple
(
tuple
(
v
)
for
v
in
new_tap_array
),
gpua
=
False
,
n_mit_mot
=
len
(
outer_inp_mitmot
),
n_mit_mot_outs
=
n_mitmot_outs
,
mit_mot_out_slices
=
tuple
(
tuple
(
v
)
for
v
in
mitmot_out_taps
),
truncate_gradient
=
self
.
truncate_gradient
,
n_sit_sot
=
n_sitsot_outs
,
n_shared_outs
=
0
,
n_nit_sot
=
n_nit_sot
,
as_while
=
False
,
profile
=
self
.
profile
,
name
=
f
"grad_of_{self.name}"
if
self
.
name
else
None
,
allow_gc
=
self
.
allow_gc
,
)
outer_inputs
=
(
[
grad_steps
]
...
...
@@ -2709,7 +2697,7 @@ class Scan(Op):
)
inner_gfn_outs
=
inner_out_mitmot
+
inner_out_sitsot
+
inner_out_nitsot
local_op
=
Scan
(
inner_gfn_ins
,
inner_gfn_outs
,
info
)
local_op
=
Scan
(
inner_gfn_ins
,
inner_gfn_outs
,
info
,
self
.
mode
)
outputs
=
local_op
(
*
outer_inputs
)
if
type
(
outputs
)
not
in
(
list
,
tuple
):
outputs
=
[
outputs
]
...
...
@@ -2852,8 +2840,8 @@ class Scan(Op):
rop_self_outputs
=
self_outputs
[:
-
1
]
else
:
rop_self_outputs
=
self_outputs
if
self
.
info
[
"n_shared_outs"
]
>
0
:
rop_self_outputs
=
rop_self_outputs
[:
-
self
.
info
[
"n_shared_outs"
]
]
if
self
.
info
.
n_shared_outs
>
0
:
rop_self_outputs
=
rop_self_outputs
[:
-
self
.
info
.
n_shared_outs
]
rop_outs
=
Rop
(
rop_self_outputs
,
rop_of_inputs
,
inner_eval_points
)
if
type
(
rop_outs
)
not
in
(
list
,
tuple
):
rop_outs
=
[
rop_outs
]
...
...
@@ -2867,25 +2855,7 @@ class Scan(Op):
# The only exception is the eval point for the number of sequences, and
# evan point for the number of nit_sot which I think should just be
# ignored (?)
info
=
OrderedDict
()
info
[
"n_seqs"
]
=
self
.
n_seqs
*
2
info
[
"n_mit_sot"
]
=
self
.
n_mit_sot
*
2
info
[
"n_sit_sot"
]
=
self
.
n_sit_sot
*
2
info
[
"n_mit_mot"
]
=
self
.
n_mit_mot
*
2
info
[
"n_nit_sot"
]
=
self
.
n_nit_sot
*
2
info
[
"n_shared_outs"
]
=
self
.
n_shared_outs
info
[
"gpua"
]
=
False
info
[
"as_while"
]
=
self
.
as_while
info
[
"profile"
]
=
self
.
profile
info
[
"truncate_gradient"
]
=
self
.
truncate_gradient
if
self
.
name
:
info
[
"name"
]
=
"rop_of_"
+
self
.
name
else
:
info
[
"name"
]
=
None
info
[
"mode"
]
=
self
.
mode
info
[
"allow_gc"
]
=
self
.
allow_gc
info
[
"mit_mot_out_slices"
]
=
self
.
mit_mot_out_slices
*
2
info
[
"destroy_map"
]
=
OrderedDict
()
new_tap_array
=
[]
b
=
0
e
=
self
.
n_mit_mot
...
...
@@ -2896,7 +2866,6 @@ class Scan(Op):
b
=
e
e
+=
self
.
n_sit_sot
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
info
[
"tap_array"
]
=
new_tap_array
# Sequences ...
b
=
1
...
...
@@ -2993,7 +2962,7 @@ class Scan(Op):
# Outputs
n_mit_mot_outs
=
int
(
np
.
sum
([
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
]))
info
[
"n_mit_mot_outs"
]
=
n_mit_mot_outs
*
2
b
=
0
e
=
n_mit_mot_outs
inner_out_mit_mot
=
self_outputs
[
b
:
e
]
+
rop_outs
[
b
:
e
]
...
...
@@ -3039,7 +3008,25 @@ class Scan(Op):
+
scan_other
)
local_op
=
Scan
(
inner_ins
,
inner_outs
,
info
)
info
=
ScanInfo
(
n_seqs
=
self
.
n_seqs
*
2
,
n_mit_sot
=
self
.
n_mit_sot
*
2
,
n_sit_sot
=
self
.
n_sit_sot
*
2
,
n_mit_mot
=
self
.
n_mit_mot
*
2
,
n_nit_sot
=
self
.
n_nit_sot
*
2
,
n_shared_outs
=
self
.
n_shared_outs
,
n_mit_mot_outs
=
n_mit_mot_outs
*
2
,
gpua
=
False
,
as_while
=
self
.
as_while
,
profile
=
self
.
profile
,
truncate_gradient
=
self
.
truncate_gradient
,
name
=
f
"rop_of_{self.name}"
if
self
.
name
else
None
,
allow_gc
=
self
.
allow_gc
,
tap_array
=
tuple
(
tuple
(
v
)
for
v
in
new_tap_array
),
mit_mot_out_slices
=
tuple
(
tuple
(
v
)
for
v
in
self
.
mit_mot_out_slices
)
*
2
,
)
local_op
=
Scan
(
inner_ins
,
inner_outs
,
info
,
self
.
mode
)
outputs
=
local_op
(
*
scan_inputs
)
if
type
(
outputs
)
not
in
(
list
,
tuple
):
outputs
=
[
outputs
]
...
...
aesara/scan/opt.py
浏览文件 @
e85c7fd0
...
...
@@ -51,8 +51,8 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
"""
import
copy
import
dataclasses
import
logging
from
collections
import
OrderedDict
from
sys
import
maxsize
import
numpy
as
np
...
...
@@ -78,7 +78,7 @@ from aesara.graph.fg import InconsistencyError
from
aesara.graph.op
import
compute_test_value
from
aesara.graph.opt
import
GlobalOptimizer
,
in2out
,
local_optimizer
from
aesara.graph.optdb
import
EquilibriumDB
,
SequenceDB
from
aesara.scan.op
import
Scan
from
aesara.scan.op
import
Scan
,
ScanInfo
from
aesara.scan.utils
import
(
ScanArgs
,
compress_outs
,
...
...
@@ -156,7 +156,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
out_stuff_outer
=
node
.
inputs
[
1
+
op
.
n_seqs
:
st
]
# To replace constants in the outer graph by clones in the inner graph
givens
=
OrderedDict
()
givens
=
{}
# All the inputs of the inner graph of the new scan
nw_inner
=
[]
# Same for the outer graph, initialized w/ number of steps
...
...
@@ -217,12 +217,10 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
if
len
(
nw_inner
)
!=
len
(
op_ins
):
op_outs
=
clone_replace
(
op_outs
,
replace
=
givens
)
nw_info
=
copy
.
deepcopy
(
op
.
info
)
nw_info
[
"n_seqs"
]
=
nw_n_seqs
# DEBUG CHECK
nwScan
=
Scan
(
nw_inner
,
op_outs
,
nw_info
)
nw_info
=
dataclasses
.
replace
(
op
.
info
,
n_seqs
=
nw_n_seqs
)
nwScan
=
Scan
(
nw_inner
,
op_outs
,
nw_info
,
op
.
mode
)
nw_outs
=
nwScan
(
*
nw_outer
,
**
dict
(
return_list
=
True
))
return
OrderedD
ict
([(
"remove"
,
[
node
])]
+
list
(
zip
(
node
.
outputs
,
nw_outs
)))
return
d
ict
([(
"remove"
,
[
node
])]
+
list
(
zip
(
node
.
outputs
,
nw_outs
)))
else
:
return
False
...
...
@@ -263,7 +261,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
to_remove_set
=
set
()
to_replace_set
=
set
()
to_replace_map
=
OrderedDict
()
to_replace_map
=
{}
def
add_to_replace
(
y
):
to_replace_set
.
add
(
y
)
...
...
@@ -377,7 +375,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
if
len
(
clean_to_replace
)
>
0
:
# We can finally put an end to all this madness
givens
=
OrderedDict
()
givens
=
{}
nw_outer
=
[]
nw_inner
=
[]
for
to_repl
,
repl_in
,
repl_out
in
zip
(
...
...
@@ -394,7 +392,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
op_ins
=
clean_inputs
+
nw_inner
# Reconstruct node
nwScan
=
Scan
(
op_ins
,
op_outs
,
op
.
info
)
nwScan
=
Scan
(
op_ins
,
op_outs
,
op
.
info
,
op
.
mode
)
# Do not call make_node for test_value
nw_node
=
nwScan
(
*
(
node
.
inputs
+
nw_outer
),
**
dict
(
return_list
=
True
))[
...
...
@@ -409,7 +407,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
return
True
elif
not
to_keep_set
:
# Nothing in the inner graph should be kept
replace_with
=
OrderedDict
()
replace_with
=
{}
for
out
,
idx
in
to_replace_map
.
items
():
if
out
in
local_fgraph_outs_set
:
x
=
node
.
outputs
[
local_fgraph_outs_map
[
out
]]
...
...
@@ -481,7 +479,7 @@ class PushOutSeqScan(GlobalOptimizer):
to_remove_set
=
set
()
to_replace_set
=
set
()
to_replace_map
=
OrderedDict
()
to_replace_map
=
{}
def
add_to_replace
(
y
):
to_replace_set
.
add
(
y
)
...
...
@@ -638,7 +636,7 @@ class PushOutSeqScan(GlobalOptimizer):
if
len
(
clean_to_replace
)
>
0
:
# We can finally put an end to all this madness
givens
=
OrderedDict
()
givens
=
{}
nw_outer
=
[]
nw_inner
=
[]
for
to_repl
,
repl_in
,
repl_out
in
zip
(
...
...
@@ -656,9 +654,10 @@ class PushOutSeqScan(GlobalOptimizer):
op_ins
=
nw_inner
+
clean_inputs
# Reconstruct node
nw_info
=
op
.
info
.
copy
()
nw_info
[
"n_seqs"
]
+=
len
(
nw_inner
)
nwScan
=
Scan
(
op_ins
,
op_outs
,
nw_info
)
nw_info
=
dataclasses
.
replace
(
op
.
info
,
n_seqs
=
op
.
info
.
n_seqs
+
len
(
nw_inner
)
)
nwScan
=
Scan
(
op_ins
,
op_outs
,
nw_info
,
op
.
mode
)
# Do not call make_node for test_value
nw_node
=
nwScan
(
*
(
node
.
inputs
[:
1
]
+
nw_outer
+
node
.
inputs
[
1
:]),
...
...
@@ -673,7 +672,7 @@ class PushOutSeqScan(GlobalOptimizer):
return
True
elif
not
to_keep_set
and
not
op
.
as_while
and
not
op
.
outer_mitmot
(
node
):
# Nothing in the inner graph should be kept
replace_with
=
OrderedDict
()
replace_with
=
{}
for
out
,
idx
in
to_replace_map
.
items
():
if
out
in
local_fgraph_outs_set
:
x
=
node
.
outputs
[
local_fgraph_outs_map
[
out
]]
...
...
@@ -937,7 +936,10 @@ class PushOutScanOutput(GlobalOptimizer):
# Create the `Scan` `Op` from the `ScanArgs`
new_scan_op
=
Scan
(
new_scan_args
.
inner_inputs
,
new_scan_args
.
inner_outputs
,
new_scan_args
.
info
new_scan_args
.
inner_inputs
,
new_scan_args
.
inner_outputs
,
new_scan_args
.
info
,
old_scan_node
.
op
.
mode
,
)
# Create the Apply node for the scan op
...
...
@@ -1000,13 +1002,6 @@ class ScanInplaceOptimizer(GlobalOptimizer):
op
=
node
.
op
info
=
copy
.
deepcopy
(
op
.
info
)
if
"destroy_map"
not
in
info
:
info
[
"destroy_map"
]
=
OrderedDict
()
for
out_idx
in
output_indices
:
info
[
"destroy_map"
][
out_idx
]
=
[
out_idx
+
1
+
op
.
info
[
"n_seqs"
]]
# inputs corresponding to sequences and n_steps
ls_begin
=
node
.
inputs
[:
1
+
op
.
n_seqs
]
ls
=
op
.
outer_mitmot
(
node
.
inputs
)
...
...
@@ -1048,7 +1043,15 @@ class ScanInplaceOptimizer(GlobalOptimizer):
else
:
typeConstructor
=
self
.
typeInfer
(
node
)
new_op
=
Scan
(
op
.
inputs
,
op
.
outputs
,
info
,
typeConstructor
=
typeConstructor
)
new_op
=
Scan
(
op
.
inputs
,
op
.
outputs
,
op
.
info
,
op
.
mode
,
typeConstructor
=
typeConstructor
)
destroy_map
=
op
.
destroy_map
.
copy
()
for
out_idx
in
output_indices
:
destroy_map
[
out_idx
]
=
[
out_idx
+
1
+
op
.
info
.
n_seqs
]
new_op
.
destroy_map
=
destroy_map
# Do not call make_node for test_value
new_outs
=
new_op
(
*
inputs
,
**
dict
(
return_list
=
True
))
...
...
@@ -1070,7 +1073,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
scan_nodes
=
[
x
for
x
in
nodes
if
(
isinstance
(
x
.
op
,
Scan
)
and
x
.
op
.
info
[
"gpua"
]
==
self
.
gpua_flag
)
if
(
isinstance
(
x
.
op
,
Scan
)
and
x
.
op
.
info
.
gpua
==
self
.
gpua_flag
)
]
for
scan_idx
in
range
(
len
(
scan_nodes
)):
...
...
@@ -1080,7 +1083,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# them.
original_node
=
scan_nodes
[
scan_idx
]
op
=
original_node
.
op
n_outs
=
op
.
info
[
"n_mit_mot"
]
+
op
.
info
[
"n_mit_sot"
]
+
op
.
info
[
"n_sit_sot"
]
n_outs
=
op
.
info
.
n_mit_mot
+
op
.
info
.
n_mit_sot
+
op
.
info
.
n_sit_sot
# Generate a list of outputs on which the node could potentially
# operate inplace.
...
...
@@ -1171,7 +1174,7 @@ class ScanSaveMem(GlobalOptimizer):
# Each access to shape_of is in a try..except block in order to
# use a default version when the variable is not in the shape_of
# dictionary.
shape_of
=
OrderedDict
()
shape_of
=
{}
# 1. Initialization of variables
# Note 1) We do not actually care about outputs representing shared
# variables (those have no intermediate values) so it is safer to
...
...
@@ -1548,7 +1551,7 @@ class ScanSaveMem(GlobalOptimizer):
(
inps
,
outs
,
info
,
node_ins
,
compress_map
)
=
compress_outs
(
op
,
not_required
,
nw_inputs
)
inv_compress_map
=
OrderedDict
()
inv_compress_map
=
{}
for
k
,
v
in
compress_map
.
items
():
inv_compress_map
[
v
]
=
k
...
...
@@ -1559,7 +1562,9 @@ class ScanSaveMem(GlobalOptimizer):
return
# Do not call make_node for test_value
new_outs
=
Scan
(
inps
,
outs
,
info
)(
*
node_ins
,
**
dict
(
return_list
=
True
))
new_outs
=
Scan
(
inps
,
outs
,
info
,
op
.
mode
)(
*
node_ins
,
**
dict
(
return_list
=
True
)
)
old_new
=
[]
# 3.7 Get replace pairs for those outputs that do not change
...
...
@@ -1688,24 +1693,6 @@ class ScanMerge(GlobalOptimizer):
else
:
as_while
=
False
info
=
OrderedDict
()
info
[
"tap_array"
]
=
[]
info
[
"n_seqs"
]
=
sum
([
nd
.
op
.
n_seqs
for
nd
in
nodes
])
info
[
"n_mit_mot"
]
=
sum
([
nd
.
op
.
n_mit_mot
for
nd
in
nodes
])
info
[
"n_mit_mot_outs"
]
=
sum
([
nd
.
op
.
n_mit_mot_outs
for
nd
in
nodes
])
info
[
"mit_mot_out_slices"
]
=
[]
info
[
"n_mit_sot"
]
=
sum
([
nd
.
op
.
n_mit_sot
for
nd
in
nodes
])
info
[
"n_sit_sot"
]
=
sum
([
nd
.
op
.
n_sit_sot
for
nd
in
nodes
])
info
[
"n_shared_outs"
]
=
sum
([
nd
.
op
.
n_shared_outs
for
nd
in
nodes
])
info
[
"n_nit_sot"
]
=
sum
([
nd
.
op
.
n_nit_sot
for
nd
in
nodes
])
info
[
"truncate_gradient"
]
=
nodes
[
0
]
.
op
.
truncate_gradient
info
[
"name"
]
=
"&"
.
join
([
nd
.
op
.
name
for
nd
in
nodes
])
info
[
"mode"
]
=
nodes
[
0
]
.
op
.
mode
info
[
"gpua"
]
=
False
info
[
"as_while"
]
=
as_while
info
[
"profile"
]
=
nodes
[
0
]
.
op
.
profile
info
[
"allow_gc"
]
=
nodes
[
0
]
.
op
.
allow_gc
# We keep the inner_ins and inner_outs of each original node separated.
# To be able to recombine them in the right order after the clone,
# we also need to split them by types (seq, mitmot, ...).
...
...
@@ -1725,12 +1712,15 @@ class ScanMerge(GlobalOptimizer):
inner_ins
[
idx
]
.
append
(
rename
(
nd
.
op
.
inner_seqs
(
nd
.
op
.
inputs
),
idx
))
outer_ins
+=
rename
(
nd
.
op
.
outer_seqs
(
nd
.
inputs
),
idx
)
tap_array
=
()
mit_mot_out_slices
=
()
for
idx
,
nd
in
enumerate
(
nodes
):
# MitMot
inner_ins
[
idx
]
.
append
(
rename
(
nd
.
op
.
inner_mitmot
(
nd
.
op
.
inputs
),
idx
))
inner_outs
[
idx
]
.
append
(
nd
.
op
.
inner_mitmot_outs
(
nd
.
op
.
outputs
))
info
[
"tap_array"
]
+=
nd
.
op
.
mitmot_taps
()
info
[
"mit_mot_out_slices"
]
+=
nd
.
op
.
mitmot_out_taps
()
tap_array
+=
nd
.
op
.
mitmot_taps
()
mit_mot_out_slices
+=
nd
.
op
.
mitmot_out_taps
()
outer_ins
+=
rename
(
nd
.
op
.
outer_mitmot
(
nd
.
inputs
),
idx
)
outer_outs
+=
nd
.
op
.
outer_mitmot_outs
(
nd
.
outputs
)
...
...
@@ -1738,14 +1728,14 @@ class ScanMerge(GlobalOptimizer):
# MitSot
inner_ins
[
idx
]
.
append
(
rename
(
nd
.
op
.
inner_mitsot
(
nd
.
op
.
inputs
),
idx
))
inner_outs
[
idx
]
.
append
(
nd
.
op
.
inner_mitsot_outs
(
nd
.
op
.
outputs
))
info
[
"tap_array"
]
+=
nd
.
op
.
mitsot_taps
()
tap_array
+=
nd
.
op
.
mitsot_taps
()
outer_ins
+=
rename
(
nd
.
op
.
outer_mitsot
(
nd
.
inputs
),
idx
)
outer_outs
+=
nd
.
op
.
outer_mitsot_outs
(
nd
.
outputs
)
for
idx
,
nd
in
enumerate
(
nodes
):
# SitSot
inner_ins
[
idx
]
.
append
(
rename
(
nd
.
op
.
inner_sitsot
(
nd
.
op
.
inputs
),
idx
))
info
[
"tap_array"
]
+=
[[
-
1
]
for
x
in
range
(
nd
.
op
.
n_sit_sot
)]
tap_array
+=
tuple
((
-
1
,)
for
x
in
range
(
nd
.
op
.
n_sit_sot
))
inner_outs
[
idx
]
.
append
(
nd
.
op
.
inner_sitsot_outs
(
nd
.
op
.
outputs
))
outer_ins
+=
rename
(
nd
.
op
.
outer_sitsot
(
nd
.
inputs
),
idx
)
outer_outs
+=
nd
.
op
.
outer_sitsot_outs
(
nd
.
outputs
)
...
...
@@ -1834,7 +1824,25 @@ class ScanMerge(GlobalOptimizer):
else
:
new_inner_outs
+=
inner_outs
[
idx
][
gr_idx
]
new_op
=
Scan
(
new_inner_ins
,
new_inner_outs
,
info
)
info
=
ScanInfo
(
tap_array
=
tap_array
,
n_seqs
=
sum
([
nd
.
op
.
n_seqs
for
nd
in
nodes
]),
n_mit_mot
=
sum
([
nd
.
op
.
n_mit_mot
for
nd
in
nodes
]),
n_mit_mot_outs
=
sum
([
nd
.
op
.
n_mit_mot_outs
for
nd
in
nodes
]),
mit_mot_out_slices
=
mit_mot_out_slices
,
n_mit_sot
=
sum
([
nd
.
op
.
n_mit_sot
for
nd
in
nodes
]),
n_sit_sot
=
sum
([
nd
.
op
.
n_sit_sot
for
nd
in
nodes
]),
n_shared_outs
=
sum
([
nd
.
op
.
n_shared_outs
for
nd
in
nodes
]),
n_nit_sot
=
sum
([
nd
.
op
.
n_nit_sot
for
nd
in
nodes
]),
truncate_gradient
=
nodes
[
0
]
.
op
.
truncate_gradient
,
name
=
"&"
.
join
([
nd
.
op
.
name
for
nd
in
nodes
]),
gpua
=
False
,
as_while
=
as_while
,
profile
=
nodes
[
0
]
.
op
.
profile
,
allow_gc
=
nodes
[
0
]
.
op
.
allow_gc
,
)
new_op
=
Scan
(
new_inner_ins
,
new_inner_outs
,
info
,
nodes
[
0
]
.
op
.
mode
)
new_outs
=
new_op
(
*
outer_ins
)
if
not
isinstance
(
new_outs
,
(
list
,
tuple
)):
...
...
@@ -1932,7 +1940,7 @@ def make_equiv(lo, li):
the equivalence of their corresponding outer inputs.
"""
seeno
=
OrderedDict
()
seeno
=
{}
left
=
[]
right
=
[]
for
o
,
i
in
zip
(
lo
,
li
):
...
...
@@ -1956,7 +1964,7 @@ def scan_merge_inouts(fgraph, node):
node
.
inputs
,
node
.
outputs
,
node
.
op
.
inputs
,
node
.
op
.
outputs
,
node
.
op
.
info
)
inp_equiv
=
OrderedDict
()
inp_equiv
=
{}
if
has_duplicates
(
a
.
outer_in_seqs
):
new_outer_seqs
=
[]
...
...
@@ -1992,7 +2000,7 @@ def scan_merge_inouts(fgraph, node):
a_inner_outs
=
a
.
inner_outputs
inner_outputs
=
clone_replace
(
a_inner_outs
,
replace
=
inp_equiv
)
op
=
Scan
(
inner_inputs
,
inner_outputs
,
info
)
op
=
Scan
(
inner_inputs
,
inner_outputs
,
info
,
node
.
op
.
mode
)
outputs
=
op
(
*
outer_inputs
)
if
not
isinstance
(
outputs
,
(
list
,
tuple
)):
...
...
@@ -2019,7 +2027,7 @@ def scan_merge_inouts(fgraph, node):
left
+=
_left
right
+=
_right
if
has_duplicates
(
na
.
outer_in_mit_mot
):
seen
=
OrderedDict
()
seen
=
{}
for
omm
,
imm
,
_sl
in
zip
(
na
.
outer_in_mit_mot
,
na
.
inner_in_mit_mot
,
na
.
mit_mot_in_slices
):
...
...
@@ -2032,7 +2040,7 @@ def scan_merge_inouts(fgraph, node):
seen
[(
omm
,
sl
)]
=
imm
if
has_duplicates
(
na
.
outer_in_mit_sot
):
seen
=
OrderedDict
()
seen
=
{}
for
oms
,
ims
,
_sl
in
zip
(
na
.
outer_in_mit_sot
,
na
.
inner_in_mit_sot
,
na
.
mit_sot_in_slices
):
...
...
@@ -2117,9 +2125,7 @@ def scan_merge_inouts(fgraph, node):
new_outer_out_mit_mot
.
append
(
outer_omm
)
na
.
outer_out_mit_mot
=
new_outer_out_mit_mot
if
remove
:
return
OrderedDict
(
[(
"remove"
,
remove
)]
+
list
(
zip
(
node
.
outputs
,
na
.
outer_outputs
))
)
return
dict
([(
"remove"
,
remove
)]
+
list
(
zip
(
node
.
outputs
,
na
.
outer_outputs
)))
return
na
.
outer_outputs
...
...
@@ -2214,15 +2220,17 @@ class PushOutDot1(GlobalOptimizer):
inner_non_seqs
=
op
.
inner_non_seqs
(
op
.
inputs
)
outer_non_seqs
=
op
.
outer_non_seqs
(
node
)
new_info
=
op
.
info
.
copy
()
st
=
len
(
op
.
mitmot_taps
())
+
len
(
op
.
mitsot_taps
())
new_info
[
"tap_array"
]
=
(
new_info
[
"tap_array"
][:
st
+
idx
]
+
new_info
[
"tap_array"
][
st
+
idx
+
1
:]
new_info
=
dataclasses
.
replace
(
op
.
info
,
tap_array
=
(
op
.
info
.
tap_array
[:
st
+
idx
]
+
op
.
info
.
tap_array
[
st
+
idx
+
1
:]
),
n_sit_sot
=
op
.
info
.
n_sit_sot
-
1
,
n_nit_sot
=
op
.
info
.
n_nit_sot
+
1
,
)
new_info
[
"n_sit_sot"
]
-=
1
new_info
[
"n_nit_sot"
]
+=
1
inner_sitsot
=
inner_sitsot
[:
idx
]
+
inner_sitsot
[
idx
+
1
:]
outer_sitsot
=
outer_sitsot
[:
idx
]
+
outer_sitsot
[
idx
+
1
:]
inner_sitsot_outs
=
(
...
...
@@ -2249,7 +2257,7 @@ class PushOutDot1(GlobalOptimizer):
new_inner_inps
,
new_inner_outs
=
reconstruct_graph
(
_new_inner_inps
,
_new_inner_outs
)
new_op
=
Scan
(
new_inner_inps
,
new_inner_outs
,
new_info
)
new_op
=
Scan
(
new_inner_inps
,
new_inner_outs
,
new_info
,
op
.
mode
)
_scan_inputs
=
(
[
node
.
inputs
[
0
]]
+
outer_seqs
...
...
aesara/scan/utils.py
浏览文件 @
e85c7fd0
"""This module provides utility functions for the `Scan` `Op`."""
import
copy
import
dataclasses
import
logging
import
warnings
from
collections
import
OrderedDict
,
namedtuple
...
...
@@ -157,21 +158,6 @@ def traverse(out, x, x_copy, d, visited=None):
return
d
# Hashing a dictionary/list/tuple by xoring the hash of each element
def
hash_listsDictsTuples
(
x
):
hash_value
=
0
if
isinstance
(
x
,
dict
):
for
k
,
v
in
x
.
items
():
hash_value
^=
hash_listsDictsTuples
(
k
)
hash_value
^=
hash_listsDictsTuples
(
v
)
elif
isinstance
(
x
,
(
list
,
tuple
)):
for
v
in
x
:
hash_value
^=
hash_listsDictsTuples
(
v
)
else
:
hash_value
^=
hash
(
x
)
return
hash_value
def
map_variables
(
replacer
,
graphs
,
additional_inputs
=
None
):
"""Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'.
...
...
@@ -264,6 +250,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
new_inner_inputs
,
new_inner_outputs
,
node
.
op
.
info
,
node
.
op
.
mode
,
# FIXME: infer this someday?
typeConstructor
=
None
,
)
...
...
@@ -669,7 +656,7 @@ def scan_can_remove_outs(op, out_idxs):
offset
=
op
.
n_seqs
lim
=
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
for
idx
in
range
(
lim
):
n_ins
=
len
(
op
.
info
[
"tap_array"
]
[
idx
])
n_ins
=
len
(
op
.
info
.
tap_array
[
idx
])
out_ins
+=
[
op
.
inputs
[
offset
:
offset
+
n_ins
]]
offset
+=
n_ins
out_ins
+=
[[]
for
k
in
range
(
op
.
n_nit_sot
)]
...
...
@@ -702,23 +689,25 @@ def compress_outs(op, not_required, inputs):
node inputs, and changing the dictionary.
"""
info
=
OrderedDict
()
info
[
"tap_array"
]
=
[]
info
[
"n_seqs"
]
=
op
.
info
[
"n_seqs"
]
info
[
"n_mit_mot"
]
=
0
info
[
"n_mit_mot_outs"
]
=
0
info
[
"mit_mot_out_slices"
]
=
[]
info
[
"n_mit_sot"
]
=
0
info
[
"n_sit_sot"
]
=
0
info
[
"n_shared_outs"
]
=
0
info
[
"n_nit_sot"
]
=
0
info
[
"truncate_gradient"
]
=
op
.
info
[
"truncate_gradient"
]
info
[
"name"
]
=
op
.
info
[
"name"
]
info
[
"gpua"
]
=
op
.
info
[
"gpua"
]
info
[
"mode"
]
=
op
.
info
[
"mode"
]
info
[
"as_while"
]
=
op
.
info
[
"as_while"
]
info
[
"profile"
]
=
op
.
info
[
"profile"
]
info
[
"allow_gc"
]
=
op
.
info
[
"allow_gc"
]
from
aesara.scan.op
import
ScanInfo
info
=
ScanInfo
(
tap_array
=
(),
n_seqs
=
op
.
info
.
n_seqs
,
n_mit_mot
=
0
,
n_mit_mot_outs
=
0
,
mit_mot_out_slices
=
(),
n_mit_sot
=
0
,
n_sit_sot
=
0
,
n_shared_outs
=
0
,
n_nit_sot
=
0
,
truncate_gradient
=
op
.
info
.
truncate_gradient
,
name
=
op
.
info
.
name
,
gpua
=
op
.
info
.
gpua
,
as_while
=
op
.
info
.
as_while
,
profile
=
op
.
info
.
profile
,
allow_gc
=
op
.
info
.
allow_gc
,
)
op_inputs
=
op
.
inputs
[:
op
.
n_seqs
]
op_outputs
=
[]
...
...
@@ -730,13 +719,17 @@ def compress_outs(op, not_required, inputs):
i_offset
=
op
.
n_seqs
o_offset
=
0
curr_pos
=
0
for
idx
in
range
(
op
.
info
[
"n_mit_mot"
]
):
for
idx
in
range
(
op
.
info
.
n_mit_mot
):
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
info
[
"n_mit_mot"
]
+=
1
info
[
"tap_array"
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
[
"mit_mot_out_slices"
]
+=
[
op
.
mit_mot_out_slices
[
offset
+
idx
]]
info
=
dataclasses
.
replace
(
info
,
n_mit_mot
=
info
.
n_mit_mot
+
1
,
tap_array
=
info
.
tap_array
+
(
tuple
(
op
.
tap_array
[
offset
+
idx
]),),
mit_mot_out_slices
=
info
.
mit_mot_out_slices
+
(
tuple
(
op
.
mit_mot_out_slices
[
offset
+
idx
]),),
)
# input taps
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
...
...
@@ -750,16 +743,19 @@ def compress_outs(op, not_required, inputs):
else
:
o_offset
+=
len
(
op
.
mit_mot_out_slices
[
offset
+
idx
])
i_offset
+=
len
(
op
.
tap_array
[
offset
+
idx
])
info
[
"n_mit_mot_outs"
]
=
len
(
op_outputs
)
info
=
dataclasses
.
replace
(
info
,
n_mit_mot_outs
=
len
(
op_outputs
)
)
offset
+=
op
.
n_mit_mot
ni_offset
+=
op
.
n_mit_mot
for
idx
in
range
(
op
.
info
[
"n_mit_sot"
]
):
for
idx
in
range
(
op
.
info
.
n_mit_sot
):
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
info
[
"n_mit_sot"
]
+=
1
info
[
"tap_array"
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
=
dataclasses
.
replace
(
info
,
n_mit_sot
=
info
.
n_mit_sot
+
1
,
tap_array
=
info
.
tap_array
+
(
tuple
(
op
.
tap_array
[
offset
+
idx
]),),
)
# input taps
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
...
...
@@ -775,12 +771,15 @@ def compress_outs(op, not_required, inputs):
offset
+=
op
.
n_mit_sot
ni_offset
+=
op
.
n_mit_sot
for
idx
in
range
(
op
.
info
[
"n_sit_sot"
]
):
for
idx
in
range
(
op
.
info
.
n_sit_sot
):
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
info
[
"n_sit_sot"
]
+=
1
info
[
"tap_array"
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
=
dataclasses
.
replace
(
info
,
n_sit_sot
=
info
.
n_sit_sot
+
1
,
tap_array
=
info
.
tap_array
+
(
tuple
(
op
.
tap_array
[
offset
+
idx
]),),
)
# input taps
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
i_offset
+=
1
...
...
@@ -796,11 +795,11 @@ def compress_outs(op, not_required, inputs):
offset
+=
op
.
n_sit_sot
ni_offset
+=
op
.
n_sit_sot
nit_sot_ins
=
[]
for
idx
in
range
(
op
.
info
[
"n_nit_sot"
]
):
for
idx
in
range
(
op
.
info
.
n_nit_sot
):
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
info
[
"n_nit_sot"
]
+=
1
info
=
dataclasses
.
replace
(
info
,
n_nit_sot
=
info
.
n_nit_sot
+
1
)
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
nit_sot_ins
+=
[
inputs
[
ni_offset
+
idx
+
op
.
n_shared_outs
]]
...
...
@@ -809,11 +808,11 @@ def compress_outs(op, not_required, inputs):
offset
+=
op
.
n_nit_sot
shared_ins
=
[]
for
idx
in
range
(
op
.
info
[
"n_shared_outs"
]
):
for
idx
in
range
(
op
.
info
.
n_shared_outs
):
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
info
[
"n_shared_outs"
]
+=
1
info
=
dataclasses
.
replace
(
info
,
n_shared_outs
=
info
.
n_shared_outs
+
1
)
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
...
...
@@ -896,7 +895,7 @@ class ScanArgs:
else
:
rval
=
(
_inner_inputs
,
_inner_outputs
)
if
info
[
"as_while"
]
:
if
info
.
as_while
:
self
.
cond
=
[
rval
[
1
][
-
1
]]
inner_outputs
=
rval
[
1
][:
-
1
]
else
:
...
...
@@ -907,17 +906,17 @@ class ScanArgs:
p
=
1
q
=
0
n_seqs
=
info
[
"n_seqs"
]
n_seqs
=
info
.
n_seqs
self
.
outer_in_seqs
=
outer_inputs
[
p
:
p
+
n_seqs
]
self
.
inner_in_seqs
=
inner_inputs
[
q
:
q
+
n_seqs
]
p
+=
n_seqs
q
+=
n_seqs
n_mit_mot
=
info
[
"n_mit_mot"
]
n_mit_sot
=
info
[
"n_mit_sot"
]
n_mit_mot
=
info
.
n_mit_mot
n_mit_sot
=
info
.
n_mit_sot
self
.
mit_mot_in_slices
=
info
[
"tap_array"
]
[:
n_mit_mot
]
self
.
mit_sot_in_slices
=
info
[
"tap_array"
]
[
n_mit_mot
:
n_mit_mot
+
n_mit_sot
]
self
.
mit_mot_in_slices
=
info
.
tap_array
[:
n_mit_mot
]
self
.
mit_sot_in_slices
=
info
.
tap_array
[
n_mit_mot
:
n_mit_mot
+
n_mit_sot
]
n_mit_mot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_in_slices
)
n_mit_sot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_sot_in_slices
)
...
...
@@ -943,19 +942,19 @@ class ScanArgs:
self
.
outer_in_mit_sot
=
outer_inputs
[
p
:
p
+
n_mit_sot
]
p
+=
n_mit_sot
n_sit_sot
=
info
[
"n_sit_sot"
]
n_sit_sot
=
info
.
n_sit_sot
self
.
outer_in_sit_sot
=
outer_inputs
[
p
:
p
+
n_sit_sot
]
self
.
inner_in_sit_sot
=
inner_inputs
[
q
:
q
+
n_sit_sot
]
p
+=
n_sit_sot
q
+=
n_sit_sot
n_shared_outs
=
info
[
"n_shared_outs"
]
n_shared_outs
=
info
.
n_shared_outs
self
.
outer_in_shared
=
outer_inputs
[
p
:
p
+
n_shared_outs
]
self
.
inner_in_shared
=
inner_inputs
[
q
:
q
+
n_shared_outs
]
p
+=
n_shared_outs
q
+=
n_shared_outs
n_nit_sot
=
info
[
"n_nit_sot"
]
n_nit_sot
=
info
.
n_nit_sot
self
.
outer_in_nit_sot
=
outer_inputs
[
p
:
p
+
n_nit_sot
]
p
+=
n_nit_sot
...
...
@@ -966,14 +965,14 @@ class ScanArgs:
p
=
0
q
=
0
self
.
mit_mot_out_slices
=
info
[
"mit_mot_out_slices"
]
n_mit_mot_outs
=
info
[
"n_mit_mot_outs"
]
self
.
mit_mot_out_slices
=
info
.
mit_mot_out_slices
n_mit_mot_outs
=
info
.
n_mit_mot_outs
self
.
outer_out_mit_mot
=
outer_outputs
[
p
:
p
+
n_mit_mot
]
iomm
=
inner_outputs
[
q
:
q
+
n_mit_mot_outs
]
self
.
inner_out_mit_mot
=
[]
self
.
inner_out_mit_mot
=
()
qq
=
0
for
sl
in
self
.
mit_mot_out_slices
:
self
.
inner_out_mit_mot
.
append
(
iomm
[
qq
:
qq
+
len
(
sl
)]
)
self
.
inner_out_mit_mot
+=
(
iomm
[
qq
:
qq
+
len
(
sl
)],
)
qq
+=
len
(
sl
)
p
+=
n_mit_mot
q
+=
n_mit_mot_outs
...
...
@@ -1001,19 +1000,17 @@ class ScanArgs:
assert
p
==
len
(
outer_outputs
)
assert
q
==
len
(
inner_outputs
)
self
.
other_info
=
OrderedDict
()
for
k
in
(
"truncate_gradient"
,
"name"
,
"mode"
,
"destroy_map"
,
"gpua"
,
"as_while"
,
"profile"
,
"allow_gc"
,
):
if
k
in
info
:
self
.
other_info
[
k
]
=
info
[
k
]
self
.
other_info
=
{
k
:
getattr
(
info
,
k
)
for
k
in
(
"truncate_gradient"
,
"name"
,
"gpua"
,
"as_while"
,
"profile"
,
"allow_gc"
,
)
}
@staticmethod
def
from_node
(
node
,
clone
=
False
):
...
...
@@ -1032,26 +1029,24 @@ class ScanArgs:
@classmethod
def
create_empty
(
cls
):
info
=
OrderedDict
(
[
(
"n_seqs"
,
0
),
(
"n_mit_mot"
,
0
),
(
"n_mit_sot"
,
0
),
(
"tap_array"
,
[]),
(
"n_sit_sot"
,
0
),
(
"n_nit_sot"
,
0
),
(
"n_shared_outs"
,
0
),
(
"n_mit_mot_outs"
,
0
),
(
"mit_mot_out_slices"
,
[]),
(
"truncate_gradient"
,
-
1
),
(
"name"
,
None
),
(
"mode"
,
None
),
(
"destroy_map"
,
OrderedDict
()),
(
"gpua"
,
False
),
(
"as_while"
,
False
),
(
"profile"
,
False
),
(
"allow_gc"
,
False
),
]
from
aesara.scan.op
import
ScanInfo
info
=
ScanInfo
(
n_seqs
=
0
,
n_mit_mot
=
0
,
n_mit_sot
=
0
,
tap_array
=
(),
n_sit_sot
=
0
,
n_nit_sot
=
0
,
n_shared_outs
=
0
,
n_mit_mot_outs
=
0
,
mit_mot_out_slices
=
(),
truncate_gradient
=-
1
,
name
=
None
,
gpua
=
False
,
as_while
=
False
,
profile
=
False
,
allow_gc
=
False
,
)
res
=
cls
([
1
],
[],
[],
[],
info
)
res
.
n_steps
=
None
...
...
@@ -1060,7 +1055,7 @@ class ScanArgs:
@property
def
n_nit_sot
(
self
):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return
self
.
info
[
"n_nit_sot"
]
return
self
.
info
.
n_nit_sot
@property
def
inputs
(
self
):
...
...
@@ -1070,7 +1065,7 @@ class ScanArgs:
@property
def
n_mit_mot
(
self
):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return
self
.
info
[
"n_mit_mot"
]
return
self
.
info
.
n_mit_mot
@property
def
var_mappings
(
self
):
...
...
@@ -1141,20 +1136,22 @@ class ScanArgs:
@property
def
info
(
self
):
return
OrderedDict
(
from
aesara.scan.op
import
ScanInfo
return
ScanInfo
(
n_seqs
=
len
(
self
.
outer_in_seqs
),
n_mit_mot
=
len
(
self
.
outer_in_mit_mot
),
n_mit_sot
=
len
(
self
.
outer_in_mit_sot
),
tap_array
=
(
self
.
mit_mot_in_slices
+
self
.
mit_sot_in_slices
+
[[
-
1
]]
*
len
(
self
.
inner_in_sit_sot
)
tuple
(
tuple
(
v
)
for
v
in
self
.
mit_mot_in_slices
)
+
tuple
(
tuple
(
v
)
for
v
in
self
.
mit_sot_in_slices
)
+
((
-
1
,),)
*
len
(
self
.
inner_in_sit_sot
)
),
n_sit_sot
=
len
(
self
.
outer_in_sit_sot
),
n_nit_sot
=
len
(
self
.
outer_in_nit_sot
),
n_shared_outs
=
len
(
self
.
outer_in_shared
),
n_mit_mot_outs
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_out_slices
),
mit_mot_out_slices
=
self
.
mit_mot_out_slices
,
mit_mot_out_slices
=
tuple
(
self
.
mit_mot_out_slices
)
,
**
self
.
other_info
,
)
...
...
requirements.txt
浏览文件 @
e85c7fd0
-e ./
dataclasses
>=0.7; python_version < '3.7'
filelock
flake8
==3.8.4
pep8
...
...
setup.py
浏览文件 @
e85c7fd0
#!/usr/bin/env python
import
sys
from
setuptools
import
find_packages
,
setup
import
versioneer
...
...
@@ -43,6 +45,11 @@ Programming Language :: Python :: 3.9
"""
CLASSIFIERS
=
[
_f
for
_f
in
CLASSIFIERS
.
split
(
"
\n
"
)
if
_f
]
install_requires
=
[
"numpy>=1.17.0"
,
"scipy>=0.14"
,
"filelock"
]
if
sys
.
version_info
[
0
:
2
]
<
(
3
,
7
):
install_requires
+=
[
"dataclasses"
]
if
__name__
==
"__main__"
:
setup
(
name
=
NAME
,
...
...
@@ -57,7 +64,7 @@ if __name__ == "__main__":
license
=
LICENSE
,
platforms
=
PLATFORMS
,
packages
=
find_packages
(
exclude
=
[
"tests"
,
"tests.*"
]),
install_requires
=
[
"numpy>=1.17.0"
,
"scipy>=0.14"
,
"filelock"
]
,
install_requires
=
install_requires
,
package_data
=
{
""
:
[
"*.txt"
,
...
...
tests/scan/test_utils.py
浏览文件 @
e85c7fd0
...
...
@@ -252,10 +252,7 @@ def test_ScanArgs():
# The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same)
assert
scan_args
.
inputs
==
scan_op
.
inputs
scan_op_info
=
dict
(
scan_op
.
info
)
# The `ScanInfo` dictionary has the wrong order and an extra entry
del
scan_op_info
[
"strict"
]
assert
dict
(
scan_args
.
info
)
==
scan_op_info
assert
scan_args
.
info
==
scan_op
.
info
assert
scan_args
.
var_mappings
==
scan_op
.
var_mappings
# Check that `ScanArgs.find_among_fields` works
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论