Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c8dc3dbe
提交
c8dc3dbe
authored
10月 21, 2015
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3524 from carriepl/scan_inplace_opt
Scan inplace opt
上级
8f64c64f
87d55476
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
119 行增加
和
95 行删除
+119
-95
__init__.py
theano/sandbox/cuda/__init__.py
+1
-1
__init__.py
theano/sandbox/gpuarray/__init__.py
+3
-2
scan_opt.py
theano/scan_module/scan_opt.py
+115
-91
test_flake8.py
theano/tests/test_flake8.py
+0
-1
没有找到文件。
theano/sandbox/cuda/__init__.py
浏览文件 @
c8dc3dbe
...
...
@@ -316,7 +316,7 @@ if cuda_available:
GpuDimShuffle
,
GpuCAReduce
,
GpuReshape
,
GpuContiguous
,
GpuSubtensor
,
GpuIncSubtensor
,
GpuAdvancedSubtensor1
,
GpuAdvancedIncSubtensor1
,
GpuFlatten
,
GpuShape
,
GpuAlloc
,
GpuSplit
,
GpuFlatten
,
GpuShape
,
GpuAlloc
,
Gpu
AllocEmpty
,
Gpu
Split
,
GpuJoin
,
fscalar
,
fvector
,
fmatrix
,
frow
,
fcol
,
ftensor3
,
ftensor4
,
scalar
,
vector
,
matrix
,
row
,
col
,
...
...
theano/sandbox/gpuarray/__init__.py
浏览文件 @
c8dc3dbe
...
...
@@ -68,8 +68,9 @@ if pygpu:
theano
.
compile
.
shared_constructor
(
gpuarray_shared_constructor
)
optdb
.
add_tags
(
'gpuarray_opt'
,
'fast_run'
,
'fast_compile'
)
from
.basic_ops
import
(
GpuAlloc
,
GpuContiguous
,
GpuEye
,
GpuFromHost
,
GpuJoin
,
GpuReshape
,
GpuSplit
,
HostFromGpu
)
from
.basic_ops
import
(
GpuAlloc
,
GpuAllocEmpty
,
GpuContiguous
,
GpuEye
,
GpuFromHost
,
GpuJoin
,
GpuReshape
,
GpuSplit
,
HostFromGpu
)
from
.basic_ops
import
host_from_gpu
,
GpuFromHost
from
.elemwise
import
GpuElemwise
from
.subtensor
import
(
GpuSubtensor
,
GpuIncSubtensor
,
...
...
theano/scan_module/scan_opt.py
浏览文件 @
c8dc3dbe
...
...
@@ -48,16 +48,7 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
in2out(scan_merge_inouts),
ScanSaveMem,
in2out(remove_constants_and_unused_inputs_scan3)
"""
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
import
logging
import
copy
...
...
@@ -66,21 +57,29 @@ import numpy
import
theano
from
theano
import
tensor
from
theano.tensor
import
opt
,
get_scalar_constant_value
from
theano.tensor
import
opt
,
get_scalar_constant_value
,
Alloc
,
AllocEmpty
from
theano
import
gof
from
theano.compat
import
OrderedDict
from
six
import
integer_types
,
iteritems
from
six.moves
import
xrange
from
theano.gof.opt
import
Optimizer
from
theano.gof.opt
import
pre_constant_merge
,
pre_greedy_local_optimizer
from
theano.gof
import
toolbox
,
DestroyHandler
,
InconsistencyError
from
theano.compile
import
optdb
from
theano.compile.function_module
import
deep_copy_op
from
theano.gof
import
toolbox
,
DestroyHandler
,
InconsistencyError
from
theano.gof.opt
import
Optimizer
from
theano.gof.opt
import
pre_constant_merge
,
pre_greedy_local_optimizer
from
theano.scan_module
import
scan_op
from
theano.scan_module
import
scan_utils
from
theano.scan_module.scan_utils
import
equal_computations
,
find_up
,
\
scan_args
from
theano.scan_module.scan_utils
import
equal_computations
,
find_up
,
scan_args
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
# Logging function for sending warning or info
...
...
@@ -243,18 +242,17 @@ class PushOutNonSeqScan(gof.Optimizer):
local_fgraph_topo
=
theano
.
gof
.
graph
.
io_toposort
(
clean_inputs
,
clean_outputs
)
local_fgraph_outs_set
=
set
(
clean_outputs
)
local_fgraph_outs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
\
local_fgraph_outs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
clean_outputs
)])
to_remove_set
=
set
()
to_replace_set
=
set
()
to_replace_map
=
OrderedDict
()
nto_replace
=
0
def
add_to_replace
(
y
):
to_replace_set
.
add
(
y
)
to_replace_map
[
y
]
=
add_to_replace
.
n
add_to_replace
.
n
+=
1
add_to_replace
.
n
+=
1
add_to_replace
.
n
=
0
replace_with_in
=
[]
...
...
@@ -264,7 +262,8 @@ class PushOutNonSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things
inner_non_seqs
=
op
.
inner_non_seqs
(
clean_inputs
)
inner_non_seqs_set
=
set
(
inner_non_seqs
)
inner_non_seqs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
inner_non_seqs
)])
inner_non_seqs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
inner_non_seqs
)])
outer_non_seqs
=
op
.
outer_non_seqs
(
node
.
inputs
)
...
...
@@ -275,7 +274,7 @@ class PushOutNonSeqScan(gof.Optimizer):
assert
len
(
inner_seqs
)
==
len
(
outer_seqs
)
for
nd
in
local_fgraph_topo
:
if
(
# we haven't already looked at this node
if
(
# we haven't already looked at this node
nd
not
in
to_remove_set
and
all
([((
x
in
inner_non_seqs_set
)
or
(
x
.
owner
in
to_remove_set
)
or
...
...
@@ -337,7 +336,7 @@ class PushOutNonSeqScan(gof.Optimizer):
to_keep_set
.
update
(
nd
.
inputs
)
for
out
,
idx
in
to_replace_map
.
items
():
if
(
# If types are different, conversion Op will be inserted,
if
(
# If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop.
replace_with_in
[
idx
]
.
type
==
out
.
type
and
out
in
to_keep_set
and
...
...
@@ -450,13 +449,12 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph_topo
=
theano
.
gof
.
graph
.
io_toposort
(
clean_inputs
,
clean_outputs
)
local_fgraph_outs_set
=
set
(
clean_outputs
)
local_fgraph_outs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
\
local_fgraph_outs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
clean_outputs
)])
to_remove_set
=
set
()
to_replace_set
=
set
()
to_replace_map
=
OrderedDict
()
nto_replace
=
0
def
add_to_replace
(
y
):
to_replace_set
.
add
(
y
)
...
...
@@ -471,12 +469,14 @@ class PushOutSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things
inner_non_seqs
=
op
.
inner_non_seqs
(
clean_inputs
)
inner_non_seqs_set
=
set
(
inner_non_seqs
)
inner_non_seqs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
inner_non_seqs
)])
inner_non_seqs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
inner_non_seqs
)])
outer_non_seqs
=
op
.
outer_non_seqs
(
node
.
inputs
)
inner_seqs
=
op
.
inner_seqs
(
clean_inputs
)
inner_seqs_set
=
set
(
inner_seqs
)
inner_seqs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
inner_seqs
)])
inner_seqs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
inner_seqs
)])
outer_seqs
=
op
.
outer_seqs
(
node
.
inputs
)
assert
len
(
inner_non_seqs
)
==
len
(
outer_non_seqs
)
...
...
@@ -582,11 +582,10 @@ class PushOutSeqScan(gof.Optimizer):
to_keep_set
.
update
(
nd
.
inputs
)
for
out
,
idx
in
to_replace_map
.
items
():
if
(
out
in
to_keep_set
and
out
.
owner
not
in
existent_nodes_set
if
(
out
in
to_keep_set
and
out
.
owner
not
in
existent_nodes_set
and
# If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop.
and
replace_with_in
[
idx
]
.
type
==
out
.
type
):
replace_with_in
[
idx
]
.
type
==
out
.
type
):
clean_to_replace
.
append
(
out
)
clean_replace_with_in
.
append
(
replace_with_in
[
idx
])
...
...
@@ -682,7 +681,7 @@ class PushOutScanOutput(gof.Optimizer):
not
x
.
op
.
as_while
)]
for
node
in
nodelist
:
# Process the node as long as something gets optimized
while
node
!=
None
:
while
node
is
not
None
:
node
=
self
.
process_node
(
fgraph
,
node
)
def
process_node
(
self
,
fgraph
,
node
):
...
...
@@ -778,9 +777,8 @@ class PushOutScanOutput(gof.Optimizer):
outer_dot_output
=
theano
.
tensor
.
dot
(
*
outer_dot_inputs
)
# Modify the outer graph to add the outer Dot
fgraph
.
replace_all
([
(
new_scan_args
.
outer_out_nit_sot
[
dot_out_nitsot_idx
],
fgraph
.
replace_all
(
[(
new_scan_args
.
outer_out_nit_sot
[
dot_out_nitsot_idx
],
outer_dot_output
)],
reason
=
"scanOp_pushout_output"
)
...
...
@@ -807,8 +805,9 @@ class PushOutScanOutput(gof.Optimizer):
sitsot_in_idx
=
nd
.
inputs
.
index
(
args
.
inner_in_sit_sot
[
sitsot_idx
])
dot_in_idx
=
1
-
sitsot_in_idx
# 0 if sitsot_in_idx==1,
# 1 if sitsot_in_idx==0
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx
=
1
-
sitsot_in_idx
dot_input
=
nd
.
inputs
[
dot_in_idx
]
if
(
dot_input
.
owner
is
not
None
and
...
...
@@ -816,10 +815,8 @@ class PushOutScanOutput(gof.Optimizer):
len
(
dot_input
.
clients
)
==
1
and
dot_input
.
owner
.
inputs
[
0
]
.
ndim
==
2
and
dot_input
.
owner
.
inputs
[
1
]
.
ndim
==
2
and
self
.
get_outer_ndim
(
dot_input
.
owner
.
inputs
[
0
],
args
)
\
==
3
and
self
.
get_outer_ndim
(
dot_input
.
owner
.
inputs
[
1
],
args
)
\
==
3
):
self
.
get_outer_ndim
(
dot_input
.
owner
.
inputs
[
0
],
args
)
==
3
and
self
.
get_outer_ndim
(
dot_input
.
owner
.
inputs
[
1
],
args
)
==
3
):
# The optimization can be be applied in this case.
...
...
@@ -829,8 +826,7 @@ class PushOutScanOutput(gof.Optimizer):
(
outer_dot_inputs
,
new_scan_node
,
new_scan_args
)
=
\
self
.
push_out_inner_vars
(
fgraph
,
inner_dot_inputs
,
self
.
push_out_inner_vars
(
fgraph
,
inner_dot_inputs
,
node
,
args
)
# Collapse some of the dimensions of the tensors
...
...
@@ -838,8 +834,7 @@ class PushOutScanOutput(gof.Optimizer):
# dot is usually faster on two large matrices than
# a bunch of small ones
outer_dot_inputs
[
0
]
=
theano
.
tensor
.
flatten
(
outer_dot_inputs
[
0
]
.
dimshuffle
(
1
,
0
,
2
),
outdim
=
2
)
outer_dot_inputs
[
0
]
.
dimshuffle
(
1
,
0
,
2
),
outdim
=
2
)
shape_input1
=
theano
.
tensor
.
shape
(
outer_dot_inputs
[
1
])
outer_dot_inputs
[
1
]
=
\
...
...
@@ -850,15 +845,13 @@ class PushOutScanOutput(gof.Optimizer):
# Perform the dot on the newly obtained matrices and
# add the initial value
outer_dot_output
=
theano
.
tensor
.
dot
(
*
outer_dot_inputs
)
init_value
=
\
new_scan_args
.
outer_in_sit_sot
[
sitsot_idx
][
0
]
init_value
=
new_scan_args
.
outer_in_sit_sot
[
sitsot_idx
][
0
]
replacement
=
outer_dot_output
+
init_value
# Alter the outer graph to use the output of the
# external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot
outer_sitsot
=
\
new_scan_args
.
outer_out_sit_sot
[
sitsot_idx
]
outer_sitsot
=
new_scan_args
.
outer_out_sit_sot
[
sitsot_idx
]
subtensor_node
=
outer_sitsot
.
clients
[
0
][
0
]
outer_sitsot_last_step
=
subtensor_node
.
outputs
[
0
]
...
...
@@ -883,8 +876,8 @@ class PushOutScanOutput(gof.Optimizer):
if
len
(
outer_var
.
clients
)
==
1
:
client
=
outer_var
.
clients
[
0
][
0
]
if
(
client
!=
'output'
and
isinstance
(
client
.
op
,
theano
.
tensor
.
Subtensor
)):
if
(
client
!=
'output'
and
isinstance
(
client
.
op
,
theano
.
tensor
.
Subtensor
)):
lst
=
theano
.
tensor
.
subtensor
.
get_idx_list
(
client
.
inputs
,
client
.
op
.
idx_list
)
if
(
len
(
lst
)
==
1
and
...
...
@@ -991,7 +984,7 @@ class PushOutScanOutput(gof.Optimizer):
new_node_old_outputs
=
(
new_scan_node
.
outputs
[:
new_node_new_outputs_idx
]
+
new_scan_node
.
outputs
[
new_node_new_outputs_idx
+
nb_new_outs
:])
new_scan_node
.
outputs
[
new_node_new_outputs_idx
+
nb_new_outs
:])
fgraph
.
replace_all_validate_remove
(
list
(
zip
(
old_scan_node
.
outputs
,
new_node_old_outputs
)),
...
...
@@ -1017,7 +1010,7 @@ class ScanInplaceOptimizer(Optimizer):
fgraph
.
attach_feature
(
toolbox
.
ReplaceValidate
())
fgraph
.
attach_feature
(
DestroyHandler
())
def
attempt_scan_inplace
(
self
,
fgraph
,
node
,
output_indices
):
def
attempt_scan_inplace
(
self
,
fgraph
,
node
,
output_indices
,
alloc_ops
):
"""Attempts to replace a Scan node by one which computes the specified
outputs inplace.
...
...
@@ -1029,6 +1022,10 @@ class ScanInplaceOptimizer(Optimizer):
Scan node to replace by an inplace version
output_indices : list of integers
Indices of the outputs to attempt to compute inplace
alloc_ops : list of Op classes
Classes that represent operation that allocate new memory and
that the optimization should duplicate so it can operate inplace
on them.
"""
op
=
node
.
op
...
...
@@ -1049,6 +1046,14 @@ class ScanInplaceOptimizer(Optimizer):
ls_end
+=
op
.
outer_nitsot
(
node
.
inputs
)
ls_end
+=
op
.
outer_non_seqs
(
node
.
inputs
)
# In `ls`, duplicate any input which has more then one client and is
# the output of an eligible allocation op
for
i
in
range
(
len
(
ls
)):
inp
=
ls
[
i
]
if
(
len
(
inp
.
clients
)
>
1
and
inp
.
owner
and
isinstance
(
inp
.
owner
.
op
,
alloc_ops
)):
ls
[
i
]
=
inp
.
owner
.
op
(
*
inp
.
owner
.
inputs
)
n_outs
=
len
(
ls
)
for
idx
in
xrange
(
n_outs
):
if
ls
[
idx
]
in
ls
[:
idx
]:
...
...
@@ -1079,6 +1084,21 @@ class ScanInplaceOptimizer(Optimizer):
def
apply
(
self
,
fgraph
):
# Depending on the values of gpu_flag and gpua_flag, get the list of
# memory allocation ops that the optimization should be able to handle
alloc_ops
=
(
Alloc
,
AllocEmpty
)
if
self
.
gpu_flag
:
alloc_ops
+=
(
theano
.
sandbox
.
cuda
.
GpuAlloc
,
theano
.
sandbox
.
cuda
.
GpuAllocEmpty
)
if
self
.
gpua_flag
:
# gpuarray might be imported but not its GpuAlloc and
# GpuAllopEmpty ops.
try
:
alloc_ops
+=
(
theano
.
sandbox
.
gpuarray
.
GpuAlloc
,
theano
.
sandbox
.
gpuarray
.
GpuAllocEmpty
)
except
:
pass
nodes
=
fgraph
.
toposort
()[::
-
1
]
scan_nodes
=
[
x
for
x
in
nodes
if
(
isinstance
(
x
.
op
,
scan_op
.
Scan
)
and
...
...
@@ -1101,7 +1121,18 @@ class ScanInplaceOptimizer(Optimizer):
out_indices
=
[]
for
out_idx
in
range
(
n_outs
):
inp_idx
=
1
+
op
.
n_seqs
+
out_idx
inp
=
original_node
.
inputs
[
inp_idx
]
# If the input is from an eligible allocation node, attempt to
# be inplace on it, even if other nodes are modifying it
# inplace.
if
inp
.
owner
and
isinstance
(
inp
.
owner
.
op
,
alloc_ops
):
out_indices
.
append
(
out_idx
)
continue
# If the input is not from an eligible allocation node, only
# attempt to be inplace on it if nothing else is currently
# inplace on it.
input_used_inplace
=
False
for
c
in
original_node
.
inputs
[
inp_idx
]
.
clients
:
client
=
c
[
0
]
...
...
@@ -1122,14 +1153,15 @@ class ScanInplaceOptimizer(Optimizer):
out_indices
.
append
(
out_idx
)
node
=
self
.
attempt_scan_inplace
(
fgraph
,
scan_nodes
[
scan_idx
],
out_indices
)
out_indices
,
alloc_ops
)
if
node
is
original_node
:
# Making the scan compute all plausible recurrent outputs
# inplace has failed. Attempt all plausible recurrent output
# individually.
for
pos
in
out_indices
:
node
=
self
.
attempt_scan_inplace
(
fgraph
,
node
,
[
pos
])
node
=
self
.
attempt_scan_inplace
(
fgraph
,
node
,
[
pos
],
alloc_ops
)
class
ScanSaveMem
(
gof
.
Optimizer
):
...
...
@@ -1242,7 +1274,7 @@ class ScanSaveMem(gof.Optimizer):
for
cl
,
_
in
out
.
clients
:
# 2.1 outputs of the function
#=> output needs all its intermediate values
#
=> output needs all its intermediate values
if
type
(
cl
)
==
str
:
# if the node is actually an output, then
# we need to store the entire thing
...
...
@@ -1250,20 +1282,20 @@ class ScanSaveMem(gof.Optimizer):
slices
[
i
]
=
None
break
# 2.2 non-subtensor nodes
#=> output needs all its intermediate values
#
=> output needs all its intermediate values
elif
not
isinstance
(
cl
.
op
,
tensor
.
Subtensor
):
global_nsteps
=
None
slices
[
i
]
=
None
break
# 2.3 subtensor nodes
#=> output might need to store just a subset of its values
#
=> output might need to store just a subset of its values
else
:
# 2.3.1 extract idx list of subtensor
this_slice
=
tensor
.
get_idx_list
(
cl
.
inputs
,
cl
.
op
.
idx_list
)
if
this_slice
is
None
:
# if unable to extract idx_list
#=> outputs needs all its intermediate values
#
=> outputs needs all its intermediate values
global_nsteps
=
None
slices
[
i
]
=
None
break
...
...
@@ -1406,8 +1438,7 @@ class ScanSaveMem(gof.Optimizer):
# for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism) and only if
# the pre-allocation mechanism is activated.
prealloc_outs
=
\
theano
.
config
.
scan
.
allow_output_prealloc
prealloc_outs
=
theano
.
config
.
scan
.
allow_output_prealloc
first_mitsot_idx
=
node
.
op
.
n_mit_mot
last_sitsot_idx
=
(
node
.
op
.
n_mit_mot
+
...
...
@@ -1433,7 +1464,7 @@ class ScanSaveMem(gof.Optimizer):
# currently.
# pval = pre_greedy_local_optimizer(list_opt_slice,
# pval)
#pval = pre_constant_merge([pval])[0]
#
pval = pre_constant_merge([pval])[0]
# if (isinstance(pval, theano.tensor.TensorConstant)
# and
# pval.dtype.startswith('int')):
...
...
@@ -1554,7 +1585,7 @@ class ScanSaveMem(gof.Optimizer):
nw_steps
)
nw_inputs
[
in_idx
]
=
nw_input
else
:
nw_input
=
nw_inputs
[
in_idx
][:(
initl
+
nw_steps
)]
nw_input
=
nw_inputs
[
in_idx
][:(
initl
+
nw_steps
)]
elif
idx
<
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
:
in_idx
=
offset
+
idx
+
op
.
n_shared_outs
...
...
@@ -1640,8 +1671,8 @@ class ScanSaveMem(gof.Optimizer):
stop
=
None
nw_slice
=
((
slice
(
sanitize
(
start
),
sanitize
(
stop
),
sanitize
(
cnf_slice
[
0
]
.
step
)),)
+
tuple
(
old_slices
[
1
:]))
sanitize
(
cnf_slice
[
0
]
.
step
)),)
+
tuple
(
old_slices
[
1
:]))
else
:
position
=
(
cnf_slice
[
0
]
-
nw_steps
-
...
...
@@ -1662,8 +1693,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.9. Get replace pairs for all other nodes
if
flag_store
or
global_nsteps
is
not
None
:
for
idx
,
o
in
enumerate
(
node
.
outputs
):
if
not
(
idx
in
replaced_outs
)
and
\
not
idx
in
not_required
:
if
not
(
idx
in
replaced_outs
)
and
idx
not
in
not_required
:
nw_pos
=
compress_map
[
idx
]
old_new
+=
[(
o
,
new_outs
[
nw_pos
])]
# Check if the new outputs depend on the old scan node
...
...
@@ -2072,8 +2102,8 @@ def scan_merge_inouts(node):
# because they could have different sizes, and the corresponding
# outer outputs cannot be merged in that case.
for
s_outer_i
,
s_inner_o
,
s_outer_o
in
seen
:
if
(
equal_computations
([
inner_o
],
[
s_inner_o
],
left
,
right
)
and
outer_i
==
s_outer_i
):
if
(
equal_computations
([
inner_o
],
[
s_inner_o
],
left
,
right
)
and
outer_i
==
s_outer_i
):
return
s_outer_o
seen
.
append
((
outer_i
,
inner_o
,
outer_o
))
return
outer_o
...
...
@@ -2116,9 +2146,10 @@ def scan_merge_inouts(node):
na
.
outer_out_mit_mot
,
na
.
mit_mot_out_slices
):
for
s_outer_imm
,
s_inner_omm
,
s_outer_omm
,
sosl
in
seen
:
if
(
osl
==
sosl
and
equal_computations
(
inner_omm
,
s_inner_omm
,
left
,
right
)
and
outer_imm
==
s_outer_imm
):
if
(
osl
==
sosl
and
equal_computations
(
inner_omm
,
s_inner_omm
,
left
,
right
)
and
outer_imm
==
s_outer_imm
):
new_outer_out_mit_mot
.
append
(
s_outer_omm
)
break
else
:
...
...
@@ -2168,17 +2199,15 @@ class PushOutDot1(gof.Optimizer):
inp
in
out
.
owner
.
inputs
and
len
(
outer_out
.
clients
)
==
1
and
not
isinstance
(
outer_out
.
clients
[
0
][
0
],
str
)
and
isinstance
(
outer_out
.
clients
[
0
][
0
]
.
op
,
theano
.
tensor
.
Subtensor
)
and
outer_out
.
clients
[
0
][
0
]
.
op
.
idx_list
==
(
-
1
,)):
isinstance
(
outer_out
.
clients
[
0
][
0
]
.
op
,
theano
.
tensor
.
Subtensor
)
and
outer_out
.
clients
[
0
][
0
]
.
op
.
idx_list
==
(
-
1
,)):
x
=
out
.
owner
.
inputs
[
0
]
if
x
==
inp
:
x
=
out
.
owner
.
inputs
[
1
]
# We need to check if x is the result of an outer product
if
(
x
.
owner
and
isinstance
(
x
.
owner
.
op
,
theano
.
tensor
.
Dot
)
and
x
.
owner
.
inputs
[
0
]
.
ndim
==
2
and
x
.
owner
.
inputs
[
1
]
.
ndim
==
2
):
if
(
x
.
owner
and
isinstance
(
x
.
owner
.
op
,
theano
.
tensor
.
Dot
)
and
x
.
owner
.
inputs
[
0
]
.
ndim
==
2
and
x
.
owner
.
inputs
[
1
]
.
ndim
==
2
):
# We need to check if any of the inputs are a sequence
inp1
=
x
.
owner
.
inputs
[
0
]
...
...
@@ -2219,18 +2248,17 @@ class PushOutDot1(gof.Optimizer):
new_info
=
op
.
info
.
copy
()
st
=
len
(
op
.
mitmot_taps
())
+
len
(
op
.
mitsot_taps
())
new_info
[
'tap_array'
]
=
(
\
new_info
[
'tap_array'
]
=
(
new_info
[
'tap_array'
][:
st
+
idx
]
+
new_info
[
'tap_array'
][
st
+
idx
+
1
:])
new_info
[
'tap_array'
][
st
+
idx
+
1
:])
new_info
[
'n_sit_sot'
]
-=
1
new_info
[
'n_nit_sot'
]
+=
1
inner_sitsot
=
inner_sitsot
[:
idx
]
+
\
inner_sitsot
[
idx
+
1
:]
outer_sitsot
=
outer_sitsot
[:
idx
]
+
\
outer_sitsot
[
idx
+
1
:]
inner_sitsot_outs
=
inner_sitsot_outs
[:
idx
]
+
\
inner_sitsot_outs
[
idx
+
1
:]
inner_sitsot
=
(
inner_sitsot
[:
idx
]
+
inner_sitsot
[
idx
+
1
:])
outer_sitsot
=
(
outer_sitsot
[:
idx
]
+
outer_sitsot
[
idx
+
1
:])
inner_sitsot_outs
=
(
inner_sitsot_outs
[:
idx
]
+
inner_sitsot_outs
[
idx
+
1
:])
# add n_steps as the length
inner_nitsot_outs
.
append
(
new_scan_out
)
...
...
@@ -2246,8 +2274,8 @@ class PushOutDot1(gof.Optimizer):
inner_nitsot_outs
+
inner_shared_outs
)
new_inner_inps
,
new_inner_outs
=
\
scan_utils
.
reconstruct_graph
(
_new_inner_inps
,
_new_inner_outs
)
scan_utils
.
reconstruct_graph
(
_new_inner_inps
,
_new_inner_outs
)
new_op
=
scan_op
.
Scan
(
new_inner_inps
,
new_inner_outs
,
new_info
)
_scan_inputs
=
([
node
.
inputs
[
0
]]
+
...
...
@@ -2267,11 +2295,7 @@ class PushOutDot1(gof.Optimizer):
# We need now to pair correctly the new outputs
# with the old ones
outer_mitmot_outs
=
new_op
.
outer_mitmot_outs
(
new_outs
)
outer_mitsot_outs
=
new_op
.
outer_mitsot_outs
(
new_outs
)
outer_sitsot_outs
=
new_op
.
outer_sitsot_outs
(
new_outs
)
outer_nitsot_outs
=
new_op
.
outer_nitsot_outs
(
new_outs
)
outer_shared_outs
=
new_op
.
outer_shared_outs
(
new_outs
)
_val
=
outer_nitsot_outs
[
-
1
]
outer_nitsot_outs
=
outer_nitsot_outs
[:
-
1
]
...
...
@@ -2305,7 +2329,7 @@ class PushOutDot1(gof.Optimizer):
old_new
=
list
(
zip
(
node
.
outputs
[:
pos
],
new_outs
[:
pos
]))
old
=
node
.
outputs
[
pos
]
.
clients
[
0
][
0
]
.
outputs
[
0
]
old_new
.
append
((
old
,
new_out
))
old_new
+=
list
(
zip
(
node
.
outputs
[
pos
+
1
:],
old_new
+=
list
(
zip
(
node
.
outputs
[
pos
+
1
:],
new_outs
[
pos
:]))
fgraph
.
replace_all_validate_remove
(
old_new
,
remove
=
[
node
],
reason
=
'scan_pushout_dot1'
)
...
...
theano/tests/test_flake8.py
浏览文件 @
c8dc3dbe
...
...
@@ -164,7 +164,6 @@ whitelist_flake8 = [
"scan_module/scan_op.py"
,
"scan_module/scan_perform_ext.py"
,
"scan_module/__init__.py"
,
"scan_module/scan_opt.py"
,
"scan_module/tests/test_scan.py"
,
"scan_module/tests/test_scan_opt.py"
,
"misc/tests/test_may_share_memory.py"
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论