Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c8dc3dbe
提交
c8dc3dbe
authored
10月 21, 2015
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3524 from carriepl/scan_inplace_opt
Scan inplace opt
上级
8f64c64f
87d55476
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
195 行增加
和
171 行删除
+195
-171
__init__.py
theano/sandbox/cuda/__init__.py
+2
-2
__init__.py
theano/sandbox/gpuarray/__init__.py
+3
-2
scan_opt.py
theano/scan_module/scan_opt.py
+190
-166
test_flake8.py
theano/tests/test_flake8.py
+0
-1
没有找到文件。
theano/sandbox/cuda/__init__.py
浏览文件 @
c8dc3dbe
...
@@ -316,7 +316,7 @@ if cuda_available:
...
@@ -316,7 +316,7 @@ if cuda_available:
GpuDimShuffle
,
GpuCAReduce
,
GpuReshape
,
GpuContiguous
,
GpuDimShuffle
,
GpuCAReduce
,
GpuReshape
,
GpuContiguous
,
GpuSubtensor
,
GpuIncSubtensor
,
GpuSubtensor
,
GpuIncSubtensor
,
GpuAdvancedSubtensor1
,
GpuAdvancedIncSubtensor1
,
GpuAdvancedSubtensor1
,
GpuAdvancedIncSubtensor1
,
GpuFlatten
,
GpuShape
,
GpuAlloc
,
GpuSplit
,
GpuFlatten
,
GpuShape
,
GpuAlloc
,
Gpu
AllocEmpty
,
Gpu
Split
,
GpuJoin
,
fscalar
,
fvector
,
fmatrix
,
frow
,
fcol
,
GpuJoin
,
fscalar
,
fvector
,
fmatrix
,
frow
,
fcol
,
ftensor3
,
ftensor4
,
ftensor3
,
ftensor4
,
scalar
,
vector
,
matrix
,
row
,
col
,
scalar
,
vector
,
matrix
,
row
,
col
,
...
@@ -341,7 +341,7 @@ def use(device,
...
@@ -341,7 +341,7 @@ def use(device,
Parameters
Parameters
----------
----------
device : string
device : string
"cpu", "gpu", "gpuN" (N is the device number to use).
"cpu", "gpu", "gpuN" (N is the device number to use).
force
force
Will always raise an exception if we can't use the gpu.
Will always raise an exception if we can't use the gpu.
...
...
theano/sandbox/gpuarray/__init__.py
浏览文件 @
c8dc3dbe
...
@@ -68,8 +68,9 @@ if pygpu:
...
@@ -68,8 +68,9 @@ if pygpu:
theano
.
compile
.
shared_constructor
(
gpuarray_shared_constructor
)
theano
.
compile
.
shared_constructor
(
gpuarray_shared_constructor
)
optdb
.
add_tags
(
'gpuarray_opt'
,
'fast_run'
,
'fast_compile'
)
optdb
.
add_tags
(
'gpuarray_opt'
,
'fast_run'
,
'fast_compile'
)
from
.basic_ops
import
(
GpuAlloc
,
GpuContiguous
,
GpuEye
,
GpuFromHost
,
from
.basic_ops
import
(
GpuAlloc
,
GpuAllocEmpty
,
GpuContiguous
,
GpuEye
,
GpuJoin
,
GpuReshape
,
GpuSplit
,
HostFromGpu
)
GpuFromHost
,
GpuJoin
,
GpuReshape
,
GpuSplit
,
HostFromGpu
)
from
.basic_ops
import
host_from_gpu
,
GpuFromHost
from
.basic_ops
import
host_from_gpu
,
GpuFromHost
from
.elemwise
import
GpuElemwise
from
.elemwise
import
GpuElemwise
from
.subtensor
import
(
GpuSubtensor
,
GpuIncSubtensor
,
from
.subtensor
import
(
GpuSubtensor
,
GpuIncSubtensor
,
...
...
theano/scan_module/scan_opt.py
浏览文件 @
c8dc3dbe
...
@@ -48,16 +48,7 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
...
@@ -48,16 +48,7 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
in2out(scan_merge_inouts),
in2out(scan_merge_inouts),
ScanSaveMem,
ScanSaveMem,
in2out(remove_constants_and_unused_inputs_scan3)
in2out(remove_constants_and_unused_inputs_scan3)
"""
"""
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
import
logging
import
logging
import
copy
import
copy
...
@@ -66,21 +57,29 @@ import numpy
...
@@ -66,21 +57,29 @@ import numpy
import
theano
import
theano
from
theano
import
tensor
from
theano
import
tensor
from
theano.tensor
import
opt
,
get_scalar_constant_value
from
theano.tensor
import
opt
,
get_scalar_constant_value
,
Alloc
,
AllocEmpty
from
theano
import
gof
from
theano
import
gof
from
theano.compat
import
OrderedDict
from
theano.compat
import
OrderedDict
from
six
import
integer_types
,
iteritems
from
six
import
integer_types
,
iteritems
from
six.moves
import
xrange
from
six.moves
import
xrange
from
theano.gof.opt
import
Optimizer
from
theano.gof.opt
import
pre_constant_merge
,
pre_greedy_local_optimizer
from
theano.gof
import
toolbox
,
DestroyHandler
,
InconsistencyError
from
theano.compile
import
optdb
from
theano.compile
import
optdb
from
theano.compile.function_module
import
deep_copy_op
from
theano.compile.function_module
import
deep_copy_op
from
theano.gof
import
toolbox
,
DestroyHandler
,
InconsistencyError
from
theano.gof.opt
import
Optimizer
from
theano.gof.opt
import
pre_constant_merge
,
pre_greedy_local_optimizer
from
theano.scan_module
import
scan_op
from
theano.scan_module
import
scan_op
from
theano.scan_module
import
scan_utils
from
theano.scan_module
import
scan_utils
from
theano.scan_module.scan_utils
import
equal_computations
,
find_up
,
\
from
theano.scan_module.scan_utils
import
equal_computations
,
find_up
,
scan_args
scan_args
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
# Logging function for sending warning or info
# Logging function for sending warning or info
...
@@ -151,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(node):
...
@@ -151,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(node):
for
idx
in
xrange
(
op
.
n_seqs
):
for
idx
in
xrange
(
op
.
n_seqs
):
node_inp
=
node
.
inputs
[
idx
+
1
]
node_inp
=
node
.
inputs
[
idx
+
1
]
if
(
isinstance
(
node_inp
,
tensor
.
TensorConstant
)
and
if
(
isinstance
(
node_inp
,
tensor
.
TensorConstant
)
and
node_inp
.
tag
.
unique_value
is
not
None
):
node_inp
.
tag
.
unique_value
is
not
None
):
try
:
try
:
# This works if input is a constant that has all entries
# This works if input is a constant that has all entries
# equal
# equal
...
@@ -243,18 +242,17 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -243,18 +242,17 @@ class PushOutNonSeqScan(gof.Optimizer):
local_fgraph_topo
=
theano
.
gof
.
graph
.
io_toposort
(
clean_inputs
,
local_fgraph_topo
=
theano
.
gof
.
graph
.
io_toposort
(
clean_inputs
,
clean_outputs
)
clean_outputs
)
local_fgraph_outs_set
=
set
(
clean_outputs
)
local_fgraph_outs_set
=
set
(
clean_outputs
)
local_fgraph_outs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
\
local_fgraph_outs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
clean_outputs
)])
enumerate
(
clean_outputs
)])
to_remove_set
=
set
()
to_remove_set
=
set
()
to_replace_set
=
set
()
to_replace_set
=
set
()
to_replace_map
=
OrderedDict
()
to_replace_map
=
OrderedDict
()
nto_replace
=
0
def
add_to_replace
(
y
):
def
add_to_replace
(
y
):
to_replace_set
.
add
(
y
)
to_replace_set
.
add
(
y
)
to_replace_map
[
y
]
=
add_to_replace
.
n
to_replace_map
[
y
]
=
add_to_replace
.
n
add_to_replace
.
n
+=
1
add_to_replace
.
n
+=
1
add_to_replace
.
n
=
0
add_to_replace
.
n
=
0
replace_with_in
=
[]
replace_with_in
=
[]
...
@@ -264,7 +262,8 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -264,7 +262,8 @@ class PushOutNonSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things
# Construct the list of non_sequences to simplify a few things
inner_non_seqs
=
op
.
inner_non_seqs
(
clean_inputs
)
inner_non_seqs
=
op
.
inner_non_seqs
(
clean_inputs
)
inner_non_seqs_set
=
set
(
inner_non_seqs
)
inner_non_seqs_set
=
set
(
inner_non_seqs
)
inner_non_seqs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
inner_non_seqs
)])
inner_non_seqs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
inner_non_seqs
)])
outer_non_seqs
=
op
.
outer_non_seqs
(
node
.
inputs
)
outer_non_seqs
=
op
.
outer_non_seqs
(
node
.
inputs
)
...
@@ -275,17 +274,17 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -275,17 +274,17 @@ class PushOutNonSeqScan(gof.Optimizer):
assert
len
(
inner_seqs
)
==
len
(
outer_seqs
)
assert
len
(
inner_seqs
)
==
len
(
outer_seqs
)
for
nd
in
local_fgraph_topo
:
for
nd
in
local_fgraph_topo
:
if
(
# we haven't already looked at this node
if
(
# we haven't already looked at this node
nd
not
in
to_remove_set
and
nd
not
in
to_remove_set
and
all
([((
x
in
inner_non_seqs_set
)
or
all
([((
x
in
inner_non_seqs_set
)
or
(
x
.
owner
in
to_remove_set
)
or
(
x
.
owner
in
to_remove_set
)
or
isinstance
(
x
,
tensor
.
Constant
))
isinstance
(
x
,
tensor
.
Constant
))
for
x
in
nd
.
inputs
])
and
for
x
in
nd
.
inputs
])
and
# we can do this because the assumption is that a
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
# function and not somewhere in the middle ..
not
isinstance
(
nd
.
op
,
theano
.
compile
.
ViewOp
)
and
not
isinstance
(
nd
.
op
,
theano
.
compile
.
ViewOp
)
and
not
isinstance
(
nd
.
op
,
theano
.
compile
.
DeepCopyOp
)):
not
isinstance
(
nd
.
op
,
theano
.
compile
.
DeepCopyOp
)):
# We have a candidate node to removable
# We have a candidate node to removable
# Step 1. Reconstruct it on outside
# Step 1. Reconstruct it on outside
...
@@ -337,11 +336,11 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -337,11 +336,11 @@ class PushOutNonSeqScan(gof.Optimizer):
to_keep_set
.
update
(
nd
.
inputs
)
to_keep_set
.
update
(
nd
.
inputs
)
for
out
,
idx
in
to_replace_map
.
items
():
for
out
,
idx
in
to_replace_map
.
items
():
if
(
# If types are different, conversion Op will be inserted,
if
(
# If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop.
# and it may trigger an infinite loop.
replace_with_in
[
idx
]
.
type
==
out
.
type
and
replace_with_in
[
idx
]
.
type
==
out
.
type
and
out
in
to_keep_set
and
out
in
to_keep_set
and
out
.
owner
not
in
existent_nodes_set
):
out
.
owner
not
in
existent_nodes_set
):
clean_to_replace
.
append
(
out
)
clean_to_replace
.
append
(
out
)
clean_replace_with_in
.
append
(
replace_with_in
[
idx
])
clean_replace_with_in
.
append
(
replace_with_in
[
idx
])
clean_replace_with_out
.
append
(
replace_with_out
[
idx
])
clean_replace_with_out
.
append
(
replace_with_out
[
idx
])
...
@@ -450,13 +449,12 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -450,13 +449,12 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph_topo
=
theano
.
gof
.
graph
.
io_toposort
(
clean_inputs
,
local_fgraph_topo
=
theano
.
gof
.
graph
.
io_toposort
(
clean_inputs
,
clean_outputs
)
clean_outputs
)
local_fgraph_outs_set
=
set
(
clean_outputs
)
local_fgraph_outs_set
=
set
(
clean_outputs
)
local_fgraph_outs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
\
local_fgraph_outs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
clean_outputs
)])
enumerate
(
clean_outputs
)])
to_remove_set
=
set
()
to_remove_set
=
set
()
to_replace_set
=
set
()
to_replace_set
=
set
()
to_replace_map
=
OrderedDict
()
to_replace_map
=
OrderedDict
()
nto_replace
=
0
def
add_to_replace
(
y
):
def
add_to_replace
(
y
):
to_replace_set
.
add
(
y
)
to_replace_set
.
add
(
y
)
...
@@ -471,12 +469,14 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -471,12 +469,14 @@ class PushOutSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things
# Construct the list of non_sequences to simplify a few things
inner_non_seqs
=
op
.
inner_non_seqs
(
clean_inputs
)
inner_non_seqs
=
op
.
inner_non_seqs
(
clean_inputs
)
inner_non_seqs_set
=
set
(
inner_non_seqs
)
inner_non_seqs_set
=
set
(
inner_non_seqs
)
inner_non_seqs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
inner_non_seqs
)])
inner_non_seqs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
inner_non_seqs
)])
outer_non_seqs
=
op
.
outer_non_seqs
(
node
.
inputs
)
outer_non_seqs
=
op
.
outer_non_seqs
(
node
.
inputs
)
inner_seqs
=
op
.
inner_seqs
(
clean_inputs
)
inner_seqs
=
op
.
inner_seqs
(
clean_inputs
)
inner_seqs_set
=
set
(
inner_seqs
)
inner_seqs_set
=
set
(
inner_seqs
)
inner_seqs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
inner_seqs
)])
inner_seqs_map
=
dict
([(
v
,
k
)
for
k
,
v
in
enumerate
(
inner_seqs
)])
outer_seqs
=
op
.
outer_seqs
(
node
.
inputs
)
outer_seqs
=
op
.
outer_seqs
(
node
.
inputs
)
assert
len
(
inner_non_seqs
)
==
len
(
outer_non_seqs
)
assert
len
(
inner_non_seqs
)
==
len
(
outer_non_seqs
)
...
@@ -488,7 +488,7 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -488,7 +488,7 @@ class PushOutSeqScan(gof.Optimizer):
(
x
.
owner
in
to_remove_set
)
or
(
x
.
owner
in
to_remove_set
)
or
isinstance
(
x
,
tensor
.
Constant
)
or
isinstance
(
x
,
tensor
.
Constant
)
or
(
x
in
inner_seqs_set
)
for
x
in
nd
.
inputs
])
and
(
x
in
inner_seqs_set
)
for
x
in
nd
.
inputs
])
and
isinstance
(
nd
.
op
,
theano
.
tensor
.
Elemwise
)):
isinstance
(
nd
.
op
,
theano
.
tensor
.
Elemwise
)):
outside_ins
=
[]
outside_ins
=
[]
depends_on_seqs
=
False
depends_on_seqs
=
False
...
@@ -538,7 +538,7 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -538,7 +538,7 @@ class PushOutSeqScan(gof.Optimizer):
elif
(
nd
not
in
to_remove_set
and
elif
(
nd
not
in
to_remove_set
and
isinstance
(
nd
.
op
,
theano
.
tensor
.
DimShuffle
)
and
isinstance
(
nd
.
op
,
theano
.
tensor
.
DimShuffle
)
and
(
nd
.
inputs
[
0
]
in
inner_seqs_set
or
(
nd
.
inputs
[
0
]
in
inner_seqs_set
or
nd
.
inputs
[
0
]
.
owner
in
to_remove_set
)):
nd
.
inputs
[
0
]
.
owner
in
to_remove_set
)):
to_remove_set
.
add
(
nd
)
to_remove_set
.
add
(
nd
)
x
=
nd
.
inputs
[
0
]
x
=
nd
.
inputs
[
0
]
...
@@ -582,11 +582,10 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -582,11 +582,10 @@ class PushOutSeqScan(gof.Optimizer):
to_keep_set
.
update
(
nd
.
inputs
)
to_keep_set
.
update
(
nd
.
inputs
)
for
out
,
idx
in
to_replace_map
.
items
():
for
out
,
idx
in
to_replace_map
.
items
():
if
(
out
in
to_keep_set
if
(
out
in
to_keep_set
and
out
.
owner
not
in
existent_nodes_set
and
and
out
.
owner
not
in
existent_nodes_set
# If types are different, conversion Op will be inserted,
# If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop.
# and it may trigger an infinite loop.
replace_with_in
[
idx
]
.
type
==
out
.
type
):
and
replace_with_in
[
idx
]
.
type
==
out
.
type
):
clean_to_replace
.
append
(
out
)
clean_to_replace
.
append
(
out
)
clean_replace_with_in
.
append
(
replace_with_in
[
idx
])
clean_replace_with_in
.
append
(
replace_with_in
[
idx
])
...
@@ -682,7 +681,7 @@ class PushOutScanOutput(gof.Optimizer):
...
@@ -682,7 +681,7 @@ class PushOutScanOutput(gof.Optimizer):
not
x
.
op
.
as_while
)]
not
x
.
op
.
as_while
)]
for
node
in
nodelist
:
for
node
in
nodelist
:
# Process the node as long as something gets optimized
# Process the node as long as something gets optimized
while
node
!=
None
:
while
node
is
not
None
:
node
=
self
.
process_node
(
fgraph
,
node
)
node
=
self
.
process_node
(
fgraph
,
node
)
def
process_node
(
self
,
fgraph
,
node
):
def
process_node
(
self
,
fgraph
,
node
):
...
@@ -702,7 +701,7 @@ class PushOutScanOutput(gof.Optimizer):
...
@@ -702,7 +701,7 @@ class PushOutScanOutput(gof.Optimizer):
local_fgraph_topo
=
local_fgraph
.
toposort
()
local_fgraph_topo
=
local_fgraph
.
toposort
()
for
nd
in
local_fgraph_topo
:
for
nd
in
local_fgraph_topo
:
if
(
isinstance
(
nd
.
op
,
theano
.
tensor
.
Dot
)
and
if
(
isinstance
(
nd
.
op
,
theano
.
tensor
.
Dot
)
and
nd
.
out
in
args
.
inner_out_nit_sot
):
nd
.
out
in
args
.
inner_out_nit_sot
):
"""
"""
The following optimization involves pushing out, after the
The following optimization involves pushing out, after the
scan, a Dot whose output is nitsot (not feed back to the inner
scan, a Dot whose output is nitsot (not feed back to the inner
...
@@ -737,8 +736,8 @@ class PushOutScanOutput(gof.Optimizer):
...
@@ -737,8 +736,8 @@ class PushOutScanOutput(gof.Optimizer):
(
nd
.
inputs
[
0
]
in
args
.
inner_in_non_seqs
or
(
nd
.
inputs
[
0
]
in
args
.
inner_in_non_seqs
or
isinstance
(
nd
.
inputs
[
0
],
tensor
.
Constant
))
and
isinstance
(
nd
.
inputs
[
0
],
tensor
.
Constant
))
and
nd
.
inputs
[
1
]
.
ndim
==
1
and
nd
.
inputs
[
1
]
.
ndim
==
1
and
(
nd
.
inputs
[
1
]
in
args
.
inner_in_seqs
or
(
nd
.
inputs
[
1
]
in
args
.
inner_in_seqs
or
nd
.
inputs
[
1
]
not
in
args
.
inner_inputs
)):
nd
.
inputs
[
1
]
not
in
args
.
inner_inputs
)):
valid_inputs
=
True
valid_inputs
=
True
idx_matrix_input
=
0
idx_matrix_input
=
0
...
@@ -778,11 +777,10 @@ class PushOutScanOutput(gof.Optimizer):
...
@@ -778,11 +777,10 @@ class PushOutScanOutput(gof.Optimizer):
outer_dot_output
=
theano
.
tensor
.
dot
(
*
outer_dot_inputs
)
outer_dot_output
=
theano
.
tensor
.
dot
(
*
outer_dot_inputs
)
# Modify the outer graph to add the outer Dot
# Modify the outer graph to add the outer Dot
fgraph
.
replace_all
([
fgraph
.
replace_all
(
(
new_scan_args
.
outer_out_nit_sot
[
[(
new_scan_args
.
outer_out_nit_sot
[
dot_out_nitsot_idx
],
dot_out_nitsot_idx
],
outer_dot_output
)],
outer_dot_output
)],
reason
=
"scanOp_pushout_output"
)
reason
=
"scanOp_pushout_output"
)
break
break
...
@@ -807,8 +805,9 @@ class PushOutScanOutput(gof.Optimizer):
...
@@ -807,8 +805,9 @@ class PushOutScanOutput(gof.Optimizer):
sitsot_in_idx
=
nd
.
inputs
.
index
(
args
.
inner_in_sit_sot
[
sitsot_in_idx
=
nd
.
inputs
.
index
(
args
.
inner_in_sit_sot
[
sitsot_idx
])
sitsot_idx
])
dot_in_idx
=
1
-
sitsot_in_idx
# 0 if sitsot_in_idx==1,
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
# 1 if sitsot_in_idx==0
dot_in_idx
=
1
-
sitsot_in_idx
dot_input
=
nd
.
inputs
[
dot_in_idx
]
dot_input
=
nd
.
inputs
[
dot_in_idx
]
if
(
dot_input
.
owner
is
not
None
and
if
(
dot_input
.
owner
is
not
None
and
...
@@ -816,10 +815,8 @@ class PushOutScanOutput(gof.Optimizer):
...
@@ -816,10 +815,8 @@ class PushOutScanOutput(gof.Optimizer):
len
(
dot_input
.
clients
)
==
1
and
len
(
dot_input
.
clients
)
==
1
and
dot_input
.
owner
.
inputs
[
0
]
.
ndim
==
2
and
dot_input
.
owner
.
inputs
[
0
]
.
ndim
==
2
and
dot_input
.
owner
.
inputs
[
1
]
.
ndim
==
2
and
dot_input
.
owner
.
inputs
[
1
]
.
ndim
==
2
and
self
.
get_outer_ndim
(
dot_input
.
owner
.
inputs
[
0
],
args
)
\
self
.
get_outer_ndim
(
dot_input
.
owner
.
inputs
[
0
],
args
)
==
3
and
==
3
and
self
.
get_outer_ndim
(
dot_input
.
owner
.
inputs
[
1
],
args
)
==
3
):
self
.
get_outer_ndim
(
dot_input
.
owner
.
inputs
[
1
],
args
)
\
==
3
):
# The optimization can be be applied in this case.
# The optimization can be be applied in this case.
...
@@ -829,36 +826,32 @@ class PushOutScanOutput(gof.Optimizer):
...
@@ -829,36 +826,32 @@ class PushOutScanOutput(gof.Optimizer):
(
outer_dot_inputs
,
(
outer_dot_inputs
,
new_scan_node
,
new_scan_node
,
new_scan_args
)
=
\
new_scan_args
)
=
\
self
.
push_out_inner_vars
(
fgraph
,
self
.
push_out_inner_vars
(
fgraph
,
inner_dot_inputs
,
inner_dot_inputs
,
node
,
args
)
node
,
args
)
# Collapse some of the dimensions of the tensors
# Collapse some of the dimensions of the tensors
# so that they become matrices. This is because a
# so that they become matrices. This is because a
# dot is usually faster on two large matrices than
# dot is usually faster on two large matrices than
# a bunch of small ones
# a bunch of small ones
outer_dot_inputs
[
0
]
=
theano
.
tensor
.
flatten
(
outer_dot_inputs
[
0
]
=
theano
.
tensor
.
flatten
(
outer_dot_inputs
[
0
]
.
dimshuffle
(
1
,
0
,
2
),
outer_dot_inputs
[
0
]
.
dimshuffle
(
1
,
0
,
2
),
outdim
=
2
)
outdim
=
2
)
shape_input1
=
theano
.
tensor
.
shape
(
outer_dot_inputs
[
1
])
shape_input1
=
theano
.
tensor
.
shape
(
outer_dot_inputs
[
1
])
outer_dot_inputs
[
1
]
=
\
outer_dot_inputs
[
1
]
=
\
outer_dot_inputs
[
1
]
.
reshape
((
shape_input1
[
0
]
*
outer_dot_inputs
[
1
]
.
reshape
((
shape_input1
[
0
]
*
shape_input1
[
1
],
shape_input1
[
1
],
shape_input1
[
2
]))
shape_input1
[
2
]))
# Perform the dot on the newly obtained matrices and
# Perform the dot on the newly obtained matrices and
# add the initial value
# add the initial value
outer_dot_output
=
theano
.
tensor
.
dot
(
*
outer_dot_inputs
)
outer_dot_output
=
theano
.
tensor
.
dot
(
*
outer_dot_inputs
)
init_value
=
\
init_value
=
new_scan_args
.
outer_in_sit_sot
[
sitsot_idx
][
0
]
new_scan_args
.
outer_in_sit_sot
[
sitsot_idx
][
0
]
replacement
=
outer_dot_output
+
init_value
replacement
=
outer_dot_output
+
init_value
# Alter the outer graph to use the output of the
# Alter the outer graph to use the output of the
# external Dot instead of the output of scan
# external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot
# Modify the outer graph to add the outer Dot
outer_sitsot
=
\
outer_sitsot
=
new_scan_args
.
outer_out_sit_sot
[
sitsot_idx
]
new_scan_args
.
outer_out_sit_sot
[
sitsot_idx
]
subtensor_node
=
outer_sitsot
.
clients
[
0
][
0
]
subtensor_node
=
outer_sitsot
.
clients
[
0
][
0
]
outer_sitsot_last_step
=
subtensor_node
.
outputs
[
0
]
outer_sitsot_last_step
=
subtensor_node
.
outputs
[
0
]
...
@@ -883,12 +876,12 @@ class PushOutScanOutput(gof.Optimizer):
...
@@ -883,12 +876,12 @@ class PushOutScanOutput(gof.Optimizer):
if
len
(
outer_var
.
clients
)
==
1
:
if
len
(
outer_var
.
clients
)
==
1
:
client
=
outer_var
.
clients
[
0
][
0
]
client
=
outer_var
.
clients
[
0
][
0
]
if
(
client
!=
'output'
and
if
(
client
!=
'output'
and
isinstance
(
client
.
op
,
isinstance
(
client
.
op
,
theano
.
tensor
.
Subtensor
)):
theano
.
tensor
.
Subtensor
)):
lst
=
theano
.
tensor
.
subtensor
.
get_idx_list
(
lst
=
theano
.
tensor
.
subtensor
.
get_idx_list
(
client
.
inputs
,
client
.
op
.
idx_list
)
client
.
inputs
,
client
.
op
.
idx_list
)
if
(
len
(
lst
)
==
1
and
if
(
len
(
lst
)
==
1
and
theano
.
tensor
.
extract_constant
(
lst
[
0
])
==
-
1
):
theano
.
tensor
.
extract_constant
(
lst
[
0
])
==
-
1
):
return
True
return
True
return
False
return
False
...
@@ -898,7 +891,7 @@ class PushOutScanOutput(gof.Optimizer):
...
@@ -898,7 +891,7 @@ class PushOutScanOutput(gof.Optimizer):
# Given a variable, determine the number of dimension it would have if
# Given a variable, determine the number of dimension it would have if
# it was pushed out of scan
# it was pushed out of scan
if
(
var
in
scan_args
.
inner_in_non_seqs
or
if
(
var
in
scan_args
.
inner_in_non_seqs
or
isinstance
(
var
,
theano
.
Constant
)):
isinstance
(
var
,
theano
.
Constant
)):
outer_ndim
=
var
.
ndim
outer_ndim
=
var
.
ndim
else
:
else
:
...
@@ -990,8 +983,8 @@ class PushOutScanOutput(gof.Optimizer):
...
@@ -990,8 +983,8 @@ class PushOutScanOutput(gof.Optimizer):
len
(
old_scan_args
.
outer_out_shared
))
len
(
old_scan_args
.
outer_out_shared
))
new_node_old_outputs
=
(
new_node_old_outputs
=
(
new_scan_node
.
outputs
[:
new_node_new_outputs_idx
]
+
new_scan_node
.
outputs
[:
new_node_new_outputs_idx
]
+
new_scan_node
.
outputs
[
new_node_new_outputs_idx
+
nb_new_outs
:])
new_scan_node
.
outputs
[
new_node_new_outputs_idx
+
nb_new_outs
:])
fgraph
.
replace_all_validate_remove
(
fgraph
.
replace_all_validate_remove
(
list
(
zip
(
old_scan_node
.
outputs
,
new_node_old_outputs
)),
list
(
zip
(
old_scan_node
.
outputs
,
new_node_old_outputs
)),
...
@@ -1017,7 +1010,7 @@ class ScanInplaceOptimizer(Optimizer):
...
@@ -1017,7 +1010,7 @@ class ScanInplaceOptimizer(Optimizer):
fgraph
.
attach_feature
(
toolbox
.
ReplaceValidate
())
fgraph
.
attach_feature
(
toolbox
.
ReplaceValidate
())
fgraph
.
attach_feature
(
DestroyHandler
())
fgraph
.
attach_feature
(
DestroyHandler
())
def
attempt_scan_inplace
(
self
,
fgraph
,
node
,
output_indices
):
def
attempt_scan_inplace
(
self
,
fgraph
,
node
,
output_indices
,
alloc_ops
):
"""Attempts to replace a Scan node by one which computes the specified
"""Attempts to replace a Scan node by one which computes the specified
outputs inplace.
outputs inplace.
...
@@ -1029,6 +1022,10 @@ class ScanInplaceOptimizer(Optimizer):
...
@@ -1029,6 +1022,10 @@ class ScanInplaceOptimizer(Optimizer):
Scan node to replace by an inplace version
Scan node to replace by an inplace version
output_indices : list of integers
output_indices : list of integers
Indices of the outputs to attempt to compute inplace
Indices of the outputs to attempt to compute inplace
alloc_ops : list of Op classes
Classes that represent operation that allocate new memory and
that the optimization should duplicate so it can operate inplace
on them.
"""
"""
op
=
node
.
op
op
=
node
.
op
...
@@ -1049,6 +1046,14 @@ class ScanInplaceOptimizer(Optimizer):
...
@@ -1049,6 +1046,14 @@ class ScanInplaceOptimizer(Optimizer):
ls_end
+=
op
.
outer_nitsot
(
node
.
inputs
)
ls_end
+=
op
.
outer_nitsot
(
node
.
inputs
)
ls_end
+=
op
.
outer_non_seqs
(
node
.
inputs
)
ls_end
+=
op
.
outer_non_seqs
(
node
.
inputs
)
# In `ls`, duplicate any input which has more then one client and is
# the output of an eligible allocation op
for
i
in
range
(
len
(
ls
)):
inp
=
ls
[
i
]
if
(
len
(
inp
.
clients
)
>
1
and
inp
.
owner
and
isinstance
(
inp
.
owner
.
op
,
alloc_ops
)):
ls
[
i
]
=
inp
.
owner
.
op
(
*
inp
.
owner
.
inputs
)
n_outs
=
len
(
ls
)
n_outs
=
len
(
ls
)
for
idx
in
xrange
(
n_outs
):
for
idx
in
xrange
(
n_outs
):
if
ls
[
idx
]
in
ls
[:
idx
]:
if
ls
[
idx
]
in
ls
[:
idx
]:
...
@@ -1079,6 +1084,21 @@ class ScanInplaceOptimizer(Optimizer):
...
@@ -1079,6 +1084,21 @@ class ScanInplaceOptimizer(Optimizer):
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
# Depending on the values of gpu_flag and gpua_flag, get the list of
# memory allocation ops that the optimization should be able to handle
alloc_ops
=
(
Alloc
,
AllocEmpty
)
if
self
.
gpu_flag
:
alloc_ops
+=
(
theano
.
sandbox
.
cuda
.
GpuAlloc
,
theano
.
sandbox
.
cuda
.
GpuAllocEmpty
)
if
self
.
gpua_flag
:
# gpuarray might be imported but not its GpuAlloc and
# GpuAllopEmpty ops.
try
:
alloc_ops
+=
(
theano
.
sandbox
.
gpuarray
.
GpuAlloc
,
theano
.
sandbox
.
gpuarray
.
GpuAllocEmpty
)
except
:
pass
nodes
=
fgraph
.
toposort
()[::
-
1
]
nodes
=
fgraph
.
toposort
()[::
-
1
]
scan_nodes
=
[
x
for
x
in
nodes
scan_nodes
=
[
x
for
x
in
nodes
if
(
isinstance
(
x
.
op
,
scan_op
.
Scan
)
and
if
(
isinstance
(
x
.
op
,
scan_op
.
Scan
)
and
...
@@ -1101,7 +1121,18 @@ class ScanInplaceOptimizer(Optimizer):
...
@@ -1101,7 +1121,18 @@ class ScanInplaceOptimizer(Optimizer):
out_indices
=
[]
out_indices
=
[]
for
out_idx
in
range
(
n_outs
):
for
out_idx
in
range
(
n_outs
):
inp_idx
=
1
+
op
.
n_seqs
+
out_idx
inp_idx
=
1
+
op
.
n_seqs
+
out_idx
inp
=
original_node
.
inputs
[
inp_idx
]
# If the input is from an eligible allocation node, attempt to
# be inplace on it, even if other nodes are modifying it
# inplace.
if
inp
.
owner
and
isinstance
(
inp
.
owner
.
op
,
alloc_ops
):
out_indices
.
append
(
out_idx
)
continue
# If the input is not from an eligible allocation node, only
# attempt to be inplace on it if nothing else is currently
# inplace on it.
input_used_inplace
=
False
input_used_inplace
=
False
for
c
in
original_node
.
inputs
[
inp_idx
]
.
clients
:
for
c
in
original_node
.
inputs
[
inp_idx
]
.
clients
:
client
=
c
[
0
]
client
=
c
[
0
]
...
@@ -1122,14 +1153,15 @@ class ScanInplaceOptimizer(Optimizer):
...
@@ -1122,14 +1153,15 @@ class ScanInplaceOptimizer(Optimizer):
out_indices
.
append
(
out_idx
)
out_indices
.
append
(
out_idx
)
node
=
self
.
attempt_scan_inplace
(
fgraph
,
scan_nodes
[
scan_idx
],
node
=
self
.
attempt_scan_inplace
(
fgraph
,
scan_nodes
[
scan_idx
],
out_indices
)
out_indices
,
alloc_ops
)
if
node
is
original_node
:
if
node
is
original_node
:
# Making the scan compute all plausible recurrent outputs
# Making the scan compute all plausible recurrent outputs
# inplace has failed. Attempt all plausible recurrent output
# inplace has failed. Attempt all plausible recurrent output
# individually.
# individually.
for
pos
in
out_indices
:
for
pos
in
out_indices
:
node
=
self
.
attempt_scan_inplace
(
fgraph
,
node
,
[
pos
])
node
=
self
.
attempt_scan_inplace
(
fgraph
,
node
,
[
pos
],
alloc_ops
)
class
ScanSaveMem
(
gof
.
Optimizer
):
class
ScanSaveMem
(
gof
.
Optimizer
):
...
@@ -1242,7 +1274,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1242,7 +1274,7 @@ class ScanSaveMem(gof.Optimizer):
for
cl
,
_
in
out
.
clients
:
for
cl
,
_
in
out
.
clients
:
# 2.1 outputs of the function
# 2.1 outputs of the function
#=> output needs all its intermediate values
#
=> output needs all its intermediate values
if
type
(
cl
)
==
str
:
if
type
(
cl
)
==
str
:
# if the node is actually an output, then
# if the node is actually an output, then
# we need to store the entire thing
# we need to store the entire thing
...
@@ -1250,20 +1282,20 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1250,20 +1282,20 @@ class ScanSaveMem(gof.Optimizer):
slices
[
i
]
=
None
slices
[
i
]
=
None
break
break
# 2.2 non-subtensor nodes
# 2.2 non-subtensor nodes
#=> output needs all its intermediate values
#
=> output needs all its intermediate values
elif
not
isinstance
(
cl
.
op
,
tensor
.
Subtensor
):
elif
not
isinstance
(
cl
.
op
,
tensor
.
Subtensor
):
global_nsteps
=
None
global_nsteps
=
None
slices
[
i
]
=
None
slices
[
i
]
=
None
break
break
# 2.3 subtensor nodes
# 2.3 subtensor nodes
#=> output might need to store just a subset of its values
#
=> output might need to store just a subset of its values
else
:
else
:
# 2.3.1 extract idx list of subtensor
# 2.3.1 extract idx list of subtensor
this_slice
=
tensor
.
get_idx_list
(
cl
.
inputs
,
this_slice
=
tensor
.
get_idx_list
(
cl
.
inputs
,
cl
.
op
.
idx_list
)
cl
.
op
.
idx_list
)
if
this_slice
is
None
:
if
this_slice
is
None
:
# if unable to extract idx_list
# if unable to extract idx_list
#=> outputs needs all its intermediate values
#
=> outputs needs all its intermediate values
global_nsteps
=
None
global_nsteps
=
None
slices
[
i
]
=
None
slices
[
i
]
=
None
break
break
...
@@ -1284,7 +1316,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1284,7 +1316,7 @@ class ScanSaveMem(gof.Optimizer):
slices
[
i
]
+=
[(
cf_slice
,
this_slice
)]
slices
[
i
]
+=
[(
cf_slice
,
this_slice
)]
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
this_slice
[
0
]
.
stop
is
None
):
this_slice
[
0
]
.
stop
is
None
):
global_nsteps
=
None
global_nsteps
=
None
if
isinstance
(
cf_slice
[
0
],
slice
):
if
isinstance
(
cf_slice
[
0
],
slice
):
stop
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
]
.
stop
)
stop
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
]
.
stop
)
...
@@ -1374,7 +1406,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1374,7 +1406,7 @@ class ScanSaveMem(gof.Optimizer):
break
break
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
this_slice
[
0
]
.
start
is
None
):
this_slice
[
0
]
.
start
is
None
):
store_steps
[
i
]
=
0
store_steps
[
i
]
=
0
break
break
...
@@ -1406,8 +1438,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1406,8 +1438,7 @@ class ScanSaveMem(gof.Optimizer):
# for mitsots and sitsots (because mitmots are not
# for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism) and only if
# currently supported by the mechanism) and only if
# the pre-allocation mechanism is activated.
# the pre-allocation mechanism is activated.
prealloc_outs
=
\
prealloc_outs
=
theano
.
config
.
scan
.
allow_output_prealloc
theano
.
config
.
scan
.
allow_output_prealloc
first_mitsot_idx
=
node
.
op
.
n_mit_mot
first_mitsot_idx
=
node
.
op
.
n_mit_mot
last_sitsot_idx
=
(
node
.
op
.
n_mit_mot
+
last_sitsot_idx
=
(
node
.
op
.
n_mit_mot
+
...
@@ -1433,7 +1464,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1433,7 +1464,7 @@ class ScanSaveMem(gof.Optimizer):
# currently.
# currently.
# pval = pre_greedy_local_optimizer(list_opt_slice,
# pval = pre_greedy_local_optimizer(list_opt_slice,
# pval)
# pval)
#pval = pre_constant_merge([pval])[0]
#
pval = pre_constant_merge([pval])[0]
# if (isinstance(pval, theano.tensor.TensorConstant)
# if (isinstance(pval, theano.tensor.TensorConstant)
# and
# and
# pval.dtype.startswith('int')):
# pval.dtype.startswith('int')):
...
@@ -1554,7 +1585,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1554,7 +1585,7 @@ class ScanSaveMem(gof.Optimizer):
nw_steps
)
nw_steps
)
nw_inputs
[
in_idx
]
=
nw_input
nw_inputs
[
in_idx
]
=
nw_input
else
:
else
:
nw_input
=
nw_inputs
[
in_idx
][:(
initl
+
nw_steps
)]
nw_input
=
nw_inputs
[
in_idx
][:(
initl
+
nw_steps
)]
elif
idx
<
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
:
elif
idx
<
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
:
in_idx
=
offset
+
idx
+
op
.
n_shared_outs
in_idx
=
offset
+
idx
+
op
.
n_shared_outs
...
@@ -1563,7 +1594,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1563,7 +1594,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.5 Remove unwanted orphane outputs
# 3.5 Remove unwanted orphane outputs
(
inps
,
outs
,
info
,
node_ins
,
compress_map
)
=
\
(
inps
,
outs
,
info
,
node_ins
,
compress_map
)
=
\
scan_utils
.
compress_outs
(
op
,
not_required
,
nw_inputs
)
scan_utils
.
compress_outs
(
op
,
not_required
,
nw_inputs
)
inv_compress_map
=
OrderedDict
()
inv_compress_map
=
OrderedDict
()
for
k
,
v
in
iteritems
(
compress_map
):
for
k
,
v
in
iteritems
(
compress_map
):
inv_compress_map
[
v
]
=
k
inv_compress_map
[
v
]
=
k
...
@@ -1633,15 +1664,15 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1633,15 +1664,15 @@ class ScanSaveMem(gof.Optimizer):
start
=
(
cnf_slice
[
0
]
.
start
-
nw_steps
-
start
=
(
cnf_slice
[
0
]
.
start
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
])
init_l
[
pos
]
+
store_steps
[
pos
])
if
(
cnf_slice
[
0
]
.
stop
is
not
None
and
if
(
cnf_slice
[
0
]
.
stop
is
not
None
and
cnf_slice
[
0
]
.
stop
!=
maxsize
):
cnf_slice
[
0
]
.
stop
!=
maxsize
):
stop
=
(
cnf_slice
[
0
]
.
stop
-
nw_steps
-
stop
=
(
cnf_slice
[
0
]
.
stop
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
])
init_l
[
pos
]
+
store_steps
[
pos
])
else
:
else
:
stop
=
None
stop
=
None
nw_slice
=
((
slice
(
sanitize
(
start
),
nw_slice
=
((
slice
(
sanitize
(
start
),
sanitize
(
stop
),
sanitize
(
stop
),
sanitize
(
cnf_slice
[
0
]
.
step
)),)
sanitize
(
cnf_slice
[
0
]
.
step
)),)
+
+
tuple
(
old_slices
[
1
:]))
tuple
(
old_slices
[
1
:]))
else
:
else
:
position
=
(
cnf_slice
[
0
]
-
nw_steps
-
position
=
(
cnf_slice
[
0
]
-
nw_steps
-
...
@@ -1662,8 +1693,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1662,8 +1693,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.9. Get replace pairs for all other nodes
# 3.9. Get replace pairs for all other nodes
if
flag_store
or
global_nsteps
is
not
None
:
if
flag_store
or
global_nsteps
is
not
None
:
for
idx
,
o
in
enumerate
(
node
.
outputs
):
for
idx
,
o
in
enumerate
(
node
.
outputs
):
if
not
(
idx
in
replaced_outs
)
and
\
if
not
(
idx
in
replaced_outs
)
and
idx
not
in
not_required
:
not
idx
in
not_required
:
nw_pos
=
compress_map
[
idx
]
nw_pos
=
compress_map
[
idx
]
old_new
+=
[(
o
,
new_outs
[
nw_pos
])]
old_new
+=
[(
o
,
new_outs
[
nw_pos
])]
# Check if the new outputs depend on the old scan node
# Check if the new outputs depend on the old scan node
...
@@ -1808,7 +1838,7 @@ class ScanMerge(gof.Optimizer):
...
@@ -1808,7 +1838,7 @@ class ScanMerge(gof.Optimizer):
flat_inner_outs
=
sum
(
inner_outs
[
idx
],
[])
flat_inner_outs
=
sum
(
inner_outs
[
idx
],
[])
# clone
# clone
flat_inner_ins
,
flat_inner_outs
=
scan_utils
.
reconstruct_graph
(
flat_inner_ins
,
flat_inner_outs
=
scan_utils
.
reconstruct_graph
(
flat_inner_ins
,
flat_inner_outs
)
flat_inner_ins
,
flat_inner_outs
)
# split the new inner variables again in seq, mitmot, etc.
# split the new inner variables again in seq, mitmot, etc.
new_inner_ins
=
[]
new_inner_ins
=
[]
count
=
0
count
=
0
...
@@ -2072,8 +2102,8 @@ def scan_merge_inouts(node):
...
@@ -2072,8 +2102,8 @@ def scan_merge_inouts(node):
# because they could have different sizes, and the corresponding
# because they could have different sizes, and the corresponding
# outer outputs cannot be merged in that case.
# outer outputs cannot be merged in that case.
for
s_outer_i
,
s_inner_o
,
s_outer_o
in
seen
:
for
s_outer_i
,
s_inner_o
,
s_outer_o
in
seen
:
if
(
equal_computations
([
inner_o
],
[
s_inner_o
],
left
,
right
)
if
(
equal_computations
([
inner_o
],
[
s_inner_o
],
left
,
right
)
and
and
outer_i
==
s_outer_i
):
outer_i
==
s_outer_i
):
return
s_outer_o
return
s_outer_o
seen
.
append
((
outer_i
,
inner_o
,
outer_o
))
seen
.
append
((
outer_i
,
inner_o
,
outer_o
))
return
outer_o
return
outer_o
...
@@ -2116,9 +2146,10 @@ def scan_merge_inouts(node):
...
@@ -2116,9 +2146,10 @@ def scan_merge_inouts(node):
na
.
outer_out_mit_mot
,
na
.
outer_out_mit_mot
,
na
.
mit_mot_out_slices
):
na
.
mit_mot_out_slices
):
for
s_outer_imm
,
s_inner_omm
,
s_outer_omm
,
sosl
in
seen
:
for
s_outer_imm
,
s_inner_omm
,
s_outer_omm
,
sosl
in
seen
:
if
(
osl
==
sosl
if
(
osl
==
sosl
and
and
equal_computations
(
inner_omm
,
s_inner_omm
,
left
,
right
)
equal_computations
(
inner_omm
,
s_inner_omm
,
left
,
right
)
and
and
outer_imm
==
s_outer_imm
):
outer_imm
==
s_outer_imm
):
new_outer_out_mit_mot
.
append
(
s_outer_omm
)
new_outer_out_mit_mot
.
append
(
s_outer_omm
)
break
break
else
:
else
:
...
@@ -2168,17 +2199,15 @@ class PushOutDot1(gof.Optimizer):
...
@@ -2168,17 +2199,15 @@ class PushOutDot1(gof.Optimizer):
inp
in
out
.
owner
.
inputs
and
inp
in
out
.
owner
.
inputs
and
len
(
outer_out
.
clients
)
==
1
and
len
(
outer_out
.
clients
)
==
1
and
not
isinstance
(
outer_out
.
clients
[
0
][
0
],
str
)
and
not
isinstance
(
outer_out
.
clients
[
0
][
0
],
str
)
and
isinstance
(
outer_out
.
clients
[
0
][
0
]
.
op
,
theano
.
tensor
.
Subtensor
)
isinstance
(
outer_out
.
clients
[
0
][
0
]
.
op
,
theano
.
tensor
.
Subtensor
)
and
and
outer_out
.
clients
[
0
][
0
]
.
op
.
idx_list
==
(
-
1
,)):
outer_out
.
clients
[
0
][
0
]
.
op
.
idx_list
==
(
-
1
,)):
x
=
out
.
owner
.
inputs
[
0
]
x
=
out
.
owner
.
inputs
[
0
]
if
x
==
inp
:
if
x
==
inp
:
x
=
out
.
owner
.
inputs
[
1
]
x
=
out
.
owner
.
inputs
[
1
]
# We need to check if x is the result of an outer product
# We need to check if x is the result of an outer product
if
(
x
.
owner
and
if
(
x
.
owner
and
isinstance
(
x
.
owner
.
op
,
theano
.
tensor
.
Dot
)
and
isinstance
(
x
.
owner
.
op
,
theano
.
tensor
.
Dot
)
and
x
.
owner
.
inputs
[
0
]
.
ndim
==
2
and
x
.
owner
.
inputs
[
1
]
.
ndim
==
2
):
x
.
owner
.
inputs
[
0
]
.
ndim
==
2
and
x
.
owner
.
inputs
[
1
]
.
ndim
==
2
):
# We need to check if any of the inputs are a sequence
# We need to check if any of the inputs are a sequence
inp1
=
x
.
owner
.
inputs
[
0
]
inp1
=
x
.
owner
.
inputs
[
0
]
...
@@ -2219,18 +2248,17 @@ class PushOutDot1(gof.Optimizer):
...
@@ -2219,18 +2248,17 @@ class PushOutDot1(gof.Optimizer):
new_info
=
op
.
info
.
copy
()
new_info
=
op
.
info
.
copy
()
st
=
len
(
op
.
mitmot_taps
())
+
len
(
op
.
mitsot_taps
())
st
=
len
(
op
.
mitmot_taps
())
+
len
(
op
.
mitsot_taps
())
new_info
[
'tap_array'
]
=
(
\
new_info
[
'tap_array'
]
=
(
new_info
[
'tap_array'
][:
st
+
idx
]
+
new_info
[
'tap_array'
][:
st
+
idx
]
+
new_info
[
'tap_array'
][
st
+
new_info
[
'tap_array'
][
st
+
idx
+
1
:])
idx
+
1
:])
new_info
[
'n_sit_sot'
]
-=
1
new_info
[
'n_sit_sot'
]
-=
1
new_info
[
'n_nit_sot'
]
+=
1
new_info
[
'n_nit_sot'
]
+=
1
inner_sitsot
=
inner_sitsot
[:
idx
]
+
\
inner_sitsot
=
(
inner_sitsot
[:
idx
]
+
inner_sitsot
[
idx
+
1
:]
inner_sitsot
[
idx
+
1
:])
outer_sitsot
=
outer_sitsot
[:
idx
]
+
\
outer_sitsot
=
(
outer_sitsot
[:
idx
]
+
outer_sitsot
[
idx
+
1
:]
outer_sitsot
[
idx
+
1
:])
inner_sitsot_outs
=
inner_sitsot_outs
[:
idx
]
+
\
inner_sitsot_outs
=
(
inner_sitsot_outs
[:
idx
]
+
inner_sitsot_outs
[
idx
+
1
:]
inner_sitsot_outs
[
idx
+
1
:])
# add n_steps as the length
# add n_steps as the length
inner_nitsot_outs
.
append
(
new_scan_out
)
inner_nitsot_outs
.
append
(
new_scan_out
)
...
@@ -2246,8 +2274,8 @@ class PushOutDot1(gof.Optimizer):
...
@@ -2246,8 +2274,8 @@ class PushOutDot1(gof.Optimizer):
inner_nitsot_outs
+
inner_nitsot_outs
+
inner_shared_outs
)
inner_shared_outs
)
new_inner_inps
,
new_inner_outs
=
\
new_inner_inps
,
new_inner_outs
=
\
scan_utils
.
reconstruct_graph
(
scan_utils
.
reconstruct_graph
(
_new_inner_inps
,
_new_inner_inps
,
_new_inner_outs
)
_new_inner_outs
)
new_op
=
scan_op
.
Scan
(
new_inner_inps
,
new_inner_outs
,
new_op
=
scan_op
.
Scan
(
new_inner_inps
,
new_inner_outs
,
new_info
)
new_info
)
_scan_inputs
=
([
node
.
inputs
[
0
]]
+
_scan_inputs
=
([
node
.
inputs
[
0
]]
+
...
@@ -2267,11 +2295,7 @@ class PushOutDot1(gof.Optimizer):
...
@@ -2267,11 +2295,7 @@ class PushOutDot1(gof.Optimizer):
# We need now to pair correctly the new outputs
# We need now to pair correctly the new outputs
# with the old ones
# with the old ones
outer_mitmot_outs
=
new_op
.
outer_mitmot_outs
(
new_outs
)
outer_mitsot_outs
=
new_op
.
outer_mitsot_outs
(
new_outs
)
outer_sitsot_outs
=
new_op
.
outer_sitsot_outs
(
new_outs
)
outer_nitsot_outs
=
new_op
.
outer_nitsot_outs
(
new_outs
)
outer_nitsot_outs
=
new_op
.
outer_nitsot_outs
(
new_outs
)
outer_shared_outs
=
new_op
.
outer_shared_outs
(
new_outs
)
_val
=
outer_nitsot_outs
[
-
1
]
_val
=
outer_nitsot_outs
[
-
1
]
outer_nitsot_outs
=
outer_nitsot_outs
[:
-
1
]
outer_nitsot_outs
=
outer_nitsot_outs
[:
-
1
]
...
@@ -2305,7 +2329,7 @@ class PushOutDot1(gof.Optimizer):
...
@@ -2305,7 +2329,7 @@ class PushOutDot1(gof.Optimizer):
old_new
=
list
(
zip
(
node
.
outputs
[:
pos
],
new_outs
[:
pos
]))
old_new
=
list
(
zip
(
node
.
outputs
[:
pos
],
new_outs
[:
pos
]))
old
=
node
.
outputs
[
pos
]
.
clients
[
0
][
0
]
.
outputs
[
0
]
old
=
node
.
outputs
[
pos
]
.
clients
[
0
][
0
]
.
outputs
[
0
]
old_new
.
append
((
old
,
new_out
))
old_new
.
append
((
old
,
new_out
))
old_new
+=
list
(
zip
(
node
.
outputs
[
pos
+
1
:],
old_new
+=
list
(
zip
(
node
.
outputs
[
pos
+
1
:],
new_outs
[
pos
:]))
new_outs
[
pos
:]))
fgraph
.
replace_all_validate_remove
(
fgraph
.
replace_all_validate_remove
(
old_new
,
remove
=
[
node
],
reason
=
'scan_pushout_dot1'
)
old_new
,
remove
=
[
node
],
reason
=
'scan_pushout_dot1'
)
...
@@ -2374,61 +2398,61 @@ scan_seqopt1.register('scanOp_pushout_output',
...
@@ -2374,61 +2398,61 @@ scan_seqopt1.register('scanOp_pushout_output',
scan_eqopt2
.
register
(
'constant_folding_for_scan2'
,
scan_eqopt2
.
register
(
'constant_folding_for_scan2'
,
opt
.
in2out
(
tensor
.
opt
.
constant_folding
,
opt
.
in2out
(
tensor
.
opt
.
constant_folding
,
ignore_newtrees
=
True
),
ignore_newtrees
=
True
),
1
,
1
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
scan_eqopt2
.
register
(
'scanOp_remove_constants_and_unused_inputs1'
,
scan_eqopt2
.
register
(
'scanOp_remove_constants_and_unused_inputs1'
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
ignore_newtrees
=
True
),
2
,
2
,
'remove_constants_and_unused_inputs_scan'
,
'remove_constants_and_unused_inputs_scan'
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
# after const merge but before stabilize so that we can have identity
# after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out
# for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later.
# of the scan later.
scan_eqopt2
.
register
(
'scanOp_merge'
,
scan_eqopt2
.
register
(
'scanOp_merge'
,
ScanMerge
(),
ScanMerge
(),
4
,
4
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
# After Merge optimization
# After Merge optimization
scan_eqopt2
.
register
(
'scanop_remove_constants_and_unused_inputs2'
,
scan_eqopt2
.
register
(
'scanop_remove_constants_and_unused_inputs2'
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
ignore_newtrees
=
True
),
5
,
5
,
'remove_constants_and_unused_inputs_scan'
,
'remove_constants_and_unused_inputs_scan'
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
scan_eqopt2
.
register
(
'scanOp_merge_inouts'
,
scan_eqopt2
.
register
(
'scanOp_merge_inouts'
,
opt
.
in2out
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
opt
.
in2out
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
6
,
6
,
'scan_merge_inouts'
,
'scan_merge_inouts'
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
# Just before specialize to have the other optimization
# Just before specialize to have the other optimization
# like constant folding being applied
# like constant folding being applied
# This don't introduce inplace.
# This don't introduce inplace.
scan_eqopt2
.
register
(
'scanOp_save_mem'
,
scan_eqopt2
.
register
(
'scanOp_save_mem'
,
ScanSaveMem
(),
ScanSaveMem
(),
7
,
7
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
# After everything else
# After everything else
scan_eqopt2
.
register
(
'scanOp_remove_constants_and_unused_inputs3'
,
scan_eqopt2
.
register
(
'scanOp_remove_constants_and_unused_inputs3'
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
ignore_newtrees
=
True
),
8
,
8
,
'remove_constants_and_unused_inputs_scan'
,
'remove_constants_and_unused_inputs_scan'
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
theano/tests/test_flake8.py
浏览文件 @
c8dc3dbe
...
@@ -164,7 +164,6 @@ whitelist_flake8 = [
...
@@ -164,7 +164,6 @@ whitelist_flake8 = [
"scan_module/scan_op.py"
,
"scan_module/scan_op.py"
,
"scan_module/scan_perform_ext.py"
,
"scan_module/scan_perform_ext.py"
,
"scan_module/__init__.py"
,
"scan_module/__init__.py"
,
"scan_module/scan_opt.py"
,
"scan_module/tests/test_scan.py"
,
"scan_module/tests/test_scan.py"
,
"scan_module/tests/test_scan_opt.py"
,
"scan_module/tests/test_scan_opt.py"
,
"misc/tests/test_may_share_memory.py"
,
"misc/tests/test_may_share_memory.py"
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论