Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e85c7fd0
提交
e85c7fd0
authored
8月 08, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
9月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace Scan info dict with ScanInfo dataclass
上级
d83ff33c
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
334 行增加
和
336 行删除
+334
-336
basic.py
aesara/scan/basic.py
+25
-24
op.py
aesara/scan/op.py
+130
-143
opt.py
aesara/scan/opt.py
+78
-70
utils.py
aesara/scan/utils.py
+91
-94
requirements.txt
requirements.txt
+1
-0
setup.py
setup.py
+8
-1
test_utils.py
tests/scan/test_utils.py
+1
-4
没有找到文件。
aesara/scan/basic.py
浏览文件 @
e85c7fd0
...
@@ -13,7 +13,7 @@ from aesara.graph.fg import MissingInputError
...
@@ -13,7 +13,7 @@ from aesara.graph.fg import MissingInputError
from
aesara.graph.op
import
get_test_value
from
aesara.graph.op
import
get_test_value
from
aesara.graph.utils
import
TestValueError
from
aesara.graph.utils
import
TestValueError
from
aesara.scan
import
utils
from
aesara.scan
import
utils
from
aesara.scan.op
import
Scan
from
aesara.scan.op
import
Scan
,
ScanInfo
from
aesara.scan.utils
import
safe_new
,
traverse
from
aesara.scan.utils
import
safe_new
,
traverse
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.math
import
minimum
from
aesara.tensor.math
import
minimum
...
@@ -1022,31 +1022,32 @@ def scan(
...
@@ -1022,31 +1022,32 @@ def scan(
# Step 7. Create the Scan Op
# Step 7. Create the Scan Op
##
##
tap_array
=
mit_sot_tap_array
+
[[
-
1
]
for
x
in
range
(
n_sit_sot
)]
tap_array
=
tuple
(
tuple
(
v
)
for
v
in
mit_sot_tap_array
)
+
tuple
(
(
-
1
,)
for
x
in
range
(
n_sit_sot
)
)
if
allow_gc
is
None
:
if
allow_gc
is
None
:
allow_gc
=
config
.
scan__allow_gc
allow_gc
=
config
.
scan__allow_gc
info
=
OrderedDict
()
info
=
ScanInfo
(
info
[
"tap_array"
]
=
tap_array
tap_array
=
tap_array
,
info
[
"n_seqs"
]
=
n_seqs
n_seqs
=
n_seqs
,
info
[
"n_mit_mot"
]
=
n_mit_mot
n_mit_mot
=
n_mit_mot
,
info
[
"n_mit_mot_outs"
]
=
n_mit_mot_outs
n_mit_mot_outs
=
n_mit_mot_outs
,
info
[
"mit_mot_out_slices"
]
=
mit_mot_out_slices
mit_mot_out_slices
=
tuple
(
tuple
(
v
)
for
v
in
mit_mot_out_slices
),
info
[
"n_mit_sot"
]
=
n_mit_sot
n_mit_sot
=
n_mit_sot
,
info
[
"n_sit_sot"
]
=
n_sit_sot
n_sit_sot
=
n_sit_sot
,
info
[
"n_shared_outs"
]
=
n_shared_outs
n_shared_outs
=
n_shared_outs
,
info
[
"n_nit_sot"
]
=
n_nit_sot
n_nit_sot
=
n_nit_sot
,
info
[
"truncate_gradient"
]
=
truncate_gradient
truncate_gradient
=
truncate_gradient
,
info
[
"name"
]
=
name
name
=
name
,
info
[
"mode"
]
=
mode
gpua
=
False
,
info
[
"destroy_map"
]
=
OrderedDict
()
as_while
=
as_while
,
info
[
"gpua"
]
=
False
profile
=
profile
,
info
[
"as_while"
]
=
as_while
allow_gc
=
allow_gc
,
info
[
"profile"
]
=
profile
strict
=
strict
,
info
[
"allow_gc"
]
=
allow_gc
)
info
[
"strict"
]
=
strict
local_op
=
Scan
(
inner_inputs
,
new_outs
,
info
,
mode
)
local_op
=
Scan
(
inner_inputs
,
new_outs
,
info
)
##
##
# Step 8. Compute the outputs using the scan op
# Step 8. Compute the outputs using the scan op
...
...
aesara/scan/op.py
浏览文件 @
e85c7fd0
...
@@ -44,11 +44,12 @@ relies on the following elements to work properly :
...
@@ -44,11 +44,12 @@ relies on the following elements to work properly :
"""
"""
import
copy
import
dataclasses
import
itertools
import
itertools
import
logging
import
logging
import
time
import
time
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -57,7 +58,7 @@ from aesara import tensor as aet
...
@@ -57,7 +58,7 @@ from aesara import tensor as aet
from
aesara.compile.builders
import
infer_shape
from
aesara.compile.builders
import
infer_shape
from
aesara.compile.function
import
function
from
aesara.compile.function
import
function
from
aesara.compile.io
import
In
,
Out
from
aesara.compile.io
import
In
,
Out
from
aesara.compile.mode
import
AddFeatureOptimizer
,
get_mode
from
aesara.compile.mode
import
AddFeatureOptimizer
,
Mode
,
get_default_mode
,
get_mode
from
aesara.compile.profiling
import
ScanProfileStats
,
register_profiler_printer
from
aesara.compile.profiling
import
ScanProfileStats
,
register_profiler_printer
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.gradient
import
DisconnectedType
,
NullType
,
Rop
,
grad
,
grad_undefined
from
aesara.gradient
import
DisconnectedType
,
NullType
,
Rop
,
grad
,
grad_undefined
...
@@ -76,7 +77,7 @@ from aesara.graph.op import Op, ops_with_inner_function
...
@@ -76,7 +77,7 @@ from aesara.graph.op import Op, ops_with_inner_function
from
aesara.link.c.basic
import
CLinker
from
aesara.link.c.basic
import
CLinker
from
aesara.link.c.exceptions
import
MissingGXX
from
aesara.link.c.exceptions
import
MissingGXX
from
aesara.link.utils
import
raise_with_op
from
aesara.link.utils
import
raise_with_op
from
aesara.scan.utils
import
Validator
,
forced_replace
,
hash_listsDictsTuples
,
safe_new
from
aesara.scan.utils
import
Validator
,
forced_replace
,
safe_new
from
aesara.tensor.basic
import
as_tensor_variable
from
aesara.tensor.basic
import
as_tensor_variable
from
aesara.tensor.math
import
minimum
from
aesara.tensor.math
import
minimum
from
aesara.tensor.shape
import
Shape_i
from
aesara.tensor.shape
import
Shape_i
...
@@ -88,25 +89,58 @@ from aesara.tensor.var import TensorVariable
...
@@ -88,25 +89,58 @@ from aesara.tensor.var import TensorVariable
_logger
=
logging
.
getLogger
(
"aesara.scan.op"
)
_logger
=
logging
.
getLogger
(
"aesara.scan.op"
)
@dataclasses.dataclass
(
frozen
=
True
)
class
ScanInfo
:
tap_array
:
tuple
n_seqs
:
int
n_mit_mot
:
int
n_mit_mot_outs
:
int
mit_mot_out_slices
:
tuple
n_mit_sot
:
int
n_sit_sot
:
int
n_shared_outs
:
int
n_nit_sot
:
int
truncate_gradient
:
bool
=
False
name
:
Optional
[
str
]
=
None
gpua
:
bool
=
False
as_while
:
bool
=
False
profile
:
Optional
[
Union
[
str
,
bool
]]
=
None
allow_gc
:
bool
=
True
strict
:
bool
=
True
TensorConstructorType
=
Callable
[[
List
[
bool
],
Union
[
str
,
np
.
generic
]],
TensorType
]
class
Scan
(
Op
):
class
Scan
(
Op
):
"""
def
__init__
(
self
,
inputs
:
List
[
Variable
],
outputs
:
List
[
Variable
],
info
:
ScanInfo
,
mode
:
Optional
[
Mode
]
=
None
,
typeConstructor
:
Optional
[
TensorConstructorType
]
=
None
,
):
r"""
Parameters
Parameters
----------
----------
inputs
inputs
Inputs of the inner function of scan
.
Inputs of the inner function of `Scan`
.
outputs
outputs
Outputs of the inner function of scan
.
Outputs of the inner function of `Scan`
.
info
info
Dictionary containing different properties of the scan op (like number
Dictionary containing different properties of the `Scan` `Op` (like
of different types of arguments, name, mode, if it should run on GPU or
number of different types of arguments, name, mode, if it should run on
not, etc.).
GPU or not, etc.).
mode
The compilation mode for the inner graph.
typeConstructor
typeConstructor
Function that constructs an equivalent to Aesara TensorType
.
Function that constructs an equivalent to Aesara `TensorType`
.
Notes
Notes
-----
-----
``typeConstructor`
` had been added to refactor how
`typeConstructor
` had been added to refactor how
Aesara deals with the GPU. If it runs on the GPU, scan needs
Aesara deals with the GPU. If it runs on the GPU, scan needs
to construct certain outputs (those who reside in the GPU
to construct certain outputs (those who reside in the GPU
memory) as the GPU-specific type. However we can not import
memory) as the GPU-specific type. However we can not import
...
@@ -114,31 +148,31 @@ class Scan(Op):
...
@@ -114,31 +148,31 @@ class Scan(Op):
on each machine) so the workaround is that the GPU
on each machine) so the workaround is that the GPU
optimization passes to the constructor of this class a
optimization passes to the constructor of this class a
function that is able to construct a GPU type. This way the
function that is able to construct a GPU type. This way the
class Scan
does not need to be aware of the details for the
class `Scan`
does not need to be aware of the details for the
GPU, it just constructs any tensor using this function (which
GPU, it just constructs any tensor using this function (which
by default constructs normal tensors).
by default constructs normal tensors).
"""
"""
def
__init__
(
self
,
inputs
,
outputs
,
info
,
typeConstructor
=
None
,
):
# adding properties into self
# adding properties into self
self
.
inputs
=
inputs
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
outputs
=
outputs
self
.
__dict__
.
update
(
info
)
# I keep a version of info in self, to use in __eq__ and __hash__,
# since info contains all tunable parameters of the op, so for two
# scan to be equal this tunable parameters should be the same
self
.
info
=
info
self
.
info
=
info
self
.
__dict__
.
update
(
dataclasses
.
asdict
(
info
))
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if
self
.
name
:
message
=
self
.
name
+
" sub profile"
else
:
message
=
"Scan sub profile"
self
.
mode
=
get_default_mode
()
if
mode
is
None
else
mode
self
.
mode_instance
=
get_mode
(
self
.
mode
)
.
clone
(
link_kwargs
=
dict
(
allow_gc
=
self
.
allow_gc
),
message
=
message
)
# build a list of output types for any Apply node using this op.
# build a list of output types for any Apply node using this op.
self
.
output_types
=
[]
self
.
output_types
=
[]
idx
=
0
jdx
=
0
def
tensorConstructor
(
broadcastable
,
dtype
):
def
tensorConstructor
(
broadcastable
,
dtype
):
return
TensorType
(
broadcastable
=
broadcastable
,
dtype
=
dtype
)
return
TensorType
(
broadcastable
=
broadcastable
,
dtype
=
dtype
)
...
@@ -146,6 +180,8 @@ class Scan(Op):
...
@@ -146,6 +180,8 @@ class Scan(Op):
if
typeConstructor
is
None
:
if
typeConstructor
is
None
:
typeConstructor
=
tensorConstructor
typeConstructor
=
tensorConstructor
idx
=
0
jdx
=
0
while
idx
<
self
.
n_mit_mot_outs
:
while
idx
<
self
.
n_mit_mot_outs
:
# Not that for mit_mot there are several output slices per
# Not that for mit_mot there are several output slices per
# output sequence
# output sequence
...
@@ -176,24 +212,13 @@ class Scan(Op):
...
@@ -176,24 +212,13 @@ class Scan(Op):
if
self
.
as_while
:
if
self
.
as_while
:
self
.
output_types
=
self
.
output_types
[:
-
1
]
self
.
output_types
=
self
.
output_types
[:
-
1
]
mode_instance
=
get_mode
(
self
.
mode
)
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if
self
.
name
:
message
=
self
.
name
+
" sub profile"
else
:
message
=
"Scan sub profile"
self
.
mode_instance
=
mode_instance
.
clone
(
link_kwargs
=
dict
(
allow_gc
=
self
.
allow_gc
),
message
=
message
)
if
not
hasattr
(
self
,
"name"
)
or
self
.
name
is
None
:
if
not
hasattr
(
self
,
"name"
)
or
self
.
name
is
None
:
self
.
name
=
"scan_fn"
self
.
name
=
"scan_fn"
# to have a fair __eq__ comparison later on, we update the info with
# to have a fair __eq__ comparison later on, we update the info with
# the actual mode used to compile the function and the name of the
# the actual mode used to compile the function and the name of the
# function that we set in case none was given
# function that we set in case none was given
self
.
info
[
"name"
]
=
self
.
name
self
.
info
=
dataclasses
.
replace
(
self
.
info
,
name
=
self
.
name
)
# Pre-computing some values to speed up perform
# Pre-computing some values to speed up perform
self
.
mintaps
=
[
np
.
min
(
x
)
for
x
in
self
.
tap_array
]
self
.
mintaps
=
[
np
.
min
(
x
)
for
x
in
self
.
tap_array
]
...
@@ -205,8 +230,8 @@ class Scan(Op):
...
@@ -205,8 +230,8 @@ class Scan(Op):
self
.
nit_sot_arg_offset
=
self
.
shared_arg_offset
+
self
.
n_shared_outs
self
.
nit_sot_arg_offset
=
self
.
shared_arg_offset
+
self
.
n_shared_outs
self
.
n_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
self
.
n_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
self
.
n_tap_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
self
.
n_tap_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
if
self
.
info
[
"gpua"
]
:
if
self
.
info
.
gpua
:
self
.
_hash_inner_graph
=
self
.
info
[
"gpu_hash"
]
self
.
_hash_inner_graph
=
self
.
info
.
gpu_hash
else
:
else
:
# Do the missing inputs check here to have the error early.
# Do the missing inputs check here to have the error early.
for
var
in
graph_inputs
(
self
.
outputs
,
self
.
inputs
):
for
var
in
graph_inputs
(
self
.
outputs
,
self
.
inputs
):
...
@@ -256,7 +281,7 @@ class Scan(Op):
...
@@ -256,7 +281,7 @@ class Scan(Op):
# output with type GpuArrayType
# output with type GpuArrayType
from
aesara.gpuarray
import
GpuArrayType
from
aesara.gpuarray
import
GpuArrayType
if
not
self
.
info
.
g
et
(
"gpua"
,
False
)
:
if
not
self
.
info
.
g
pua
:
for
inp
in
self
.
inputs
:
for
inp
in
self
.
inputs
:
if
isinstance
(
inp
.
type
,
GpuArrayType
):
if
isinstance
(
inp
.
type
,
GpuArrayType
):
raise
TypeError
(
raise
TypeError
(
...
@@ -279,9 +304,7 @@ class Scan(Op):
...
@@ -279,9 +304,7 @@ class Scan(Op):
def
__setstate__
(
self
,
d
):
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
self
.
__dict__
.
update
(
d
)
if
"allow_gc"
not
in
self
.
__dict__
:
self
.
allow_gc
=
True
self
.
info
[
"allow_gc"
]
=
True
if
not
hasattr
(
self
,
"var_mappings"
):
if
not
hasattr
(
self
,
"var_mappings"
):
# Generate the mappings between inner and outer inputs and outputs
# Generate the mappings between inner and outer inputs and outputs
# if they haven't already been generated.
# if they haven't already been generated.
...
@@ -731,41 +754,20 @@ class Scan(Op):
...
@@ -731,41 +754,20 @@ class Scan(Op):
return
apply_node
return
apply_node
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
# Check if we are dealing with same type of objects
if
type
(
self
)
!=
type
(
other
):
if
not
type
(
self
)
==
type
(
other
):
return
False
return
False
if
"destroy_map"
not
in
self
.
info
:
self
.
info
[
"destroy_map"
]
=
OrderedDict
()
if
self
.
info
!=
other
.
info
:
if
"destroy_map"
not
in
other
.
info
:
other
.
info
[
"destroy_map"
]
=
OrderedDict
()
keys_to_check
=
[
"truncate_gradient"
,
"profile"
,
"n_seqs"
,
"tap_array"
,
"as_while"
,
"n_mit_sot"
,
"destroy_map"
,
"n_nit_sot"
,
"n_shared_outs"
,
"n_sit_sot"
,
"gpua"
,
"n_mit_mot_outs"
,
"n_mit_mot"
,
"mit_mot_out_slices"
,
]
# This are some safety checks ( namely that the inner graph has the
# same number of inputs and same number of outputs )
if
not
len
(
self
.
inputs
)
==
len
(
other
.
inputs
):
return
False
return
False
elif
not
len
(
self
.
outputs
)
==
len
(
other
.
outputs
):
# Compare inner graphs
# TODO: Use `self.inner_fgraph == other.inner_fgraph`
if
len
(
self
.
inputs
)
!=
len
(
other
.
inputs
):
return
False
return
False
for
key
in
keys_to_check
:
if
self
.
info
[
key
]
!=
other
.
info
[
key
]
:
if
len
(
self
.
outputs
)
!=
len
(
other
.
outputs
)
:
return
False
return
False
# If everything went OK up to here, there is still one thing to
# check. Namely, do the internal graph represent same
# computations
for
self_in
,
other_in
in
zip
(
self
.
inputs
,
other
.
inputs
):
for
self_in
,
other_in
in
zip
(
self
.
inputs
,
other
.
inputs
):
if
self_in
.
type
!=
other_in
.
type
:
if
self_in
.
type
!=
other_in
.
type
:
return
False
return
False
...
@@ -801,15 +803,7 @@ class Scan(Op):
...
@@ -801,15 +803,7 @@ class Scan(Op):
return
aux_txt
return
aux_txt
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
return
hash
((
type
(
self
),
self
.
_hash_inner_graph
,
self
.
info
))
(
type
(
self
),
# and a hash representing the inner graph using the
# CLinker.cmodule_key_
self
.
_hash_inner_graph
,
hash_listsDictsTuples
(
self
.
info
),
)
)
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
,
impl
=
None
):
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
,
impl
=
None
):
"""
"""
...
@@ -864,7 +858,6 @@ class Scan(Op):
...
@@ -864,7 +858,6 @@ class Scan(Op):
wrapped_inputs
=
[
In
(
x
,
borrow
=
False
)
for
x
in
self
.
inputs
[:
self
.
n_seqs
]]
wrapped_inputs
=
[
In
(
x
,
borrow
=
False
)
for
x
in
self
.
inputs
[:
self
.
n_seqs
]]
new_outputs
=
[
x
for
x
in
self
.
outputs
]
new_outputs
=
[
x
for
x
in
self
.
outputs
]
preallocated_mitmot_outs
=
[]
preallocated_mitmot_outs
=
[]
new_mit_mot_out_slices
=
copy
.
deepcopy
(
self
.
mit_mot_out_slices
)
input_idx
=
self
.
n_seqs
input_idx
=
self
.
n_seqs
for
mitmot_idx
in
range
(
self
.
n_mit_mot
):
for
mitmot_idx
in
range
(
self
.
n_mit_mot
):
...
@@ -894,7 +887,6 @@ class Scan(Op):
...
@@ -894,7 +887,6 @@ class Scan(Op):
)
)
wrapped_inputs
.
append
(
wrapped_inp
)
wrapped_inputs
.
append
(
wrapped_inp
)
preallocated_mitmot_outs
.
append
(
output_idx
)
preallocated_mitmot_outs
.
append
(
output_idx
)
new_mit_mot_out_slices
[
mitmot_idx
]
.
remove
(
inp_tap
)
else
:
else
:
# Wrap the corresponding input as usual. Leave the
# Wrap the corresponding input as usual. Leave the
# output as-is.
# output as-is.
...
@@ -1963,7 +1955,7 @@ class Scan(Op):
...
@@ -1963,7 +1955,7 @@ class Scan(Op):
outer_oidx
=
0
outer_oidx
=
0
# Handle sequences inputs
# Handle sequences inputs
for
i
in
range
(
self
.
info
[
"n_seqs"
]
):
for
i
in
range
(
self
.
info
.
n_seqs
):
outer_input_indices
.
append
(
outer_iidx
)
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([
inner_iidx
])
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([])
inner_output_indices
.
append
([])
...
@@ -1975,8 +1967,8 @@ class Scan(Op):
...
@@ -1975,8 +1967,8 @@ class Scan(Op):
outer_oidx
+=
0
outer_oidx
+=
0
# Handle mitmots, mitsots and sitsots variables
# Handle mitmots, mitsots and sitsots variables
for
i
in
range
(
len
(
self
.
info
[
"tap_array"
]
)):
for
i
in
range
(
len
(
self
.
info
.
tap_array
)):
nb_input_taps
=
len
(
self
.
info
[
"tap_array"
]
[
i
])
nb_input_taps
=
len
(
self
.
info
.
tap_array
[
i
])
if
i
<
self
.
n_mit_mot
:
if
i
<
self
.
n_mit_mot
:
nb_output_taps
=
len
(
self
.
mit_mot_out_slices
[
i
])
nb_output_taps
=
len
(
self
.
mit_mot_out_slices
[
i
])
...
@@ -1999,7 +1991,7 @@ class Scan(Op):
...
@@ -1999,7 +1991,7 @@ class Scan(Op):
# This is needed because, for outer inputs (and for outer inputs only)
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
# nitsots come *after* shared variables.
outer_iidx
+=
self
.
info
[
"n_shared_outs"
]
outer_iidx
+=
self
.
info
.
n_shared_outs
# Handle nitsots variables
# Handle nitsots variables
for
i
in
range
(
self
.
n_nit_sot
):
for
i
in
range
(
self
.
n_nit_sot
):
...
@@ -2015,10 +2007,10 @@ class Scan(Op):
...
@@ -2015,10 +2007,10 @@ class Scan(Op):
# This is needed because, for outer inputs (and for outer inputs only)
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
# nitsots come *after* shared variables.
outer_iidx
-=
self
.
info
[
"n_shared_outs"
]
+
self
.
n_nit_sot
outer_iidx
-=
self
.
info
.
n_shared_outs
+
self
.
n_nit_sot
# Handle shared states
# Handle shared states
for
i
in
range
(
self
.
info
[
"n_shared_outs"
]
):
for
i
in
range
(
self
.
info
.
n_shared_outs
):
outer_input_indices
.
append
(
outer_iidx
)
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([
inner_iidx
])
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([
inner_oidx
])
inner_output_indices
.
append
([
inner_oidx
])
...
@@ -2158,10 +2150,10 @@ class Scan(Op):
...
@@ -2158,10 +2150,10 @@ class Scan(Op):
oidx
+=
1
oidx
+=
1
iidx
-=
len
(
taps
)
iidx
-=
len
(
taps
)
if
iidx
<
self
.
info
[
"n_sit_sot"
]
:
if
iidx
<
self
.
info
.
n_sit_sot
:
return
oidx
+
iidx
return
oidx
+
iidx
else
:
else
:
return
oidx
+
iidx
+
self
.
info
[
"n_nit_sot"
]
return
oidx
+
iidx
+
self
.
info
.
n_nit_sot
def
get_out_idx
(
iidx
):
def
get_out_idx
(
iidx
):
oidx
=
0
oidx
=
0
...
@@ -2241,9 +2233,9 @@ class Scan(Op):
...
@@ -2241,9 +2233,9 @@ class Scan(Op):
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because
# the exact same variable can be used as multiple outputs.
# the exact same variable can be used as multiple outputs.
idx_nitsot_start
=
(
idx_nitsot_start
=
(
self
.
info
[
"n_mit_mot"
]
+
self
.
info
[
"n_mit_sot"
]
+
self
.
info
[
"n_sit_sot"
]
self
.
info
.
n_mit_mot
+
self
.
info
.
n_mit_sot
+
self
.
info
.
n_sit_sot
)
)
idx_nitsot_end
=
idx_nitsot_start
+
self
.
info
[
"n_nit_sot"
]
idx_nitsot_end
=
idx_nitsot_start
+
self
.
info
.
n_nit_sot
if
idx
<
idx_nitsot_start
or
idx
>=
idx_nitsot_end
:
if
idx
<
idx_nitsot_start
or
idx
>=
idx_nitsot_end
:
# What we do here is loop through dC_douts and collect all
# What we do here is loop through dC_douts and collect all
# those that are connected to the specific one and do an
# those that are connected to the specific one and do an
...
@@ -2668,27 +2660,23 @@ class Scan(Op):
...
@@ -2668,27 +2660,23 @@ class Scan(Op):
n_sitsot_outs
=
len
(
outer_inp_sitsot
)
n_sitsot_outs
=
len
(
outer_inp_sitsot
)
new_tap_array
=
mitmot_inp_taps
+
[[
-
1
]
for
k
in
range
(
n_sitsot_outs
)]
new_tap_array
=
mitmot_inp_taps
+
[[
-
1
]
for
k
in
range
(
n_sitsot_outs
)]
info
=
OrderedDict
()
info
=
ScanInfo
(
info
[
"n_seqs"
]
=
len
(
outer_inp_seqs
)
n_seqs
=
len
(
outer_inp_seqs
),
info
[
"n_mit_sot"
]
=
0
n_mit_sot
=
0
,
info
[
"tap_array"
]
=
new_tap_array
tap_array
=
tuple
(
tuple
(
v
)
for
v
in
new_tap_array
),
info
[
"gpua"
]
=
False
gpua
=
False
,
info
[
"n_mit_mot"
]
=
len
(
outer_inp_mitmot
)
n_mit_mot
=
len
(
outer_inp_mitmot
),
info
[
"n_mit_mot_outs"
]
=
n_mitmot_outs
n_mit_mot_outs
=
n_mitmot_outs
,
info
[
"mit_mot_out_slices"
]
=
mitmot_out_taps
mit_mot_out_slices
=
tuple
(
tuple
(
v
)
for
v
in
mitmot_out_taps
),
info
[
"truncate_gradient"
]
=
self
.
truncate_gradient
truncate_gradient
=
self
.
truncate_gradient
,
info
[
"n_sit_sot"
]
=
n_sitsot_outs
n_sit_sot
=
n_sitsot_outs
,
info
[
"n_shared_outs"
]
=
0
n_shared_outs
=
0
,
info
[
"n_nit_sot"
]
=
n_nit_sot
n_nit_sot
=
n_nit_sot
,
info
[
"as_while"
]
=
False
as_while
=
False
,
info
[
"profile"
]
=
self
.
profile
profile
=
self
.
profile
,
info
[
"destroy_map"
]
=
OrderedDict
()
name
=
f
"grad_of_{self.name}"
if
self
.
name
else
None
,
if
self
.
name
:
allow_gc
=
self
.
allow_gc
,
info
[
"name"
]
=
"grad_of_"
+
self
.
name
)
else
:
info
[
"name"
]
=
None
info
[
"mode"
]
=
self
.
mode
info
[
"allow_gc"
]
=
self
.
allow_gc
outer_inputs
=
(
outer_inputs
=
(
[
grad_steps
]
[
grad_steps
]
...
@@ -2709,7 +2697,7 @@ class Scan(Op):
...
@@ -2709,7 +2697,7 @@ class Scan(Op):
)
)
inner_gfn_outs
=
inner_out_mitmot
+
inner_out_sitsot
+
inner_out_nitsot
inner_gfn_outs
=
inner_out_mitmot
+
inner_out_sitsot
+
inner_out_nitsot
local_op
=
Scan
(
inner_gfn_ins
,
inner_gfn_outs
,
info
)
local_op
=
Scan
(
inner_gfn_ins
,
inner_gfn_outs
,
info
,
self
.
mode
)
outputs
=
local_op
(
*
outer_inputs
)
outputs
=
local_op
(
*
outer_inputs
)
if
type
(
outputs
)
not
in
(
list
,
tuple
):
if
type
(
outputs
)
not
in
(
list
,
tuple
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
...
@@ -2852,8 +2840,8 @@ class Scan(Op):
...
@@ -2852,8 +2840,8 @@ class Scan(Op):
rop_self_outputs
=
self_outputs
[:
-
1
]
rop_self_outputs
=
self_outputs
[:
-
1
]
else
:
else
:
rop_self_outputs
=
self_outputs
rop_self_outputs
=
self_outputs
if
self
.
info
[
"n_shared_outs"
]
>
0
:
if
self
.
info
.
n_shared_outs
>
0
:
rop_self_outputs
=
rop_self_outputs
[:
-
self
.
info
[
"n_shared_outs"
]
]
rop_self_outputs
=
rop_self_outputs
[:
-
self
.
info
.
n_shared_outs
]
rop_outs
=
Rop
(
rop_self_outputs
,
rop_of_inputs
,
inner_eval_points
)
rop_outs
=
Rop
(
rop_self_outputs
,
rop_of_inputs
,
inner_eval_points
)
if
type
(
rop_outs
)
not
in
(
list
,
tuple
):
if
type
(
rop_outs
)
not
in
(
list
,
tuple
):
rop_outs
=
[
rop_outs
]
rop_outs
=
[
rop_outs
]
...
@@ -2867,25 +2855,7 @@ class Scan(Op):
...
@@ -2867,25 +2855,7 @@ class Scan(Op):
# The only exception is the eval point for the number of sequences, and
# The only exception is the eval point for the number of sequences, and
# evan point for the number of nit_sot which I think should just be
# evan point for the number of nit_sot which I think should just be
# ignored (?)
# ignored (?)
info
=
OrderedDict
()
info
[
"n_seqs"
]
=
self
.
n_seqs
*
2
info
[
"n_mit_sot"
]
=
self
.
n_mit_sot
*
2
info
[
"n_sit_sot"
]
=
self
.
n_sit_sot
*
2
info
[
"n_mit_mot"
]
=
self
.
n_mit_mot
*
2
info
[
"n_nit_sot"
]
=
self
.
n_nit_sot
*
2
info
[
"n_shared_outs"
]
=
self
.
n_shared_outs
info
[
"gpua"
]
=
False
info
[
"as_while"
]
=
self
.
as_while
info
[
"profile"
]
=
self
.
profile
info
[
"truncate_gradient"
]
=
self
.
truncate_gradient
if
self
.
name
:
info
[
"name"
]
=
"rop_of_"
+
self
.
name
else
:
info
[
"name"
]
=
None
info
[
"mode"
]
=
self
.
mode
info
[
"allow_gc"
]
=
self
.
allow_gc
info
[
"mit_mot_out_slices"
]
=
self
.
mit_mot_out_slices
*
2
info
[
"destroy_map"
]
=
OrderedDict
()
new_tap_array
=
[]
new_tap_array
=
[]
b
=
0
b
=
0
e
=
self
.
n_mit_mot
e
=
self
.
n_mit_mot
...
@@ -2896,7 +2866,6 @@ class Scan(Op):
...
@@ -2896,7 +2866,6 @@ class Scan(Op):
b
=
e
b
=
e
e
+=
self
.
n_sit_sot
e
+=
self
.
n_sit_sot
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
info
[
"tap_array"
]
=
new_tap_array
# Sequences ...
# Sequences ...
b
=
1
b
=
1
...
@@ -2993,7 +2962,7 @@ class Scan(Op):
...
@@ -2993,7 +2962,7 @@ class Scan(Op):
# Outputs
# Outputs
n_mit_mot_outs
=
int
(
np
.
sum
([
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
]))
n_mit_mot_outs
=
int
(
np
.
sum
([
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
]))
info
[
"n_mit_mot_outs"
]
=
n_mit_mot_outs
*
2
b
=
0
b
=
0
e
=
n_mit_mot_outs
e
=
n_mit_mot_outs
inner_out_mit_mot
=
self_outputs
[
b
:
e
]
+
rop_outs
[
b
:
e
]
inner_out_mit_mot
=
self_outputs
[
b
:
e
]
+
rop_outs
[
b
:
e
]
...
@@ -3039,7 +3008,25 @@ class Scan(Op):
...
@@ -3039,7 +3008,25 @@ class Scan(Op):
+
scan_other
+
scan_other
)
)
local_op
=
Scan
(
inner_ins
,
inner_outs
,
info
)
info
=
ScanInfo
(
n_seqs
=
self
.
n_seqs
*
2
,
n_mit_sot
=
self
.
n_mit_sot
*
2
,
n_sit_sot
=
self
.
n_sit_sot
*
2
,
n_mit_mot
=
self
.
n_mit_mot
*
2
,
n_nit_sot
=
self
.
n_nit_sot
*
2
,
n_shared_outs
=
self
.
n_shared_outs
,
n_mit_mot_outs
=
n_mit_mot_outs
*
2
,
gpua
=
False
,
as_while
=
self
.
as_while
,
profile
=
self
.
profile
,
truncate_gradient
=
self
.
truncate_gradient
,
name
=
f
"rop_of_{self.name}"
if
self
.
name
else
None
,
allow_gc
=
self
.
allow_gc
,
tap_array
=
tuple
(
tuple
(
v
)
for
v
in
new_tap_array
),
mit_mot_out_slices
=
tuple
(
tuple
(
v
)
for
v
in
self
.
mit_mot_out_slices
)
*
2
,
)
local_op
=
Scan
(
inner_ins
,
inner_outs
,
info
,
self
.
mode
)
outputs
=
local_op
(
*
scan_inputs
)
outputs
=
local_op
(
*
scan_inputs
)
if
type
(
outputs
)
not
in
(
list
,
tuple
):
if
type
(
outputs
)
not
in
(
list
,
tuple
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
...
...
aesara/scan/opt.py
浏览文件 @
e85c7fd0
...
@@ -51,8 +51,8 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
...
@@ -51,8 +51,8 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
"""
"""
import
copy
import
copy
import
dataclasses
import
logging
import
logging
from
collections
import
OrderedDict
from
sys
import
maxsize
from
sys
import
maxsize
import
numpy
as
np
import
numpy
as
np
...
@@ -78,7 +78,7 @@ from aesara.graph.fg import InconsistencyError
...
@@ -78,7 +78,7 @@ from aesara.graph.fg import InconsistencyError
from
aesara.graph.op
import
compute_test_value
from
aesara.graph.op
import
compute_test_value
from
aesara.graph.opt
import
GlobalOptimizer
,
in2out
,
local_optimizer
from
aesara.graph.opt
import
GlobalOptimizer
,
in2out
,
local_optimizer
from
aesara.graph.optdb
import
EquilibriumDB
,
SequenceDB
from
aesara.graph.optdb
import
EquilibriumDB
,
SequenceDB
from
aesara.scan.op
import
Scan
from
aesara.scan.op
import
Scan
,
ScanInfo
from
aesara.scan.utils
import
(
from
aesara.scan.utils
import
(
ScanArgs
,
ScanArgs
,
compress_outs
,
compress_outs
,
...
@@ -156,7 +156,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
...
@@ -156,7 +156,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
out_stuff_outer
=
node
.
inputs
[
1
+
op
.
n_seqs
:
st
]
out_stuff_outer
=
node
.
inputs
[
1
+
op
.
n_seqs
:
st
]
# To replace constants in the outer graph by clones in the inner graph
# To replace constants in the outer graph by clones in the inner graph
givens
=
OrderedDict
()
givens
=
{}
# All the inputs of the inner graph of the new scan
# All the inputs of the inner graph of the new scan
nw_inner
=
[]
nw_inner
=
[]
# Same for the outer graph, initialized w/ number of steps
# Same for the outer graph, initialized w/ number of steps
...
@@ -217,12 +217,10 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
...
@@ -217,12 +217,10 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
if
len
(
nw_inner
)
!=
len
(
op_ins
):
if
len
(
nw_inner
)
!=
len
(
op_ins
):
op_outs
=
clone_replace
(
op_outs
,
replace
=
givens
)
op_outs
=
clone_replace
(
op_outs
,
replace
=
givens
)
nw_info
=
copy
.
deepcopy
(
op
.
info
)
nw_info
=
dataclasses
.
replace
(
op
.
info
,
n_seqs
=
nw_n_seqs
)
nw_info
[
"n_seqs"
]
=
nw_n_seqs
nwScan
=
Scan
(
nw_inner
,
op_outs
,
nw_info
,
op
.
mode
)
# DEBUG CHECK
nwScan
=
Scan
(
nw_inner
,
op_outs
,
nw_info
)
nw_outs
=
nwScan
(
*
nw_outer
,
**
dict
(
return_list
=
True
))
nw_outs
=
nwScan
(
*
nw_outer
,
**
dict
(
return_list
=
True
))
return
OrderedD
ict
([(
"remove"
,
[
node
])]
+
list
(
zip
(
node
.
outputs
,
nw_outs
)))
return
d
ict
([(
"remove"
,
[
node
])]
+
list
(
zip
(
node
.
outputs
,
nw_outs
)))
else
:
else
:
return
False
return
False
...
@@ -263,7 +261,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
...
@@ -263,7 +261,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
to_remove_set
=
set
()
to_remove_set
=
set
()
to_replace_set
=
set
()
to_replace_set
=
set
()
to_replace_map
=
OrderedDict
()
to_replace_map
=
{}
def
add_to_replace
(
y
):
def
add_to_replace
(
y
):
to_replace_set
.
add
(
y
)
to_replace_set
.
add
(
y
)
...
@@ -377,7 +375,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
...
@@ -377,7 +375,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
if
len
(
clean_to_replace
)
>
0
:
if
len
(
clean_to_replace
)
>
0
:
# We can finally put an end to all this madness
# We can finally put an end to all this madness
givens
=
OrderedDict
()
givens
=
{}
nw_outer
=
[]
nw_outer
=
[]
nw_inner
=
[]
nw_inner
=
[]
for
to_repl
,
repl_in
,
repl_out
in
zip
(
for
to_repl
,
repl_in
,
repl_out
in
zip
(
...
@@ -394,7 +392,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
...
@@ -394,7 +392,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
op_ins
=
clean_inputs
+
nw_inner
op_ins
=
clean_inputs
+
nw_inner
# Reconstruct node
# Reconstruct node
nwScan
=
Scan
(
op_ins
,
op_outs
,
op
.
info
)
nwScan
=
Scan
(
op_ins
,
op_outs
,
op
.
info
,
op
.
mode
)
# Do not call make_node for test_value
# Do not call make_node for test_value
nw_node
=
nwScan
(
*
(
node
.
inputs
+
nw_outer
),
**
dict
(
return_list
=
True
))[
nw_node
=
nwScan
(
*
(
node
.
inputs
+
nw_outer
),
**
dict
(
return_list
=
True
))[
...
@@ -409,7 +407,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
...
@@ -409,7 +407,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
return
True
return
True
elif
not
to_keep_set
:
elif
not
to_keep_set
:
# Nothing in the inner graph should be kept
# Nothing in the inner graph should be kept
replace_with
=
OrderedDict
()
replace_with
=
{}
for
out
,
idx
in
to_replace_map
.
items
():
for
out
,
idx
in
to_replace_map
.
items
():
if
out
in
local_fgraph_outs_set
:
if
out
in
local_fgraph_outs_set
:
x
=
node
.
outputs
[
local_fgraph_outs_map
[
out
]]
x
=
node
.
outputs
[
local_fgraph_outs_map
[
out
]]
...
@@ -481,7 +479,7 @@ class PushOutSeqScan(GlobalOptimizer):
...
@@ -481,7 +479,7 @@ class PushOutSeqScan(GlobalOptimizer):
to_remove_set
=
set
()
to_remove_set
=
set
()
to_replace_set
=
set
()
to_replace_set
=
set
()
to_replace_map
=
OrderedDict
()
to_replace_map
=
{}
def
add_to_replace
(
y
):
def
add_to_replace
(
y
):
to_replace_set
.
add
(
y
)
to_replace_set
.
add
(
y
)
...
@@ -638,7 +636,7 @@ class PushOutSeqScan(GlobalOptimizer):
...
@@ -638,7 +636,7 @@ class PushOutSeqScan(GlobalOptimizer):
if
len
(
clean_to_replace
)
>
0
:
if
len
(
clean_to_replace
)
>
0
:
# We can finally put an end to all this madness
# We can finally put an end to all this madness
givens
=
OrderedDict
()
givens
=
{}
nw_outer
=
[]
nw_outer
=
[]
nw_inner
=
[]
nw_inner
=
[]
for
to_repl
,
repl_in
,
repl_out
in
zip
(
for
to_repl
,
repl_in
,
repl_out
in
zip
(
...
@@ -656,9 +654,10 @@ class PushOutSeqScan(GlobalOptimizer):
...
@@ -656,9 +654,10 @@ class PushOutSeqScan(GlobalOptimizer):
op_ins
=
nw_inner
+
clean_inputs
op_ins
=
nw_inner
+
clean_inputs
# Reconstruct node
# Reconstruct node
nw_info
=
op
.
info
.
copy
()
nw_info
=
dataclasses
.
replace
(
nw_info
[
"n_seqs"
]
+=
len
(
nw_inner
)
op
.
info
,
n_seqs
=
op
.
info
.
n_seqs
+
len
(
nw_inner
)
nwScan
=
Scan
(
op_ins
,
op_outs
,
nw_info
)
)
nwScan
=
Scan
(
op_ins
,
op_outs
,
nw_info
,
op
.
mode
)
# Do not call make_node for test_value
# Do not call make_node for test_value
nw_node
=
nwScan
(
nw_node
=
nwScan
(
*
(
node
.
inputs
[:
1
]
+
nw_outer
+
node
.
inputs
[
1
:]),
*
(
node
.
inputs
[:
1
]
+
nw_outer
+
node
.
inputs
[
1
:]),
...
@@ -673,7 +672,7 @@ class PushOutSeqScan(GlobalOptimizer):
...
@@ -673,7 +672,7 @@ class PushOutSeqScan(GlobalOptimizer):
return
True
return
True
elif
not
to_keep_set
and
not
op
.
as_while
and
not
op
.
outer_mitmot
(
node
):
elif
not
to_keep_set
and
not
op
.
as_while
and
not
op
.
outer_mitmot
(
node
):
# Nothing in the inner graph should be kept
# Nothing in the inner graph should be kept
replace_with
=
OrderedDict
()
replace_with
=
{}
for
out
,
idx
in
to_replace_map
.
items
():
for
out
,
idx
in
to_replace_map
.
items
():
if
out
in
local_fgraph_outs_set
:
if
out
in
local_fgraph_outs_set
:
x
=
node
.
outputs
[
local_fgraph_outs_map
[
out
]]
x
=
node
.
outputs
[
local_fgraph_outs_map
[
out
]]
...
@@ -937,7 +936,10 @@ class PushOutScanOutput(GlobalOptimizer):
...
@@ -937,7 +936,10 @@ class PushOutScanOutput(GlobalOptimizer):
# Create the `Scan` `Op` from the `ScanArgs`
# Create the `Scan` `Op` from the `ScanArgs`
new_scan_op
=
Scan
(
new_scan_op
=
Scan
(
new_scan_args
.
inner_inputs
,
new_scan_args
.
inner_outputs
,
new_scan_args
.
info
new_scan_args
.
inner_inputs
,
new_scan_args
.
inner_outputs
,
new_scan_args
.
info
,
old_scan_node
.
op
.
mode
,
)
)
# Create the Apply node for the scan op
# Create the Apply node for the scan op
...
@@ -1000,13 +1002,6 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -1000,13 +1002,6 @@ class ScanInplaceOptimizer(GlobalOptimizer):
op
=
node
.
op
op
=
node
.
op
info
=
copy
.
deepcopy
(
op
.
info
)
if
"destroy_map"
not
in
info
:
info
[
"destroy_map"
]
=
OrderedDict
()
for
out_idx
in
output_indices
:
info
[
"destroy_map"
][
out_idx
]
=
[
out_idx
+
1
+
op
.
info
[
"n_seqs"
]]
# inputs corresponding to sequences and n_steps
# inputs corresponding to sequences and n_steps
ls_begin
=
node
.
inputs
[:
1
+
op
.
n_seqs
]
ls_begin
=
node
.
inputs
[:
1
+
op
.
n_seqs
]
ls
=
op
.
outer_mitmot
(
node
.
inputs
)
ls
=
op
.
outer_mitmot
(
node
.
inputs
)
...
@@ -1048,7 +1043,15 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -1048,7 +1043,15 @@ class ScanInplaceOptimizer(GlobalOptimizer):
else
:
else
:
typeConstructor
=
self
.
typeInfer
(
node
)
typeConstructor
=
self
.
typeInfer
(
node
)
new_op
=
Scan
(
op
.
inputs
,
op
.
outputs
,
info
,
typeConstructor
=
typeConstructor
)
new_op
=
Scan
(
op
.
inputs
,
op
.
outputs
,
op
.
info
,
op
.
mode
,
typeConstructor
=
typeConstructor
)
destroy_map
=
op
.
destroy_map
.
copy
()
for
out_idx
in
output_indices
:
destroy_map
[
out_idx
]
=
[
out_idx
+
1
+
op
.
info
.
n_seqs
]
new_op
.
destroy_map
=
destroy_map
# Do not call make_node for test_value
# Do not call make_node for test_value
new_outs
=
new_op
(
*
inputs
,
**
dict
(
return_list
=
True
))
new_outs
=
new_op
(
*
inputs
,
**
dict
(
return_list
=
True
))
...
@@ -1070,7 +1073,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -1070,7 +1073,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
scan_nodes
=
[
scan_nodes
=
[
x
x
for
x
in
nodes
for
x
in
nodes
if
(
isinstance
(
x
.
op
,
Scan
)
and
x
.
op
.
info
[
"gpua"
]
==
self
.
gpua_flag
)
if
(
isinstance
(
x
.
op
,
Scan
)
and
x
.
op
.
info
.
gpua
==
self
.
gpua_flag
)
]
]
for
scan_idx
in
range
(
len
(
scan_nodes
)):
for
scan_idx
in
range
(
len
(
scan_nodes
)):
...
@@ -1080,7 +1083,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -1080,7 +1083,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# them.
# them.
original_node
=
scan_nodes
[
scan_idx
]
original_node
=
scan_nodes
[
scan_idx
]
op
=
original_node
.
op
op
=
original_node
.
op
n_outs
=
op
.
info
[
"n_mit_mot"
]
+
op
.
info
[
"n_mit_sot"
]
+
op
.
info
[
"n_sit_sot"
]
n_outs
=
op
.
info
.
n_mit_mot
+
op
.
info
.
n_mit_sot
+
op
.
info
.
n_sit_sot
# Generate a list of outputs on which the node could potentially
# Generate a list of outputs on which the node could potentially
# operate inplace.
# operate inplace.
...
@@ -1171,7 +1174,7 @@ class ScanSaveMem(GlobalOptimizer):
...
@@ -1171,7 +1174,7 @@ class ScanSaveMem(GlobalOptimizer):
# Each access to shape_of is in a try..except block in order to
# Each access to shape_of is in a try..except block in order to
# use a default version when the variable is not in the shape_of
# use a default version when the variable is not in the shape_of
# dictionary.
# dictionary.
shape_of
=
OrderedDict
()
shape_of
=
{}
# 1. Initialization of variables
# 1. Initialization of variables
# Note 1) We do not actually care about outputs representing shared
# Note 1) We do not actually care about outputs representing shared
# variables (those have no intermediate values) so it is safer to
# variables (those have no intermediate values) so it is safer to
...
@@ -1548,7 +1551,7 @@ class ScanSaveMem(GlobalOptimizer):
...
@@ -1548,7 +1551,7 @@ class ScanSaveMem(GlobalOptimizer):
(
inps
,
outs
,
info
,
node_ins
,
compress_map
)
=
compress_outs
(
(
inps
,
outs
,
info
,
node_ins
,
compress_map
)
=
compress_outs
(
op
,
not_required
,
nw_inputs
op
,
not_required
,
nw_inputs
)
)
inv_compress_map
=
OrderedDict
()
inv_compress_map
=
{}
for
k
,
v
in
compress_map
.
items
():
for
k
,
v
in
compress_map
.
items
():
inv_compress_map
[
v
]
=
k
inv_compress_map
[
v
]
=
k
...
@@ -1559,7 +1562,9 @@ class ScanSaveMem(GlobalOptimizer):
...
@@ -1559,7 +1562,9 @@ class ScanSaveMem(GlobalOptimizer):
return
return
# Do not call make_node for test_value
# Do not call make_node for test_value
new_outs
=
Scan
(
inps
,
outs
,
info
)(
*
node_ins
,
**
dict
(
return_list
=
True
))
new_outs
=
Scan
(
inps
,
outs
,
info
,
op
.
mode
)(
*
node_ins
,
**
dict
(
return_list
=
True
)
)
old_new
=
[]
old_new
=
[]
# 3.7 Get replace pairs for those outputs that do not change
# 3.7 Get replace pairs for those outputs that do not change
...
@@ -1688,24 +1693,6 @@ class ScanMerge(GlobalOptimizer):
...
@@ -1688,24 +1693,6 @@ class ScanMerge(GlobalOptimizer):
else
:
else
:
as_while
=
False
as_while
=
False
info
=
OrderedDict
()
info
[
"tap_array"
]
=
[]
info
[
"n_seqs"
]
=
sum
([
nd
.
op
.
n_seqs
for
nd
in
nodes
])
info
[
"n_mit_mot"
]
=
sum
([
nd
.
op
.
n_mit_mot
for
nd
in
nodes
])
info
[
"n_mit_mot_outs"
]
=
sum
([
nd
.
op
.
n_mit_mot_outs
for
nd
in
nodes
])
info
[
"mit_mot_out_slices"
]
=
[]
info
[
"n_mit_sot"
]
=
sum
([
nd
.
op
.
n_mit_sot
for
nd
in
nodes
])
info
[
"n_sit_sot"
]
=
sum
([
nd
.
op
.
n_sit_sot
for
nd
in
nodes
])
info
[
"n_shared_outs"
]
=
sum
([
nd
.
op
.
n_shared_outs
for
nd
in
nodes
])
info
[
"n_nit_sot"
]
=
sum
([
nd
.
op
.
n_nit_sot
for
nd
in
nodes
])
info
[
"truncate_gradient"
]
=
nodes
[
0
]
.
op
.
truncate_gradient
info
[
"name"
]
=
"&"
.
join
([
nd
.
op
.
name
for
nd
in
nodes
])
info
[
"mode"
]
=
nodes
[
0
]
.
op
.
mode
info
[
"gpua"
]
=
False
info
[
"as_while"
]
=
as_while
info
[
"profile"
]
=
nodes
[
0
]
.
op
.
profile
info
[
"allow_gc"
]
=
nodes
[
0
]
.
op
.
allow_gc
# We keep the inner_ins and inner_outs of each original node separated.
# We keep the inner_ins and inner_outs of each original node separated.
# To be able to recombine them in the right order after the clone,
# To be able to recombine them in the right order after the clone,
# we also need to split them by types (seq, mitmot, ...).
# we also need to split them by types (seq, mitmot, ...).
...
@@ -1725,12 +1712,15 @@ class ScanMerge(GlobalOptimizer):
...
@@ -1725,12 +1712,15 @@ class ScanMerge(GlobalOptimizer):
inner_ins
[
idx
]
.
append
(
rename
(
nd
.
op
.
inner_seqs
(
nd
.
op
.
inputs
),
idx
))
inner_ins
[
idx
]
.
append
(
rename
(
nd
.
op
.
inner_seqs
(
nd
.
op
.
inputs
),
idx
))
outer_ins
+=
rename
(
nd
.
op
.
outer_seqs
(
nd
.
inputs
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_seqs
(
nd
.
inputs
),
idx
)
tap_array
=
()
mit_mot_out_slices
=
()
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# MitMot
# MitMot
inner_ins
[
idx
]
.
append
(
rename
(
nd
.
op
.
inner_mitmot
(
nd
.
op
.
inputs
),
idx
))
inner_ins
[
idx
]
.
append
(
rename
(
nd
.
op
.
inner_mitmot
(
nd
.
op
.
inputs
),
idx
))
inner_outs
[
idx
]
.
append
(
nd
.
op
.
inner_mitmot_outs
(
nd
.
op
.
outputs
))
inner_outs
[
idx
]
.
append
(
nd
.
op
.
inner_mitmot_outs
(
nd
.
op
.
outputs
))
info
[
"tap_array"
]
+=
nd
.
op
.
mitmot_taps
()
tap_array
+=
nd
.
op
.
mitmot_taps
()
info
[
"mit_mot_out_slices"
]
+=
nd
.
op
.
mitmot_out_taps
()
mit_mot_out_slices
+=
nd
.
op
.
mitmot_out_taps
()
outer_ins
+=
rename
(
nd
.
op
.
outer_mitmot
(
nd
.
inputs
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_mitmot
(
nd
.
inputs
),
idx
)
outer_outs
+=
nd
.
op
.
outer_mitmot_outs
(
nd
.
outputs
)
outer_outs
+=
nd
.
op
.
outer_mitmot_outs
(
nd
.
outputs
)
...
@@ -1738,14 +1728,14 @@ class ScanMerge(GlobalOptimizer):
...
@@ -1738,14 +1728,14 @@ class ScanMerge(GlobalOptimizer):
# MitSot
# MitSot
inner_ins
[
idx
]
.
append
(
rename
(
nd
.
op
.
inner_mitsot
(
nd
.
op
.
inputs
),
idx
))
inner_ins
[
idx
]
.
append
(
rename
(
nd
.
op
.
inner_mitsot
(
nd
.
op
.
inputs
),
idx
))
inner_outs
[
idx
]
.
append
(
nd
.
op
.
inner_mitsot_outs
(
nd
.
op
.
outputs
))
inner_outs
[
idx
]
.
append
(
nd
.
op
.
inner_mitsot_outs
(
nd
.
op
.
outputs
))
info
[
"tap_array"
]
+=
nd
.
op
.
mitsot_taps
()
tap_array
+=
nd
.
op
.
mitsot_taps
()
outer_ins
+=
rename
(
nd
.
op
.
outer_mitsot
(
nd
.
inputs
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_mitsot
(
nd
.
inputs
),
idx
)
outer_outs
+=
nd
.
op
.
outer_mitsot_outs
(
nd
.
outputs
)
outer_outs
+=
nd
.
op
.
outer_mitsot_outs
(
nd
.
outputs
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# SitSot
# SitSot
inner_ins
[
idx
]
.
append
(
rename
(
nd
.
op
.
inner_sitsot
(
nd
.
op
.
inputs
),
idx
))
inner_ins
[
idx
]
.
append
(
rename
(
nd
.
op
.
inner_sitsot
(
nd
.
op
.
inputs
),
idx
))
info
[
"tap_array"
]
+=
[[
-
1
]
for
x
in
range
(
nd
.
op
.
n_sit_sot
)]
tap_array
+=
tuple
((
-
1
,)
for
x
in
range
(
nd
.
op
.
n_sit_sot
))
inner_outs
[
idx
]
.
append
(
nd
.
op
.
inner_sitsot_outs
(
nd
.
op
.
outputs
))
inner_outs
[
idx
]
.
append
(
nd
.
op
.
inner_sitsot_outs
(
nd
.
op
.
outputs
))
outer_ins
+=
rename
(
nd
.
op
.
outer_sitsot
(
nd
.
inputs
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_sitsot
(
nd
.
inputs
),
idx
)
outer_outs
+=
nd
.
op
.
outer_sitsot_outs
(
nd
.
outputs
)
outer_outs
+=
nd
.
op
.
outer_sitsot_outs
(
nd
.
outputs
)
...
@@ -1834,7 +1824,25 @@ class ScanMerge(GlobalOptimizer):
...
@@ -1834,7 +1824,25 @@ class ScanMerge(GlobalOptimizer):
else
:
else
:
new_inner_outs
+=
inner_outs
[
idx
][
gr_idx
]
new_inner_outs
+=
inner_outs
[
idx
][
gr_idx
]
new_op
=
Scan
(
new_inner_ins
,
new_inner_outs
,
info
)
info
=
ScanInfo
(
tap_array
=
tap_array
,
n_seqs
=
sum
([
nd
.
op
.
n_seqs
for
nd
in
nodes
]),
n_mit_mot
=
sum
([
nd
.
op
.
n_mit_mot
for
nd
in
nodes
]),
n_mit_mot_outs
=
sum
([
nd
.
op
.
n_mit_mot_outs
for
nd
in
nodes
]),
mit_mot_out_slices
=
mit_mot_out_slices
,
n_mit_sot
=
sum
([
nd
.
op
.
n_mit_sot
for
nd
in
nodes
]),
n_sit_sot
=
sum
([
nd
.
op
.
n_sit_sot
for
nd
in
nodes
]),
n_shared_outs
=
sum
([
nd
.
op
.
n_shared_outs
for
nd
in
nodes
]),
n_nit_sot
=
sum
([
nd
.
op
.
n_nit_sot
for
nd
in
nodes
]),
truncate_gradient
=
nodes
[
0
]
.
op
.
truncate_gradient
,
name
=
"&"
.
join
([
nd
.
op
.
name
for
nd
in
nodes
]),
gpua
=
False
,
as_while
=
as_while
,
profile
=
nodes
[
0
]
.
op
.
profile
,
allow_gc
=
nodes
[
0
]
.
op
.
allow_gc
,
)
new_op
=
Scan
(
new_inner_ins
,
new_inner_outs
,
info
,
nodes
[
0
]
.
op
.
mode
)
new_outs
=
new_op
(
*
outer_ins
)
new_outs
=
new_op
(
*
outer_ins
)
if
not
isinstance
(
new_outs
,
(
list
,
tuple
)):
if
not
isinstance
(
new_outs
,
(
list
,
tuple
)):
...
@@ -1932,7 +1940,7 @@ def make_equiv(lo, li):
...
@@ -1932,7 +1940,7 @@ def make_equiv(lo, li):
the equivalence of their corresponding outer inputs.
the equivalence of their corresponding outer inputs.
"""
"""
seeno
=
OrderedDict
()
seeno
=
{}
left
=
[]
left
=
[]
right
=
[]
right
=
[]
for
o
,
i
in
zip
(
lo
,
li
):
for
o
,
i
in
zip
(
lo
,
li
):
...
@@ -1956,7 +1964,7 @@ def scan_merge_inouts(fgraph, node):
...
@@ -1956,7 +1964,7 @@ def scan_merge_inouts(fgraph, node):
node
.
inputs
,
node
.
outputs
,
node
.
op
.
inputs
,
node
.
op
.
outputs
,
node
.
op
.
info
node
.
inputs
,
node
.
outputs
,
node
.
op
.
inputs
,
node
.
op
.
outputs
,
node
.
op
.
info
)
)
inp_equiv
=
OrderedDict
()
inp_equiv
=
{}
if
has_duplicates
(
a
.
outer_in_seqs
):
if
has_duplicates
(
a
.
outer_in_seqs
):
new_outer_seqs
=
[]
new_outer_seqs
=
[]
...
@@ -1992,7 +2000,7 @@ def scan_merge_inouts(fgraph, node):
...
@@ -1992,7 +2000,7 @@ def scan_merge_inouts(fgraph, node):
a_inner_outs
=
a
.
inner_outputs
a_inner_outs
=
a
.
inner_outputs
inner_outputs
=
clone_replace
(
a_inner_outs
,
replace
=
inp_equiv
)
inner_outputs
=
clone_replace
(
a_inner_outs
,
replace
=
inp_equiv
)
op
=
Scan
(
inner_inputs
,
inner_outputs
,
info
)
op
=
Scan
(
inner_inputs
,
inner_outputs
,
info
,
node
.
op
.
mode
)
outputs
=
op
(
*
outer_inputs
)
outputs
=
op
(
*
outer_inputs
)
if
not
isinstance
(
outputs
,
(
list
,
tuple
)):
if
not
isinstance
(
outputs
,
(
list
,
tuple
)):
...
@@ -2019,7 +2027,7 @@ def scan_merge_inouts(fgraph, node):
...
@@ -2019,7 +2027,7 @@ def scan_merge_inouts(fgraph, node):
left
+=
_left
left
+=
_left
right
+=
_right
right
+=
_right
if
has_duplicates
(
na
.
outer_in_mit_mot
):
if
has_duplicates
(
na
.
outer_in_mit_mot
):
seen
=
OrderedDict
()
seen
=
{}
for
omm
,
imm
,
_sl
in
zip
(
for
omm
,
imm
,
_sl
in
zip
(
na
.
outer_in_mit_mot
,
na
.
inner_in_mit_mot
,
na
.
mit_mot_in_slices
na
.
outer_in_mit_mot
,
na
.
inner_in_mit_mot
,
na
.
mit_mot_in_slices
):
):
...
@@ -2032,7 +2040,7 @@ def scan_merge_inouts(fgraph, node):
...
@@ -2032,7 +2040,7 @@ def scan_merge_inouts(fgraph, node):
seen
[(
omm
,
sl
)]
=
imm
seen
[(
omm
,
sl
)]
=
imm
if
has_duplicates
(
na
.
outer_in_mit_sot
):
if
has_duplicates
(
na
.
outer_in_mit_sot
):
seen
=
OrderedDict
()
seen
=
{}
for
oms
,
ims
,
_sl
in
zip
(
for
oms
,
ims
,
_sl
in
zip
(
na
.
outer_in_mit_sot
,
na
.
inner_in_mit_sot
,
na
.
mit_sot_in_slices
na
.
outer_in_mit_sot
,
na
.
inner_in_mit_sot
,
na
.
mit_sot_in_slices
):
):
...
@@ -2117,9 +2125,7 @@ def scan_merge_inouts(fgraph, node):
...
@@ -2117,9 +2125,7 @@ def scan_merge_inouts(fgraph, node):
new_outer_out_mit_mot
.
append
(
outer_omm
)
new_outer_out_mit_mot
.
append
(
outer_omm
)
na
.
outer_out_mit_mot
=
new_outer_out_mit_mot
na
.
outer_out_mit_mot
=
new_outer_out_mit_mot
if
remove
:
if
remove
:
return
OrderedDict
(
return
dict
([(
"remove"
,
remove
)]
+
list
(
zip
(
node
.
outputs
,
na
.
outer_outputs
)))
[(
"remove"
,
remove
)]
+
list
(
zip
(
node
.
outputs
,
na
.
outer_outputs
))
)
return
na
.
outer_outputs
return
na
.
outer_outputs
...
@@ -2214,15 +2220,17 @@ class PushOutDot1(GlobalOptimizer):
...
@@ -2214,15 +2220,17 @@ class PushOutDot1(GlobalOptimizer):
inner_non_seqs
=
op
.
inner_non_seqs
(
op
.
inputs
)
inner_non_seqs
=
op
.
inner_non_seqs
(
op
.
inputs
)
outer_non_seqs
=
op
.
outer_non_seqs
(
node
)
outer_non_seqs
=
op
.
outer_non_seqs
(
node
)
new_info
=
op
.
info
.
copy
()
st
=
len
(
op
.
mitmot_taps
())
+
len
(
op
.
mitsot_taps
())
st
=
len
(
op
.
mitmot_taps
())
+
len
(
op
.
mitsot_taps
())
new_info
[
"tap_array"
]
=
(
new_info
=
dataclasses
.
replace
(
new_info
[
"tap_array"
][:
st
+
idx
]
op
.
info
,
+
new_info
[
"tap_array"
][
st
+
idx
+
1
:]
tap_array
=
(
op
.
info
.
tap_array
[:
st
+
idx
]
+
op
.
info
.
tap_array
[
st
+
idx
+
1
:]
),
n_sit_sot
=
op
.
info
.
n_sit_sot
-
1
,
n_nit_sot
=
op
.
info
.
n_nit_sot
+
1
,
)
)
new_info
[
"n_sit_sot"
]
-=
1
new_info
[
"n_nit_sot"
]
+=
1
inner_sitsot
=
inner_sitsot
[:
idx
]
+
inner_sitsot
[
idx
+
1
:]
inner_sitsot
=
inner_sitsot
[:
idx
]
+
inner_sitsot
[
idx
+
1
:]
outer_sitsot
=
outer_sitsot
[:
idx
]
+
outer_sitsot
[
idx
+
1
:]
outer_sitsot
=
outer_sitsot
[:
idx
]
+
outer_sitsot
[
idx
+
1
:]
inner_sitsot_outs
=
(
inner_sitsot_outs
=
(
...
@@ -2249,7 +2257,7 @@ class PushOutDot1(GlobalOptimizer):
...
@@ -2249,7 +2257,7 @@ class PushOutDot1(GlobalOptimizer):
new_inner_inps
,
new_inner_outs
=
reconstruct_graph
(
new_inner_inps
,
new_inner_outs
=
reconstruct_graph
(
_new_inner_inps
,
_new_inner_outs
_new_inner_inps
,
_new_inner_outs
)
)
new_op
=
Scan
(
new_inner_inps
,
new_inner_outs
,
new_info
)
new_op
=
Scan
(
new_inner_inps
,
new_inner_outs
,
new_info
,
op
.
mode
)
_scan_inputs
=
(
_scan_inputs
=
(
[
node
.
inputs
[
0
]]
[
node
.
inputs
[
0
]]
+
outer_seqs
+
outer_seqs
...
...
aesara/scan/utils.py
浏览文件 @
e85c7fd0
"""This module provides utility functions for the `Scan` `Op`."""
"""This module provides utility functions for the `Scan` `Op`."""
import
copy
import
copy
import
dataclasses
import
logging
import
logging
import
warnings
import
warnings
from
collections
import
OrderedDict
,
namedtuple
from
collections
import
OrderedDict
,
namedtuple
...
@@ -157,21 +158,6 @@ def traverse(out, x, x_copy, d, visited=None):
...
@@ -157,21 +158,6 @@ def traverse(out, x, x_copy, d, visited=None):
return
d
return
d
# Hashing a dictionary/list/tuple by xoring the hash of each element
def
hash_listsDictsTuples
(
x
):
hash_value
=
0
if
isinstance
(
x
,
dict
):
for
k
,
v
in
x
.
items
():
hash_value
^=
hash_listsDictsTuples
(
k
)
hash_value
^=
hash_listsDictsTuples
(
v
)
elif
isinstance
(
x
,
(
list
,
tuple
)):
for
v
in
x
:
hash_value
^=
hash_listsDictsTuples
(
v
)
else
:
hash_value
^=
hash
(
x
)
return
hash_value
def
map_variables
(
replacer
,
graphs
,
additional_inputs
=
None
):
def
map_variables
(
replacer
,
graphs
,
additional_inputs
=
None
):
"""Construct new graphs based on 'graphs' with some variables replaced
"""Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'.
according to 'replacer'.
...
@@ -264,6 +250,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
...
@@ -264,6 +250,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
new_inner_inputs
,
new_inner_inputs
,
new_inner_outputs
,
new_inner_outputs
,
node
.
op
.
info
,
node
.
op
.
info
,
node
.
op
.
mode
,
# FIXME: infer this someday?
# FIXME: infer this someday?
typeConstructor
=
None
,
typeConstructor
=
None
,
)
)
...
@@ -669,7 +656,7 @@ def scan_can_remove_outs(op, out_idxs):
...
@@ -669,7 +656,7 @@ def scan_can_remove_outs(op, out_idxs):
offset
=
op
.
n_seqs
offset
=
op
.
n_seqs
lim
=
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
lim
=
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
for
idx
in
range
(
lim
):
for
idx
in
range
(
lim
):
n_ins
=
len
(
op
.
info
[
"tap_array"
]
[
idx
])
n_ins
=
len
(
op
.
info
.
tap_array
[
idx
])
out_ins
+=
[
op
.
inputs
[
offset
:
offset
+
n_ins
]]
out_ins
+=
[
op
.
inputs
[
offset
:
offset
+
n_ins
]]
offset
+=
n_ins
offset
+=
n_ins
out_ins
+=
[[]
for
k
in
range
(
op
.
n_nit_sot
)]
out_ins
+=
[[]
for
k
in
range
(
op
.
n_nit_sot
)]
...
@@ -702,23 +689,25 @@ def compress_outs(op, not_required, inputs):
...
@@ -702,23 +689,25 @@ def compress_outs(op, not_required, inputs):
node inputs, and changing the dictionary.
node inputs, and changing the dictionary.
"""
"""
info
=
OrderedDict
()
from
aesara.scan.op
import
ScanInfo
info
[
"tap_array"
]
=
[]
info
[
"n_seqs"
]
=
op
.
info
[
"n_seqs"
]
info
=
ScanInfo
(
info
[
"n_mit_mot"
]
=
0
tap_array
=
(),
info
[
"n_mit_mot_outs"
]
=
0
n_seqs
=
op
.
info
.
n_seqs
,
info
[
"mit_mot_out_slices"
]
=
[]
n_mit_mot
=
0
,
info
[
"n_mit_sot"
]
=
0
n_mit_mot_outs
=
0
,
info
[
"n_sit_sot"
]
=
0
mit_mot_out_slices
=
(),
info
[
"n_shared_outs"
]
=
0
n_mit_sot
=
0
,
info
[
"n_nit_sot"
]
=
0
n_sit_sot
=
0
,
info
[
"truncate_gradient"
]
=
op
.
info
[
"truncate_gradient"
]
n_shared_outs
=
0
,
info
[
"name"
]
=
op
.
info
[
"name"
]
n_nit_sot
=
0
,
info
[
"gpua"
]
=
op
.
info
[
"gpua"
]
truncate_gradient
=
op
.
info
.
truncate_gradient
,
info
[
"mode"
]
=
op
.
info
[
"mode"
]
name
=
op
.
info
.
name
,
info
[
"as_while"
]
=
op
.
info
[
"as_while"
]
gpua
=
op
.
info
.
gpua
,
info
[
"profile"
]
=
op
.
info
[
"profile"
]
as_while
=
op
.
info
.
as_while
,
info
[
"allow_gc"
]
=
op
.
info
[
"allow_gc"
]
profile
=
op
.
info
.
profile
,
allow_gc
=
op
.
info
.
allow_gc
,
)
op_inputs
=
op
.
inputs
[:
op
.
n_seqs
]
op_inputs
=
op
.
inputs
[:
op
.
n_seqs
]
op_outputs
=
[]
op_outputs
=
[]
...
@@ -730,13 +719,17 @@ def compress_outs(op, not_required, inputs):
...
@@ -730,13 +719,17 @@ def compress_outs(op, not_required, inputs):
i_offset
=
op
.
n_seqs
i_offset
=
op
.
n_seqs
o_offset
=
0
o_offset
=
0
curr_pos
=
0
curr_pos
=
0
for
idx
in
range
(
op
.
info
[
"n_mit_mot"
]
):
for
idx
in
range
(
op
.
info
.
n_mit_mot
):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
"n_mit_mot"
]
+=
1
info
=
dataclasses
.
replace
(
info
[
"tap_array"
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
,
info
[
"mit_mot_out_slices"
]
+=
[
op
.
mit_mot_out_slices
[
offset
+
idx
]]
n_mit_mot
=
info
.
n_mit_mot
+
1
,
tap_array
=
info
.
tap_array
+
(
tuple
(
op
.
tap_array
[
offset
+
idx
]),),
mit_mot_out_slices
=
info
.
mit_mot_out_slices
+
(
tuple
(
op
.
mit_mot_out_slices
[
offset
+
idx
]),),
)
# input taps
# input taps
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
...
@@ -750,16 +743,19 @@ def compress_outs(op, not_required, inputs):
...
@@ -750,16 +743,19 @@ def compress_outs(op, not_required, inputs):
else
:
else
:
o_offset
+=
len
(
op
.
mit_mot_out_slices
[
offset
+
idx
])
o_offset
+=
len
(
op
.
mit_mot_out_slices
[
offset
+
idx
])
i_offset
+=
len
(
op
.
tap_array
[
offset
+
idx
])
i_offset
+=
len
(
op
.
tap_array
[
offset
+
idx
])
info
[
"n_mit_mot_outs"
]
=
len
(
op_outputs
)
info
=
dataclasses
.
replace
(
info
,
n_mit_mot_outs
=
len
(
op_outputs
)
)
offset
+=
op
.
n_mit_mot
offset
+=
op
.
n_mit_mot
ni_offset
+=
op
.
n_mit_mot
ni_offset
+=
op
.
n_mit_mot
for
idx
in
range
(
op
.
info
[
"n_mit_sot"
]
):
for
idx
in
range
(
op
.
info
.
n_mit_sot
):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
"n_mit_sot"
]
+=
1
info
=
dataclasses
.
replace
(
info
[
"tap_array"
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
,
n_mit_sot
=
info
.
n_mit_sot
+
1
,
tap_array
=
info
.
tap_array
+
(
tuple
(
op
.
tap_array
[
offset
+
idx
]),),
)
# input taps
# input taps
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
...
@@ -775,12 +771,15 @@ def compress_outs(op, not_required, inputs):
...
@@ -775,12 +771,15 @@ def compress_outs(op, not_required, inputs):
offset
+=
op
.
n_mit_sot
offset
+=
op
.
n_mit_sot
ni_offset
+=
op
.
n_mit_sot
ni_offset
+=
op
.
n_mit_sot
for
idx
in
range
(
op
.
info
[
"n_sit_sot"
]
):
for
idx
in
range
(
op
.
info
.
n_sit_sot
):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
"n_sit_sot"
]
+=
1
info
=
dataclasses
.
replace
(
info
[
"tap_array"
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
,
n_sit_sot
=
info
.
n_sit_sot
+
1
,
tap_array
=
info
.
tap_array
+
(
tuple
(
op
.
tap_array
[
offset
+
idx
]),),
)
# input taps
# input taps
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
i_offset
+=
1
i_offset
+=
1
...
@@ -796,11 +795,11 @@ def compress_outs(op, not_required, inputs):
...
@@ -796,11 +795,11 @@ def compress_outs(op, not_required, inputs):
offset
+=
op
.
n_sit_sot
offset
+=
op
.
n_sit_sot
ni_offset
+=
op
.
n_sit_sot
ni_offset
+=
op
.
n_sit_sot
nit_sot_ins
=
[]
nit_sot_ins
=
[]
for
idx
in
range
(
op
.
info
[
"n_nit_sot"
]
):
for
idx
in
range
(
op
.
info
.
n_nit_sot
):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
"n_nit_sot"
]
+=
1
info
=
dataclasses
.
replace
(
info
,
n_nit_sot
=
info
.
n_nit_sot
+
1
)
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
o_offset
+=
1
nit_sot_ins
+=
[
inputs
[
ni_offset
+
idx
+
op
.
n_shared_outs
]]
nit_sot_ins
+=
[
inputs
[
ni_offset
+
idx
+
op
.
n_shared_outs
]]
...
@@ -809,11 +808,11 @@ def compress_outs(op, not_required, inputs):
...
@@ -809,11 +808,11 @@ def compress_outs(op, not_required, inputs):
offset
+=
op
.
n_nit_sot
offset
+=
op
.
n_nit_sot
shared_ins
=
[]
shared_ins
=
[]
for
idx
in
range
(
op
.
info
[
"n_shared_outs"
]
):
for
idx
in
range
(
op
.
info
.
n_shared_outs
):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
"n_shared_outs"
]
+=
1
info
=
dataclasses
.
replace
(
info
,
n_shared_outs
=
info
.
n_shared_outs
+
1
)
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
o_offset
+=
1
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
...
@@ -896,7 +895,7 @@ class ScanArgs:
...
@@ -896,7 +895,7 @@ class ScanArgs:
else
:
else
:
rval
=
(
_inner_inputs
,
_inner_outputs
)
rval
=
(
_inner_inputs
,
_inner_outputs
)
if
info
[
"as_while"
]
:
if
info
.
as_while
:
self
.
cond
=
[
rval
[
1
][
-
1
]]
self
.
cond
=
[
rval
[
1
][
-
1
]]
inner_outputs
=
rval
[
1
][:
-
1
]
inner_outputs
=
rval
[
1
][:
-
1
]
else
:
else
:
...
@@ -907,17 +906,17 @@ class ScanArgs:
...
@@ -907,17 +906,17 @@ class ScanArgs:
p
=
1
p
=
1
q
=
0
q
=
0
n_seqs
=
info
[
"n_seqs"
]
n_seqs
=
info
.
n_seqs
self
.
outer_in_seqs
=
outer_inputs
[
p
:
p
+
n_seqs
]
self
.
outer_in_seqs
=
outer_inputs
[
p
:
p
+
n_seqs
]
self
.
inner_in_seqs
=
inner_inputs
[
q
:
q
+
n_seqs
]
self
.
inner_in_seqs
=
inner_inputs
[
q
:
q
+
n_seqs
]
p
+=
n_seqs
p
+=
n_seqs
q
+=
n_seqs
q
+=
n_seqs
n_mit_mot
=
info
[
"n_mit_mot"
]
n_mit_mot
=
info
.
n_mit_mot
n_mit_sot
=
info
[
"n_mit_sot"
]
n_mit_sot
=
info
.
n_mit_sot
self
.
mit_mot_in_slices
=
info
[
"tap_array"
]
[:
n_mit_mot
]
self
.
mit_mot_in_slices
=
info
.
tap_array
[:
n_mit_mot
]
self
.
mit_sot_in_slices
=
info
[
"tap_array"
]
[
n_mit_mot
:
n_mit_mot
+
n_mit_sot
]
self
.
mit_sot_in_slices
=
info
.
tap_array
[
n_mit_mot
:
n_mit_mot
+
n_mit_sot
]
n_mit_mot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_in_slices
)
n_mit_mot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_in_slices
)
n_mit_sot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_sot_in_slices
)
n_mit_sot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_sot_in_slices
)
...
@@ -943,19 +942,19 @@ class ScanArgs:
...
@@ -943,19 +942,19 @@ class ScanArgs:
self
.
outer_in_mit_sot
=
outer_inputs
[
p
:
p
+
n_mit_sot
]
self
.
outer_in_mit_sot
=
outer_inputs
[
p
:
p
+
n_mit_sot
]
p
+=
n_mit_sot
p
+=
n_mit_sot
n_sit_sot
=
info
[
"n_sit_sot"
]
n_sit_sot
=
info
.
n_sit_sot
self
.
outer_in_sit_sot
=
outer_inputs
[
p
:
p
+
n_sit_sot
]
self
.
outer_in_sit_sot
=
outer_inputs
[
p
:
p
+
n_sit_sot
]
self
.
inner_in_sit_sot
=
inner_inputs
[
q
:
q
+
n_sit_sot
]
self
.
inner_in_sit_sot
=
inner_inputs
[
q
:
q
+
n_sit_sot
]
p
+=
n_sit_sot
p
+=
n_sit_sot
q
+=
n_sit_sot
q
+=
n_sit_sot
n_shared_outs
=
info
[
"n_shared_outs"
]
n_shared_outs
=
info
.
n_shared_outs
self
.
outer_in_shared
=
outer_inputs
[
p
:
p
+
n_shared_outs
]
self
.
outer_in_shared
=
outer_inputs
[
p
:
p
+
n_shared_outs
]
self
.
inner_in_shared
=
inner_inputs
[
q
:
q
+
n_shared_outs
]
self
.
inner_in_shared
=
inner_inputs
[
q
:
q
+
n_shared_outs
]
p
+=
n_shared_outs
p
+=
n_shared_outs
q
+=
n_shared_outs
q
+=
n_shared_outs
n_nit_sot
=
info
[
"n_nit_sot"
]
n_nit_sot
=
info
.
n_nit_sot
self
.
outer_in_nit_sot
=
outer_inputs
[
p
:
p
+
n_nit_sot
]
self
.
outer_in_nit_sot
=
outer_inputs
[
p
:
p
+
n_nit_sot
]
p
+=
n_nit_sot
p
+=
n_nit_sot
...
@@ -966,14 +965,14 @@ class ScanArgs:
...
@@ -966,14 +965,14 @@ class ScanArgs:
p
=
0
p
=
0
q
=
0
q
=
0
self
.
mit_mot_out_slices
=
info
[
"mit_mot_out_slices"
]
self
.
mit_mot_out_slices
=
info
.
mit_mot_out_slices
n_mit_mot_outs
=
info
[
"n_mit_mot_outs"
]
n_mit_mot_outs
=
info
.
n_mit_mot_outs
self
.
outer_out_mit_mot
=
outer_outputs
[
p
:
p
+
n_mit_mot
]
self
.
outer_out_mit_mot
=
outer_outputs
[
p
:
p
+
n_mit_mot
]
iomm
=
inner_outputs
[
q
:
q
+
n_mit_mot_outs
]
iomm
=
inner_outputs
[
q
:
q
+
n_mit_mot_outs
]
self
.
inner_out_mit_mot
=
[]
self
.
inner_out_mit_mot
=
()
qq
=
0
qq
=
0
for
sl
in
self
.
mit_mot_out_slices
:
for
sl
in
self
.
mit_mot_out_slices
:
self
.
inner_out_mit_mot
.
append
(
iomm
[
qq
:
qq
+
len
(
sl
)]
)
self
.
inner_out_mit_mot
+=
(
iomm
[
qq
:
qq
+
len
(
sl
)],
)
qq
+=
len
(
sl
)
qq
+=
len
(
sl
)
p
+=
n_mit_mot
p
+=
n_mit_mot
q
+=
n_mit_mot_outs
q
+=
n_mit_mot_outs
...
@@ -1001,19 +1000,17 @@ class ScanArgs:
...
@@ -1001,19 +1000,17 @@ class ScanArgs:
assert
p
==
len
(
outer_outputs
)
assert
p
==
len
(
outer_outputs
)
assert
q
==
len
(
inner_outputs
)
assert
q
==
len
(
inner_outputs
)
self
.
other_info
=
OrderedDict
()
self
.
other_info
=
{
k
:
getattr
(
info
,
k
)
for
k
in
(
for
k
in
(
"truncate_gradient"
,
"truncate_gradient"
,
"name"
,
"name"
,
"mode"
,
"destroy_map"
,
"gpua"
,
"gpua"
,
"as_while"
,
"as_while"
,
"profile"
,
"profile"
,
"allow_gc"
,
"allow_gc"
,
):
)
if
k
in
info
:
}
self
.
other_info
[
k
]
=
info
[
k
]
@staticmethod
@staticmethod
def
from_node
(
node
,
clone
=
False
):
def
from_node
(
node
,
clone
=
False
):
...
@@ -1032,26 +1029,24 @@ class ScanArgs:
...
@@ -1032,26 +1029,24 @@ class ScanArgs:
@classmethod
@classmethod
def
create_empty
(
cls
):
def
create_empty
(
cls
):
info
=
OrderedDict
(
from
aesara.scan.op
import
ScanInfo
[
(
"n_seqs"
,
0
),
info
=
ScanInfo
(
(
"n_mit_mot"
,
0
),
n_seqs
=
0
,
(
"n_mit_sot"
,
0
),
n_mit_mot
=
0
,
(
"tap_array"
,
[]),
n_mit_sot
=
0
,
(
"n_sit_sot"
,
0
),
tap_array
=
(),
(
"n_nit_sot"
,
0
),
n_sit_sot
=
0
,
(
"n_shared_outs"
,
0
),
n_nit_sot
=
0
,
(
"n_mit_mot_outs"
,
0
),
n_shared_outs
=
0
,
(
"mit_mot_out_slices"
,
[]),
n_mit_mot_outs
=
0
,
(
"truncate_gradient"
,
-
1
),
mit_mot_out_slices
=
(),
(
"name"
,
None
),
truncate_gradient
=-
1
,
(
"mode"
,
None
),
name
=
None
,
(
"destroy_map"
,
OrderedDict
()),
gpua
=
False
,
(
"gpua"
,
False
),
as_while
=
False
,
(
"as_while"
,
False
),
profile
=
False
,
(
"profile"
,
False
),
allow_gc
=
False
,
(
"allow_gc"
,
False
),
]
)
)
res
=
cls
([
1
],
[],
[],
[],
info
)
res
=
cls
([
1
],
[],
[],
[],
info
)
res
.
n_steps
=
None
res
.
n_steps
=
None
...
@@ -1060,7 +1055,7 @@ class ScanArgs:
...
@@ -1060,7 +1055,7 @@ class ScanArgs:
@property
@property
def
n_nit_sot
(
self
):
def
n_nit_sot
(
self
):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return
self
.
info
[
"n_nit_sot"
]
return
self
.
info
.
n_nit_sot
@property
@property
def
inputs
(
self
):
def
inputs
(
self
):
...
@@ -1070,7 +1065,7 @@ class ScanArgs:
...
@@ -1070,7 +1065,7 @@ class ScanArgs:
@property
@property
def
n_mit_mot
(
self
):
def
n_mit_mot
(
self
):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return
self
.
info
[
"n_mit_mot"
]
return
self
.
info
.
n_mit_mot
@property
@property
def
var_mappings
(
self
):
def
var_mappings
(
self
):
...
@@ -1141,20 +1136,22 @@ class ScanArgs:
...
@@ -1141,20 +1136,22 @@ class ScanArgs:
@property
@property
def
info
(
self
):
def
info
(
self
):
return
OrderedDict
(
from
aesara.scan.op
import
ScanInfo
return
ScanInfo
(
n_seqs
=
len
(
self
.
outer_in_seqs
),
n_seqs
=
len
(
self
.
outer_in_seqs
),
n_mit_mot
=
len
(
self
.
outer_in_mit_mot
),
n_mit_mot
=
len
(
self
.
outer_in_mit_mot
),
n_mit_sot
=
len
(
self
.
outer_in_mit_sot
),
n_mit_sot
=
len
(
self
.
outer_in_mit_sot
),
tap_array
=
(
tap_array
=
(
self
.
mit_mot_in_slices
tuple
(
tuple
(
v
)
for
v
in
self
.
mit_mot_in_slices
)
+
self
.
mit_sot_in_slices
+
tuple
(
tuple
(
v
)
for
v
in
self
.
mit_sot_in_slices
)
+
[[
-
1
]]
*
len
(
self
.
inner_in_sit_sot
)
+
((
-
1
,),)
*
len
(
self
.
inner_in_sit_sot
)
),
),
n_sit_sot
=
len
(
self
.
outer_in_sit_sot
),
n_sit_sot
=
len
(
self
.
outer_in_sit_sot
),
n_nit_sot
=
len
(
self
.
outer_in_nit_sot
),
n_nit_sot
=
len
(
self
.
outer_in_nit_sot
),
n_shared_outs
=
len
(
self
.
outer_in_shared
),
n_shared_outs
=
len
(
self
.
outer_in_shared
),
n_mit_mot_outs
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_out_slices
),
n_mit_mot_outs
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_out_slices
),
mit_mot_out_slices
=
self
.
mit_mot_out_slices
,
mit_mot_out_slices
=
tuple
(
self
.
mit_mot_out_slices
)
,
**
self
.
other_info
,
**
self
.
other_info
,
)
)
...
...
requirements.txt
浏览文件 @
e85c7fd0
-e ./
-e ./
dataclasses
>=0.7; python_version < '3.7'
filelock
filelock
flake8
==3.8.4
flake8
==3.8.4
pep8
pep8
...
...
setup.py
浏览文件 @
e85c7fd0
#!/usr/bin/env python
#!/usr/bin/env python
import
sys
from
setuptools
import
find_packages
,
setup
from
setuptools
import
find_packages
,
setup
import
versioneer
import
versioneer
...
@@ -43,6 +45,11 @@ Programming Language :: Python :: 3.9
...
@@ -43,6 +45,11 @@ Programming Language :: Python :: 3.9
"""
"""
CLASSIFIERS
=
[
_f
for
_f
in
CLASSIFIERS
.
split
(
"
\n
"
)
if
_f
]
CLASSIFIERS
=
[
_f
for
_f
in
CLASSIFIERS
.
split
(
"
\n
"
)
if
_f
]
install_requires
=
[
"numpy>=1.17.0"
,
"scipy>=0.14"
,
"filelock"
]
if
sys
.
version_info
[
0
:
2
]
<
(
3
,
7
):
install_requires
+=
[
"dataclasses"
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
setup
(
setup
(
name
=
NAME
,
name
=
NAME
,
...
@@ -57,7 +64,7 @@ if __name__ == "__main__":
...
@@ -57,7 +64,7 @@ if __name__ == "__main__":
license
=
LICENSE
,
license
=
LICENSE
,
platforms
=
PLATFORMS
,
platforms
=
PLATFORMS
,
packages
=
find_packages
(
exclude
=
[
"tests"
,
"tests.*"
]),
packages
=
find_packages
(
exclude
=
[
"tests"
,
"tests.*"
]),
install_requires
=
[
"numpy>=1.17.0"
,
"scipy>=0.14"
,
"filelock"
]
,
install_requires
=
install_requires
,
package_data
=
{
package_data
=
{
""
:
[
""
:
[
"*.txt"
,
"*.txt"
,
...
...
tests/scan/test_utils.py
浏览文件 @
e85c7fd0
...
@@ -252,10 +252,7 @@ def test_ScanArgs():
...
@@ -252,10 +252,7 @@ def test_ScanArgs():
# The `scan_args` base class always clones the inner-graph;
# The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same)
# here we make sure it doesn't (and that all the inputs are the same)
assert
scan_args
.
inputs
==
scan_op
.
inputs
assert
scan_args
.
inputs
==
scan_op
.
inputs
scan_op_info
=
dict
(
scan_op
.
info
)
assert
scan_args
.
info
==
scan_op
.
info
# The `ScanInfo` dictionary has the wrong order and an extra entry
del
scan_op_info
[
"strict"
]
assert
dict
(
scan_args
.
info
)
==
scan_op_info
assert
scan_args
.
var_mappings
==
scan_op
.
var_mappings
assert
scan_args
.
var_mappings
==
scan_op
.
var_mappings
# Check that `ScanArgs.find_among_fields` works
# Check that `ScanArgs.find_among_fields` works
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论