Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d4dfbf2a
提交
d4dfbf2a
authored
4月 19, 2012
作者:
nouiz
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #612 from lamblin/merge_feature_rebased
Merge feature (rebased)
上级
a96d5716
c42d1808
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
531 行增加
和
289 行删除
+531
-289
mode.py
theano/compile/mode.py
+132
-79
opt.py
theano/gof/opt.py
+399
-209
opt.py
theano/tensor/opt.py
+0
-1
没有找到文件。
theano/compile/mode.py
浏览文件 @
d4dfbf2a
"""WRITEME
"""WRITEME
"""
"""
import
os
,
logging
,
warnings
import
logging
import
numpy
,
theano
import
numpy
import
theano
from
theano
import
gof
from
theano
import
gof
import
theano.gof.vm
import
theano.gof.vm
from
theano.configparser
import
config
,
AddConfigVar
,
StrParam
,
EnumStr
from
theano.configparser
import
config
,
AddConfigVar
,
StrParam
_logger
=
logging
.
getLogger
(
'theano.compile.mode'
)
_logger
=
logging
.
getLogger
(
'theano.compile.mode'
)
AddConfigVar
(
'optimizer_excluding'
,
AddConfigVar
(
'optimizer_excluding'
,
"When using the default mode, we will remove optimizer with that tag. Separate many tags with ':'."
,
(
"When using the default mode, we will remove optimizer with these "
"tags. Separate tags with ':'."
),
StrParam
(
""
,
allow_override
=
False
),
StrParam
(
""
,
allow_override
=
False
),
in_c_key
=
False
)
in_c_key
=
False
)
AddConfigVar
(
'optimizer_including'
,
AddConfigVar
(
'optimizer_including'
,
"When using the default mode, we will add optimizer with that tag. Separate many tags with ':'."
,
(
"When using the default mode, we will add optimizer with these tags. "
"Separate tags with ':'."
),
StrParam
(
""
,
allow_override
=
False
),
StrParam
(
""
,
allow_override
=
False
),
in_c_key
=
False
)
in_c_key
=
False
)
AddConfigVar
(
'optimizer_requiring'
,
AddConfigVar
(
'optimizer_requiring'
,
"When using the default mode, we will require optimizer with that tag. Separate many tags with ':'."
,
(
"When using the default mode, we will require optimizer with these "
"tags. Separate tags with ':'."
),
StrParam
(
""
,
allow_override
=
False
),
StrParam
(
""
,
allow_override
=
False
),
in_c_key
=
False
)
in_c_key
=
False
)
def
check_equal
(
x
,
y
):
def
check_equal
(
x
,
y
):
"""
"""
Returns True iff x[0] and y[0] are equal (checks the dtype and
Returns True iff x[0] and y[0] are equal (checks the dtype and
...
@@ -32,35 +38,37 @@ def check_equal(x, y):
...
@@ -32,35 +38,37 @@ def check_equal(x, y):
import
scipy.sparse
as
sp
import
scipy.sparse
as
sp
x
,
y
=
x
[
0
],
y
[
0
]
x
,
y
=
x
[
0
],
y
[
0
]
# TODO: bug in current scipy, two sparse matrices are never equal, remove when moving to 0.7
# TODO: bug in current scipy, two sparse matrices are never equal,
# remove when moving to 0.7
if
sp
.
issparse
(
x
):
if
sp
.
issparse
(
x
):
x
=
x
.
todense
()
x
=
x
.
todense
()
if
sp
.
issparse
(
y
):
if
sp
.
issparse
(
y
):
y
=
y
.
todense
()
y
=
y
.
todense
()
if
isinstance
(
x
,
numpy
.
ndarray
)
and
isinstance
(
y
,
numpy
.
ndarray
):
if
isinstance
(
x
,
numpy
.
ndarray
)
and
isinstance
(
y
,
numpy
.
ndarray
):
if
x
.
dtype
!=
y
.
dtype
or
x
.
shape
!=
y
.
shape
or
numpy
.
any
(
abs
(
x
-
y
)
>
1e-10
):
if
(
x
.
dtype
!=
y
.
dtype
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
,
'clinker'
:
y
})
or
x
.
shape
!=
y
.
shape
or
numpy
.
any
(
abs
(
x
-
y
)
>
1e-10
)):
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
,
'clinker'
:
y
})
else
:
else
:
if
x
!=
y
:
if
x
!=
y
:
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
,
'clinker'
:
y
})
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
,
'clinker'
:
y
})
# If a string is passed as the linker argument in the constructor for
# If a string is passed as the linker argument in the constructor for
# Mode, it will be used as the key to retrieve the real linker in this
# Mode, it will be used as the key to retrieve the real linker in this
# dictionary
# dictionary
predefined_linkers
=
{
predefined_linkers
=
{
'py'
:
gof
.
PerformLinker
(),
'py'
:
gof
.
PerformLinker
(),
'c'
:
gof
.
CLinker
(),
'c'
:
gof
.
CLinker
(),
'c|py'
:
gof
.
OpWiseCLinker
(
allow_gc
=
True
),
'c|py'
:
gof
.
OpWiseCLinker
(
allow_gc
=
True
),
'c|py_nogc'
:
gof
.
OpWiseCLinker
(
allow_gc
=
False
),
'c|py_nogc'
:
gof
.
OpWiseCLinker
(
allow_gc
=
False
),
'c&py'
:
gof
.
DualLinker
(
checker
=
check_equal
),
'c&py'
:
gof
.
DualLinker
(
checker
=
check_equal
),
'vm'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
True
,
use_cloop
=
False
),
'vm'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
True
,
use_cloop
=
False
),
'cvm'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
True
,
use_cloop
=
True
),
'cvm'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
True
,
use_cloop
=
True
),
'vm_nogc'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
False
,
use_cloop
=
False
),
'vm_nogc'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
False
,
use_cloop
=
False
),
'cvm_nogc'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
False
,
use_cloop
=
True
),
'cvm_nogc'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
False
,
use_cloop
=
True
),
}
}
...
@@ -72,37 +80,37 @@ def register_linker(name, linker):
...
@@ -72,37 +80,37 @@ def register_linker(name, linker):
predefined_linkers
[
name
]
=
linker
predefined_linkers
[
name
]
=
linker
# If a string is passed as the optimizer argument in the constructor
# If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
# in this dictionary
OPT_FAST_RUN
=
gof
.
Query
(
include
=
[
'fast_run'
])
OPT_FAST_RUN
=
gof
.
Query
(
include
=
[
'fast_run'
])
OPT_FAST_RUN_STABLE
=
OPT_FAST_RUN
.
requiring
(
'stable'
)
OPT_FAST_RUN_STABLE
=
OPT_FAST_RUN
.
requiring
(
'stable'
)
OPT_FAST_COMPILE
=
gof
.
Query
(
include
=
[
'fast_compile'
])
OPT_FAST_COMPILE
=
gof
.
Query
(
include
=
[
'fast_compile'
])
OPT_STABILIZE
=
gof
.
Query
(
include
=
[
'fast_run'
])
OPT_STABILIZE
=
gof
.
Query
(
include
=
[
'fast_run'
])
OPT_STABILIZE
.
position_cutoff
=
1.5000001
OPT_STABILIZE
.
position_cutoff
=
1.5000001
predefined_optimizers
=
{
predefined_optimizers
=
{
None
:
lambda
env
:
None
,
None
:
(
lambda
env
:
None
)
,
'None'
:
lambda
env
:
None
,
'None'
:
(
lambda
env
:
None
)
,
'merge'
:
gof
.
MergeOptimizer
(),
'merge'
:
gof
.
MergeOptimizer
(),
'fast_run'
:
OPT_FAST_RUN
,
'fast_run'
:
OPT_FAST_RUN
,
'fast_run_stable'
:
OPT_FAST_RUN_STABLE
,
'fast_run_stable'
:
OPT_FAST_RUN_STABLE
,
'fast_compile'
:
OPT_FAST_COMPILE
,
'fast_compile'
:
OPT_FAST_COMPILE
,
'stabilize'
:
OPT_STABILIZE
'stabilize'
:
OPT_STABILIZE
}
}
def
register_optimizer
(
name
,
opt
):
def
register_optimizer
(
name
,
opt
):
"""Add a `Optimizer` which can be referred to by `name` in `Mode`."""
"""Add a `Optimizer` which can be referred to by `name` in `Mode`."""
if
name
in
predefined_optimizers
:
if
name
in
predefined_optimizers
:
raise
ValueError
(
'Optimizer name already taken:
%
s'
%
name
)
raise
ValueError
(
'Optimizer name already taken:
%
s'
%
name
)
predefined_optimizers
[
name
]
=
opt
predefined_optimizers
[
name
]
=
opt
def
register_OutputGuard_c_code
(
type
):
def
register_OutputGuard_c_code
(
type
):
OutputGuard
.
c_code_types
.
append
(
type
)
OutputGuard
.
c_code_types
.
append
(
type
)
class
OutputGuard
(
gof
.
Op
):
class
OutputGuard
(
gof
.
Op
):
"""
"""
This op is used only internally by Theano.
This op is used only internally by Theano.
...
@@ -120,20 +128,24 @@ class OutputGuard(gof.Op):
...
@@ -120,20 +128,24 @@ class OutputGuard(gof.Op):
TODO: find a current full explanation.
TODO: find a current full explanation.
"""
"""
destroy_map
=
{
0
:[
0
]}
destroy_map
=
{
0
:
[
0
]}
view_map
=
{
0
:[
0
]}
view_map
=
{
0
:
[
0
]}
c_code_types
=
[]
c_code_types
=
[]
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
return
gof
.
Apply
(
self
,
[
x
],
[
x
.
type
()])
return
gof
.
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
def
perform
(
self
,
node
,
inp
,
out
):
def
perform
(
self
,
node
,
inp
,
out
):
x
,
=
inp
x
,
=
inp
z
,
=
out
z
,
=
out
z
[
0
]
=
x
z
[
0
]
=
x
def
__str__
(
self
):
def
__str__
(
self
):
return
'
%
s'
%
self
.
__class__
.
__name__
return
'
%
s'
%
self
.
__class__
.
__name__
...
@@ -141,7 +153,8 @@ class OutputGuard(gof.Op):
...
@@ -141,7 +153,8 @@ class OutputGuard(gof.Op):
x
,
=
inp
x
,
=
inp
z
,
=
out
z
,
=
out
if
isinstance
(
node
.
inputs
[
0
]
.
type
,
theano
.
scalar
.
Scalar
):
if
isinstance
(
node
.
inputs
[
0
]
.
type
,
theano
.
scalar
.
Scalar
):
# Scalars are C objects on the stacks, and should not be inc/decrefed
# Scalars are C objects on the stack,
# and should not be inc/decrefed
return
"""
return
"""
%(z)
s =
%(x)
s;
%(z)
s =
%(x)
s;
"""
%
locals
()
"""
%
locals
()
...
@@ -161,71 +174,99 @@ class OutputGuard(gof.Op):
...
@@ -161,71 +174,99 @@ class OutputGuard(gof.Op):
_output_guard
=
OutputGuard
()
_output_guard
=
OutputGuard
()
class
AddDestroyHandler
(
gof
.
Optimizer
):
class
AddDestroyHandler
(
gof
.
Optimizer
):
"""This optimizer performs two important functions:
"""This optimizer performs two important functions:
1) it has a 'requirement' of the destroyhandler. This means that the env will include it
1) it has a 'requirement' of the destroyhandler. This means that the env
as a feature for this optimization, and keep this feature enabled for subsequent
will include it as a feature for this optimization, and keep this feature
optimizations. All optimizations that work inplace on any of their inputs must run *after*
enabled for subsequent optimizations. All optimizations that work inplace
this optimization to ensure that the DestroyHandler has been included in the env.
on any of their inputs must run *after* this optimization to ensure that
the DestroyHandler has been included in the env.
2) It tries to replace each output with an Op that purports to destroy it
(but it won't I
2) It tries to replace each output with an Op that purports to destroy it
promise). If this replacement succeeds it means that there is a bug in theano. It should
(but it won't I promise). If this replacement succeeds it means that
not be possible to destroy outputs.
there is a bug in theano. It should
not be possible to destroy outputs.
"""
"""
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
for
o
in
env
.
outputs
:
for
o
in
env
.
outputs
:
try
:
try
:
env
.
replace_validate
(
o
,
_output_guard
(
o
),
reason
=
'output_guard'
)
env
.
replace_validate
(
o
,
_output_guard
(
o
),
_logger
.
info
(
"Output variable
%
s required output_guard,"
reason
=
'output_guard'
)
" how was this output left unprotected against destructive operations?"
_logger
.
info
(
"Output variable
%
s required output_guard, "
"how was this output left unprotected against "
"destructive operations?"
%
o
)
%
o
)
except
gof
.
InconsistencyError
:
except
gof
.
InconsistencyError
:
#this output is already impossible to destroy. no guard necessary
# This output is already impossible to destroy.
# No guard necessary
pass
pass
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
super
(
AddDestroyHandler
,
self
)
.
add_requirements
(
env
)
super
(
AddDestroyHandler
,
self
)
.
add_requirements
(
env
)
env
.
extend
(
gof
.
DestroyHandler
())
env
.
extend
(
gof
.
DestroyHandler
())
class
PrintCurrentEnv
(
gof
.
Optimizer
):
class
PrintCurrentEnv
(
gof
.
Optimizer
):
"""This optimizer is for debugging.
"""This optimizer is for debugging.
Toss it into the optimization pipeline to see the state of things at any given point.
Toss it into the optimization pipeline to see the state of things at any
given point.
"""
"""
def
__init__
(
self
,
header
):
def
__init__
(
self
,
header
):
self
.
header
=
header
self
.
header
=
header
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
import
theano.printing
import
theano.printing
print
"PrintCurrentEnv:"
,
self
.
header
print
"PrintCurrentEnv:"
,
self
.
header
theano
.
printing
.
debugprint
(
env
.
outputs
)
theano
.
printing
.
debugprint
(
env
.
outputs
)
optdb
=
gof
.
SequenceDB
()
optdb
=
gof
.
SequenceDB
()
optdb
.
register
(
'merge1'
,
gof
.
MergeOptimizer
(),
optdb
.
register
(
'merge1'
,
gof
.
MergeOptimizer
(),
0
,
'fast_run'
,
'fast_compile'
)
0
,
'fast_run'
,
'fast_compile'
)
optdb
.
register
(
'canonicalize'
,
gof
.
EquilibriumDB
(),
# rearranges elemwise expressions
# rearranges elemwise expressions
optdb
.
register
(
'canonicalize'
,
gof
.
EquilibriumDB
(),
1
,
'fast_run'
,
'fast_compile'
)
1
,
'fast_run'
,
'fast_compile'
)
optdb
.
register
(
'merge1.2'
,
gof
.
MergeOptimizer
(
skip_const_merge
=
False
),
optdb
.
register
(
'merge1.2'
,
gof
.
MergeOptimizer
(),
1.2
,
'fast_run'
,
'fast_compile'
)
1.2
,
'fast_run'
,
'fast_compile'
)
optdb
.
register
(
'Print1.21'
,
PrintCurrentEnv
(
'Post-canonicalize'
),
optdb
.
register
(
'Print1.21'
,
PrintCurrentEnv
(
'Post-canonicalize'
),
1.21
,)
# 'fast_run', 'fast_compile')
1.21
,)
# 'fast_run', 'fast_compile')
optdb
.
register
(
'stabilize'
,
gof
.
EquilibriumDB
(),
# replace unstable subgraphs
# replace unstable subgraphs
optdb
.
register
(
'stabilize'
,
gof
.
EquilibriumDB
(),
1.5
,
'fast_run'
)
1.5
,
'fast_run'
)
optdb
.
register
(
'Print1.51'
,
PrintCurrentEnv
(
'Post-stabilize'
),
optdb
.
register
(
'Print1.51'
,
PrintCurrentEnv
(
'Post-stabilize'
),
1.51
,)
#'fast_run', 'fast_compile')
1.51
,)
# 'fast_run', 'fast_compile')
optdb
.
register
(
'specialize'
,
gof
.
EquilibriumDB
(),
# misc special cases for speed
# misc special cases for speed
optdb
.
register
(
'specialize'
,
gof
.
EquilibriumDB
(),
2
,
'fast_run'
)
2
,
'fast_run'
)
optdb
.
register
(
'Print2.01'
,
PrintCurrentEnv
(
'Post-specialize'
),
optdb
.
register
(
'Print2.01'
,
PrintCurrentEnv
(
'Post-specialize'
),
2.01
,
)
#'fast_run', 'fast_compile')
2.01
,)
# 'fast_run', 'fast_compile')
optdb
.
register
(
'uncanonicalize'
,
gof
.
EquilibriumDB
(),
# misc special cases for speed that break canonicalization
# misc special cases for speed that break canonicalization
optdb
.
register
(
'uncanonicalize'
,
gof
.
EquilibriumDB
(),
3
,
'fast_run'
)
3
,
'fast_run'
)
optdb
.
register
(
'specialize_device'
,
gof
.
EquilibriumDB
(),
# misc special cases for speed that are dependent on the device.
48.6
,
'fast_run'
)
#must be after gpu stuff at 48.5
# misc special cases for speed that are dependent on the device.
optdb
.
register
(
'merge2'
,
gof
.
MergeOptimizer
(),
# especially constant merge
optdb
.
register
(
'specialize_device'
,
gof
.
EquilibriumDB
(),
48.6
,
'fast_run'
)
# must be after gpu stuff at 48.5
# especially constant merge
optdb
.
register
(
'merge2'
,
gof
.
MergeOptimizer
(),
49
,
'fast_run'
)
49
,
'fast_run'
)
optdb
.
register
(
'add_destroy_handler'
,
AddDestroyHandler
(),
optdb
.
register
(
'add_destroy_handler'
,
AddDestroyHandler
(),
49.5
,
'fast_run'
,
'inplace'
)
49.5
,
'fast_run'
,
'inplace'
)
optdb
.
register
(
'merge3'
,
gof
.
MergeOptimizer
(),
# final pass just to make sure
# final pass just to make sure
optdb
.
register
(
'merge3'
,
gof
.
MergeOptimizer
(),
100
,
'fast_run'
)
100
,
'fast_run'
)
...
@@ -251,12 +292,15 @@ class Mode(object):
...
@@ -251,12 +292,15 @@ class Mode(object):
if
optimizer
is
None
:
if
optimizer
is
None
:
optimizer
=
config
.
optimizer
optimizer
=
config
.
optimizer
self
.
__setstate__
((
linker
,
optimizer
))
self
.
__setstate__
((
linker
,
optimizer
))
#self.provided_optimizer - typically the `optimizer` arg. But if the `optimizer` arg is
# keyword corresponding to a predefined Query, then this stores the query
#self._optimizer - typically same as provided_optimizer??
#self.__get_optimizer - returns self._optimizer (possibly querying optdb with self._optimizer)
# self.provided_optimizer - typically the `optimizer` arg.
#self.optimizer - property that returns __get_optimizer()
# But if the `optimizer` arg is keyword corresponding to a predefined
# Query, then this stores the query
# self._optimizer - typically same as provided_optimizer??
# self.__get_optimizer - returns self._optimizer (possibly querying
# optdb with self._optimizer)
# self.optimizer - property that returns __get_optimizer()
def
__getstate__
(
self
):
def
__getstate__
(
self
):
return
(
self
.
provided_linker
,
self
.
provided_optimizer
)
return
(
self
.
provided_linker
,
self
.
provided_optimizer
)
...
@@ -275,12 +319,13 @@ class Mode(object):
...
@@ -275,12 +319,13 @@ class Mode(object):
self
.
_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
self
.
call_time
=
0
self
.
call_time
=
0
self
.
fn_time
=
0
self
.
fn_time
=
0
linker
.
mode
=
self
#
TODO: WHY IS THIS HERE?
linker
.
mode
=
self
#
TODO: WHY IS THIS HERE?
self
.
optimizer_time
=
0
self
.
optimizer_time
=
0
self
.
linker_time
=
0
self
.
linker_time
=
0
def
__str__
(
self
):
def
__str__
(
self
):
return
"Mode(linker =
%
s, optimizer =
%
s)"
%
(
self
.
provided_linker
,
self
.
provided_optimizer
)
return
"Mode(linker =
%
s, optimizer =
%
s)"
%
(
self
.
provided_linker
,
self
.
provided_optimizer
)
def
__get_optimizer
(
self
):
def
__get_optimizer
(
self
):
if
isinstance
(
self
.
_optimizer
,
gof
.
Query
):
if
isinstance
(
self
.
_optimizer
,
gof
.
Query
):
...
@@ -298,17 +343,20 @@ class Mode(object):
...
@@ -298,17 +343,20 @@ class Mode(object):
return
(
linker
,
optimizer
)
return
(
linker
,
optimizer
)
def
including
(
self
,
*
tags
):
def
including
(
self
,
*
tags
):
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
)
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
)
#N.B. opt might be a Query instance, not sure what else it might be...
#N.B. opt might be a Query instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows???
# string? Optimizer? OptDB? who knows???
return
self
.
__class__
(
linker
=
link
,
optimizer
=
opt
.
including
(
*
tags
))
return
self
.
__class__
(
linker
=
link
,
optimizer
=
opt
.
including
(
*
tags
))
def
excluding
(
self
,
*
tags
):
def
excluding
(
self
,
*
tags
):
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
)
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
)
return
self
.
__class__
(
linker
=
link
,
optimizer
=
opt
.
excluding
(
*
tags
))
return
self
.
__class__
(
linker
=
link
,
optimizer
=
opt
.
excluding
(
*
tags
))
def
requiring
(
self
,
*
tags
):
def
requiring
(
self
,
*
tags
):
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
)
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
)
return
self
.
__class__
(
linker
=
link
,
optimizer
=
opt
.
requiring
(
*
tags
))
return
self
.
__class__
(
linker
=
link
,
optimizer
=
opt
.
requiring
(
*
tags
))
# If a string is passed as the mode argument in function or
# If a string is passed as the mode argument in function or
...
@@ -321,20 +369,22 @@ predefined_modes = {'FAST_COMPILE': FAST_COMPILE,
...
@@ -321,20 +369,22 @@ predefined_modes = {'FAST_COMPILE': FAST_COMPILE,
'FAST_RUN'
:
FAST_RUN
,
'FAST_RUN'
:
FAST_RUN
,
}
}
instanciated_default_mode
=
None
instanciated_default_mode
=
None
def
get_mode
(
orig_string
):
def
get_mode
(
orig_string
):
if
orig_string
is
None
:
if
orig_string
is
None
:
string
=
config
.
mode
string
=
config
.
mode
else
:
else
:
string
=
orig_string
string
=
orig_string
if
not
isinstance
(
string
,
basestring
):
if
not
isinstance
(
string
,
basestring
):
return
string
#
it is hopefully already a mode...
return
string
#
it is hopefully already a mode...
global
instanciated_default_mode
global
instanciated_default_mode
# The default mode is cached. However, config.mode can change
# The default mode is cached. However, config.mode can change
# If instanciated_default_mode has the right class, use it.
# If instanciated_default_mode has the right class, use it.
if
orig_string
is
None
and
instanciated_default_mode
:
if
orig_string
is
None
and
instanciated_default_mode
:
if
predefined_modes
.
has_key
(
string
)
:
if
string
in
predefined_modes
:
default_mode_class
=
predefined_modes
[
string
]
.
__class__
.
__name__
default_mode_class
=
predefined_modes
[
string
]
.
__class__
.
__name__
else
:
else
:
default_mode_class
=
string
default_mode_class
=
string
...
@@ -342,7 +392,7 @@ def get_mode(orig_string):
...
@@ -342,7 +392,7 @@ def get_mode(orig_string):
default_mode_class
):
default_mode_class
):
return
instanciated_default_mode
return
instanciated_default_mode
if
string
in
[
'Mode'
,
'ProfileMode'
,
'DebugMode'
]:
if
string
in
[
'Mode'
,
'ProfileMode'
,
'DebugMode'
]:
if
string
==
'DebugMode'
:
if
string
==
'DebugMode'
:
#need to import later to break circular dependency.
#need to import later to break circular dependency.
from
debugmode
import
DebugMode
from
debugmode
import
DebugMode
...
@@ -350,12 +400,13 @@ def get_mode(orig_string):
...
@@ -350,12 +400,13 @@ def get_mode(orig_string):
ret
=
DebugMode
(
optimizer
=
config
.
optimizer
)
ret
=
DebugMode
(
optimizer
=
config
.
optimizer
)
else
:
else
:
# The import is needed in case string is ProfileMode
# The import is needed in case string is ProfileMode
from
profilemode
import
ProfileMode
,
prof_mode_instance_to_print
from
profilemode
import
ProfileMode
,
prof_mode_instance_to_print
ret
=
eval
(
string
+
'(linker=config.linker, optimizer=config.optimizer)'
)
ret
=
eval
(
string
elif
predefined_modes
.
has_key
(
string
):
+
'(linker=config.linker, optimizer=config.optimizer)'
)
elif
string
in
predefined_modes
:
ret
=
predefined_modes
[
string
]
ret
=
predefined_modes
[
string
]
else
:
else
:
raise
Exception
(
"No predefined mode exist for string:
%
s"
%
string
)
raise
Exception
(
"No predefined mode exist for string:
%
s"
%
string
)
if
orig_string
is
None
:
if
orig_string
is
None
:
# Build and cache the default mode
# Build and cache the default mode
...
@@ -374,12 +425,14 @@ def get_mode(orig_string):
...
@@ -374,12 +425,14 @@ def get_mode(orig_string):
return
ret
return
ret
def
get_default_mode
():
def
get_default_mode
():
return
get_mode
(
None
)
return
get_mode
(
None
)
# Removed: use config.mode instead.
# Removed: use config.mode instead.
#default_mode = config.mode
#default_mode = config.mode
def
register_mode
(
name
,
mode
):
def
register_mode
(
name
,
mode
):
"""Add a `Mode` which can be referred to by `name` in `function`."""
"""Add a `Mode` which can be referred to by `name` in `function`."""
if
name
in
predefined_modes
:
if
name
in
predefined_modes
:
...
...
theano/gof/opt.py
浏览文件 @
d4dfbf2a
...
@@ -3,21 +3,23 @@ Defines the base class for optimizations as well as a certain
...
@@ -3,21 +3,23 @@ Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
amount of useful generic optimization tools.
"""
"""
import
copy
import
copy
,
logging
,
sys
,
time
import
logging
import
sys
import
time
import
numpy
import
numpy
import
graph
import
graph
from
env
import
InconsistencyError
from
env
import
InconsistencyError
import
op
import
utils
import
utils
import
unify
import
unify
import
toolbox
import
toolbox
import
op
import
theano
import
theano
from
theano
import
config
from
theano
import
config
from
theano.gof.python25
import
any
,
all
,
deque
from
theano.gof.python25
import
any
,
all
,
deque
from
theano.configparser
import
AddConfigVar
,
BoolParam
,
config
from
theano.configparser
import
AddConfigVar
,
BoolParam
#if sys.version_info[:2] >= (2,5):
#if sys.version_info[:2] >= (2,5):
# from collections import defaultdict
# from collections import defaultdict
...
@@ -39,9 +41,11 @@ import traceback
...
@@ -39,9 +41,11 @@ import traceback
_optimizer_idx
=
[
0
]
_optimizer_idx
=
[
0
]
def
_list_of_nodes
(
env
):
def
_list_of_nodes
(
env
):
return
list
(
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
))
return
list
(
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
))
class
Optimizer
(
object
):
class
Optimizer
(
object
):
"""WRITEME
"""WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it.
An L{Optimizer} can be applied to an L{Env} to transform it.
...
@@ -91,26 +95,30 @@ class Optimizer(object):
...
@@ -91,26 +95,30 @@ class Optimizer(object):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
name
=
getattr
(
self
,
'name'
,
None
)
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
name
,
id
(
self
))
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
))
class
FromFunctionOptimizer
(
Optimizer
):
class
FromFunctionOptimizer
(
Optimizer
):
"""WRITEME"""
"""WRITEME"""
def
__init__
(
self
,
fn
):
def
__init__
(
self
,
fn
):
self
.
apply
=
fn
self
.
apply
=
fn
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
# Added by default
# Added by default
#env.extend(toolbox.ReplaceValidate())
#env.extend(toolbox.ReplaceValidate())
pass
pass
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
str
(
self
.
apply
),
str
(
self
.
apply
),
id
(
self
))
id
(
self
))
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
fn
(
*
args
,
**
kwargs
)
return
self
.
fn
(
*
args
,
**
kwargs
)
def
optimizer
(
f
):
def
optimizer
(
f
):
"""decorator for FromFunctionOptimizer"""
"""decorator for FromFunctionOptimizer"""
rval
=
FromFunctionOptimizer
(
f
)
rval
=
FromFunctionOptimizer
(
f
)
...
@@ -118,7 +126,6 @@ def optimizer(f):
...
@@ -118,7 +126,6 @@ def optimizer(f):
return
rval
return
rval
class
SeqOptimizer
(
Optimizer
,
list
):
class
SeqOptimizer
(
Optimizer
,
list
):
#inherit from Optimizer first to get Optimizer.__hash__
#inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME
"""WRITEME
...
@@ -129,7 +136,7 @@ class SeqOptimizer(Optimizer, list):
...
@@ -129,7 +136,7 @@ class SeqOptimizer(Optimizer, list):
def
warn
(
exc
,
self
,
optimizer
):
def
warn
(
exc
,
self
,
optimizer
):
"""Default failure_callback for SeqOptimizer
"""Default failure_callback for SeqOptimizer
"""
"""
_logger
.
error
(
"SeqOptimizer apply
%
s"
%
str
(
optimizer
))
_logger
.
error
(
"SeqOptimizer apply
%
s"
%
str
(
optimizer
))
_logger
.
error
(
"Traceback:"
)
_logger
.
error
(
"Traceback:"
)
_logger
.
error
(
traceback
.
format_exc
())
_logger
.
error
(
traceback
.
format_exc
())
if
config
.
on_opt_error
==
'raise'
:
if
config
.
on_opt_error
==
'raise'
:
...
@@ -146,14 +153,15 @@ class SeqOptimizer(Optimizer, list):
...
@@ -146,14 +153,15 @@ class SeqOptimizer(Optimizer, list):
"""WRITEME
"""WRITEME
Applies each L{Optimizer} in self in turn.
Applies each L{Optimizer} in self in turn.
"""
"""
l
=
[]
l
=
[]
nb_node_before
=
len
(
env
.
nodes
)
nb_node_before
=
len
(
env
.
nodes
)
for
optimizer
in
self
:
for
optimizer
in
self
:
try
:
try
:
t0
=
time
.
time
()
t0
=
time
.
time
()
optimizer
.
optimize
(
env
)
optimizer
.
optimize
(
env
)
l
.
append
(
float
(
time
.
time
()
-
t0
))
l
.
append
(
float
(
time
.
time
()
-
t0
))
except
AssertionError
:
# do not catch Assertion failures
except
AssertionError
:
# do not catch Assertion failures
raise
raise
except
Exception
,
e
:
except
Exception
,
e
:
if
self
.
failure_callback
:
if
self
.
failure_callback
:
...
@@ -192,7 +200,6 @@ class SeqOptimizer(Optimizer, list):
...
@@ -192,7 +200,6 @@ class SeqOptimizer(Optimizer, list):
#added to override the list's __neq__ implementation
#added to override the list's __neq__ implementation
return
id
(
self
)
!=
id
(
other
)
return
id
(
self
)
!=
id
(
other
)
def
__str__
(
self
):
def
__str__
(
self
):
return
"SeqOpt(
%
s)"
%
list
.
__str__
(
self
)
return
"SeqOpt(
%
s)"
%
list
.
__str__
(
self
)
...
@@ -201,14 +208,13 @@ class SeqOptimizer(Optimizer, list):
...
@@ -201,14 +208,13 @@ class SeqOptimizer(Optimizer, list):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
name
=
getattr
(
self
,
'name'
,
None
)
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
name
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
))
# This way, -1 will do all depth
# This way, -1 will do all depth
if
depth
!=
0
:
if
depth
!=
0
:
depth
-=
1
depth
-=
1
for
opt
in
self
:
for
opt
in
self
:
opt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
)
opt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
depth
)
class
_metadict
:
class
_metadict
:
...
@@ -219,17 +225,39 @@ class _metadict:
...
@@ -219,17 +225,39 @@ class _metadict:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
d
=
{}
self
.
d
=
{}
self
.
l
=
[]
self
.
l
=
[]
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
self
.
get
(
item
,
None
)
return
self
.
get
(
item
,
None
)
def
__setitem__
(
self
,
item
,
value
):
def
__setitem__
(
self
,
item
,
value
):
try
:
try
:
self
.
d
[
item
]
=
value
self
.
d
[
item
]
=
value
except
Exception
:
except
Exception
:
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
if
key
==
item
:
if
key
==
item
:
self
.
l
[
i
]
=
(
item
,
value
)
self
.
l
[
i
]
=
(
item
,
value
)
return
return
self
.
l
.
append
((
item
,
value
))
self
.
l
.
append
((
item
,
value
))
def
__delitem__
(
self
,
item
):
if
item
in
self
.
d
:
del
self
.
d
[
item
]
else
:
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
if
key
==
item
:
del
self
.
l
[
i
]
return
raise
KeyError
(
item
)
def
discard
(
self
,
item
):
if
item
in
self
.
d
:
del
self
.
d
[
item
]
else
:
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
if
key
==
item
:
del
self
.
l
[
i
]
return
def
get
(
self
,
item
,
default
):
def
get
(
self
,
item
,
default
):
try
:
try
:
return
self
.
d
[
item
]
return
self
.
d
[
item
]
...
@@ -245,13 +273,148 @@ class _metadict:
...
@@ -245,13 +273,148 @@ class _metadict:
return
value
return
value
else
:
else
:
return
default
return
default
def
clear
(
self
):
def
clear
(
self
):
self
.
d
=
{}
self
.
d
=
{}
self
.
l
=
[]
self
.
l
=
[]
def
__str__
(
self
):
def
__str__
(
self
):
return
"(
%
s,
%
s)"
%
(
self
.
d
,
self
.
l
)
return
"(
%
s,
%
s)"
%
(
self
.
d
,
self
.
l
)
class
MergeFeature
(
object
):
"""
Keeps track of variables in env that cannot be merged together.
That way, the MergeOptimizer can remember the result of the last merge
pass on the env.
"""
def
on_attach
(
self
,
env
):
assert
not
hasattr
(
env
,
'merge_feature'
)
env
.
merge_feature
=
self
## For constants
self
.
seen_constants
=
set
()
# variable -> signature (for constants)
self
.
const_sig
=
_metadict
()
# signature -> variable (for constants)
self
.
const_sig_inv
=
_metadict
()
## For all variables
# Set of distinct (not mergeable) nodes
self
.
nodes_seen
=
set
()
# Each element of scheduled is a list of list of (out, new_out) pairs.
# Each list of pairs represent the substitution needed to replace all
# the outputs of a node with the outputs of a replacement candidate.
# Each node can have several candidates. For instance, if "node" has
# 2 outputs, and there are 3 replacement candidates, we will have:
# shelf.scheduled = [
# [[(node.out1, cand1.out1), (node.out2, cand1.out2)],
# [(node.out1, cand2.out1), (node.out2, cand2.out2)],
# [(node.out1, cand3.out1), (node.out2, cand3.out2)]]]
self
.
scheduled
=
[]
# List of (node, candidate) pairs, where we tried to replace node by
# candidate, but it failed. This is used to avoid infinite loops
# during the replacement phase.
self
.
blacklist
=
[]
for
node
in
env
.
toposort
():
self
.
on_import
(
env
,
node
)
def
on_change_input
(
self
,
env
,
node
,
i
,
r
,
new_r
):
# If inputs to node change, it is not guaranteed that it is distinct
# from the other nodes in nodes_seen
if
node
in
self
.
nodes_seen
:
self
.
nodes_seen
.
discard
(
node
)
self
.
process_node
(
env
,
node
)
if
isinstance
(
new_r
,
graph
.
Constant
):
self
.
process_constant
(
env
,
new_r
)
def
on_import
(
self
,
env
,
node
):
for
c
in
node
.
inputs
:
if
isinstance
(
c
,
graph
.
Constant
):
self
.
process_constant
(
env
,
c
)
self
.
process_node
(
env
,
node
)
def
on_prune
(
self
,
env
,
node
):
self
.
nodes_seen
.
discard
(
node
)
for
c
in
node
.
inputs
:
if
isinstance
(
c
,
graph
.
Constant
)
and
(
len
(
c
.
clients
)
<=
1
):
# This was the last node using this constant
sig
=
self
.
const_sig
[
c
]
self
.
const_sig
.
discard
(
c
)
self
.
const_sig_inv
.
discard
(
sig
)
self
.
seen_constants
.
discard
(
id
(
c
))
def
process_constant
(
self
,
env
,
c
):
"""Check if a constant can be merged, and queue that replacement"""
if
id
(
c
)
in
self
.
seen_constants
:
return
sig
=
c
.
signature
()
other_c
=
self
.
const_sig_inv
.
get
(
sig
,
None
)
if
other_c
is
not
None
:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if
c
.
name
:
other_c
.
name
=
c
.
name
self
.
scheduled
.
append
([[(
c
,
other_c
)]])
else
:
#this is a new constant
self
.
const_sig
[
c
]
=
sig
self
.
const_sig_inv
[
sig
]
=
c
self
.
seen_constants
.
add
(
id
(
c
))
def
process_node
(
self
,
env
,
node
):
"""Check if a node can be merged, and queue that replacement."""
if
node
in
self
.
nodes_seen
:
return
# These asserts ensure that the env has set the clients field properly.
# The clients should at least contain `node` itself!
if
node
.
inputs
:
assert
len
(
node
.
inputs
[
0
]
.
clients
)
>
0
assert
(
node
,
0
)
in
node
.
inputs
[
0
]
.
clients
merge_candidates
=
[
c
for
(
c
,
i
)
in
node
.
inputs
[
0
]
.
clients
if
c
in
self
.
nodes_seen
]
else
:
merge_candidates
=
[]
replacement_candidates
=
[]
for
candidate
in
merge_candidates
:
if
candidate
is
node
:
continue
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
continue
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
))
if
inputs_match
and
node
.
op
==
candidate
.
op
:
if
(
node
,
candidate
)
in
self
.
blacklist
:
# They were already tried, and there was an error
continue
# Schedule transfer of clients from node to candidate
pairs
=
zip
(
node
.
outputs
,
candidate
.
outputs
)
#transfer names
for
node_output
,
cand_output
in
pairs
:
#clobber old name with new one
#it's arbitrary... one of the names has to go
if
node_output
.
name
:
cand_output
.
name
=
node_output
.
name
replacement_candidates
.
append
(
pairs
)
if
replacement_candidates
:
self
.
scheduled
.
append
(
replacement_candidates
)
else
:
self
.
nodes_seen
.
add
(
node
)
class
MergeOptimizer
(
Optimizer
):
class
MergeOptimizer
(
Optimizer
):
"""
"""
Merges parts of the graph that are identical and redundant.
Merges parts of the graph that are identical and redundant.
...
@@ -264,94 +427,32 @@ class MergeOptimizer(Optimizer):
...
@@ -264,94 +427,32 @@ class MergeOptimizer(Optimizer):
The first step of merging is constant-merging, so that all clients of an
The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1).
int(1) for example, are transferred to a particular instance of int(1).
"""
"""
def
__init__
(
self
,
skip_const_merge
=
False
):
self
.
skip_const_merge
=
skip_const_merge
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
# Added by default
# Added by default
#env.extend(toolbox.ReplaceValidate())
#env.extend(toolbox.ReplaceValidate())
pass
if
not
hasattr
(
env
,
'merge_feature'
):
env
.
extend
(
MergeFeature
())
def
apply_constant_merge
(
self
,
env
):
seen_constants
=
set
()
const_sig
=
_metadict
()
# variable -> variable.signature() (for constants)
const_sig_inv
=
_metadict
()
# signature -> variable (for constants)
for
node
in
_list_of_nodes
(
env
):
for
i
,
c
in
enumerate
([
r
for
r
in
node
.
inputs
if
isinstance
(
r
,
graph
.
Constant
)]):
if
id
(
c
)
in
seen_constants
:
continue
else
:
seen_constants
.
add
(
id
(
c
))
sig
=
c
.
signature
()
other_c
=
const_sig_inv
.
get
(
sig
,
None
)
if
other_c
is
not
None
:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if
c
.
name
:
other_c
.
name
=
c
.
name
env
.
replace_validate
(
c
,
other_c
,
reason
=
'Constant Merge'
)
else
:
#this is a new constant
const_sig
[
c
]
=
sig
const_sig_inv
[
sig
]
=
c
def
apply_node_merge
(
self
,
env
):
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Variables
nodes_seen
=
{}
for
node_idx
,
node
in
enumerate
(
_list_of_nodes
(
env
)):
#
# these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself!
#
if
node
.
inputs
:
assert
len
(
node
.
inputs
[
0
]
.
clients
)
>
0
assert
(
node
,
0
)
in
node
.
inputs
[
0
]
.
clients
merge_candidates
=
[(
nodes_seen
[
c
],
c
)
for
(
c
,
i
)
in
node
.
inputs
[
0
]
.
clients
if
c
in
nodes_seen
]
else
:
merge_candidates
=
[]
merge_candidates
.
sort
()
nodes_seen
[
node
]
=
node_idx
#print 'NODE', node, merge_candidates, node.inputs[0].clients
for
candidate_idx
,
candidate
in
merge_candidates
:
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
continue
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
))
if
inputs_match
and
node
.
op
==
candidate
.
op
:
assert
node
is
not
candidate
#
#transfer clients from node to candidate
#
success
=
True
assert
len
(
node
.
outputs
)
==
len
(
candidate
.
outputs
)
pairs
=
zip
(
node
.
outputs
,
candidate
.
outputs
)
#transfer names
for
node_output
,
cand_output
in
pairs
:
#clobber old name with new one
#it's arbitrary... one of the names has to go
if
node_output
.
name
:
cand_output
.
name
=
node_output
.
name
try
:
env
.
replace_all_validate
(
pairs
,
reason
=
"Merge"
)
except
InconsistencyError
,
e
:
success
=
False
if
success
:
#break out of the candidate loop
break
else
:
#try the next candidate
pass
#TODO: Consider splitting this into a separate optimizer (SeqOptimizer)
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
if
not
self
.
skip_const_merge
:
# Constant and non-constant are now applied in the same phase.
self
.
apply_constant_merge
(
env
)
# I am not sure why, but it seems to be faster this way.
self
.
apply_node_merge
(
env
)
sched
=
env
.
merge_feature
.
scheduled
while
sched
:
pairs_list
=
sched
.
pop
()
success
=
True
for
pairs
in
pairs_list
:
try
:
env
.
replace_all_validate
(
pairs
,
'Merge'
)
except
InconsistencyError
:
success
=
False
env
.
merge_feature
.
blacklist
.
append
(
(
pairs
[
0
][
0
]
.
owner
,
pairs
[
0
][
1
]
.
owner
))
if
success
:
break
# clear blacklist
env
.
merge_feature
.
blacklist
=
[]
merge_optimizer
=
MergeOptimizer
()
merge_optimizer
=
MergeOptimizer
()
...
@@ -417,8 +518,9 @@ def pre_constant_merge(vars):
...
@@ -417,8 +518,9 @@ def pre_constant_merge(vars):
"""
"""
seen_var
=
set
()
seen_var
=
set
()
const_sig
=
{}
# variable -> variable.signature() (for constants)
# signature -> variable (for constants)
const_sig_inv
=
{}
# signature -> variable (for constants)
const_sig_inv
=
{}
def
recursive_merge
(
var
):
def
recursive_merge
(
var
):
if
var
in
seen_var
:
if
var
in
seen_var
:
return
var
return
var
...
@@ -434,12 +536,13 @@ def pre_constant_merge(vars):
...
@@ -434,12 +536,13 @@ def pre_constant_merge(vars):
const_sig_inv
[
sig
]
=
var
const_sig_inv
[
sig
]
=
var
return
var
return
var
if
var
.
owner
:
if
var
.
owner
:
for
idx
,
inp
in
enumerate
(
var
.
owner
.
inputs
):
for
idx
,
inp
in
enumerate
(
var
.
owner
.
inputs
):
var
.
owner
.
inputs
[
idx
]
=
recursive_merge
(
inp
)
var
.
owner
.
inputs
[
idx
]
=
recursive_merge
(
inp
)
return
var
return
var
return
map
(
recursive_merge
,
vars
)
return
map
(
recursive_merge
,
vars
)
########################
########################
### Local Optimizers ###
### Local Optimizers ###
########################
########################
...
@@ -463,25 +566,31 @@ class LocalOptimizer(object):
...
@@ -463,25 +566,31 @@ class LocalOptimizer(object):
Subclasses should implement this function so that it returns one of two
Subclasses should implement this function so that it returns one of two
kinds of things:
kinds of things:
- False to indicate that no optimization can be applied to this `node`; or
- False to indicate that no optimization can be applied to this `node`;
or
- <list of variables> to use in place of `node`'s outputs in the greater graph.
- <list of variables> to use in place of `node`'s outputs in the
greater graph.
:type node: an Apply instance
:type node: an Apply instance
"""
"""
raise
utils
.
MethodNotDefined
(
"transform"
,
type
(
self
),
self
.
__class__
.
__name__
)
raise
utils
.
MethodNotDefined
(
"transform"
,
type
(
self
),
self
.
__class__
.
__name__
)
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
"""If this local optimization wants to add some requirements to the env,
"""
This is the place to do it."""
If this local optimization wants to add some requirements to the env,
This is the place to do it.
"""
# Added by default
# Added by default
#env.extend(toolbox.ReplaceValidate())
#env.extend(toolbox.ReplaceValidate())
pass
pass
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
"""WRITEME"""
"""WRITEME"""
...
@@ -490,15 +599,21 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
...
@@ -490,15 +599,21 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
tracks
=
[]
tracks
=
[]
self
.
transform
=
fn
self
.
transform
=
fn
self
.
_tracks
=
tracks
self
.
_tracks
=
tracks
def
tracks
(
self
):
def
tracks
(
self
):
return
self
.
_tracks
return
self
.
_tracks
def
__str__
(
self
):
def
__str__
(
self
):
return
getattr
(
self
,
'__name__'
,
'<FromFunctionLocalOptimizer instance>'
)
return
getattr
(
self
,
'__name__'
,
'<FromFunctionLocalOptimizer instance>'
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
str
(
self
.
transform
),
str
(
self
.
transform
),
id
(
self
))
id
(
self
))
def
local_optimizer
(
*
tracks
):
def
local_optimizer
(
*
tracks
):
def
decorator
(
f
):
def
decorator
(
f
):
"""WRITEME"""
"""WRITEME"""
...
@@ -513,11 +628,15 @@ class LocalOptGroup(LocalOptimizer):
...
@@ -513,11 +628,15 @@ class LocalOptGroup(LocalOptimizer):
def
__init__
(
self
,
*
optimizers
):
def
__init__
(
self
,
*
optimizers
):
self
.
opts
=
optimizers
self
.
opts
=
optimizers
self
.
reentrant
=
any
(
getattr
(
opt
,
'reentrant'
,
True
)
for
opt
in
optimizers
)
self
.
reentrant
=
any
(
getattr
(
opt
,
'reentrant'
,
True
)
self
.
retains_inputs
=
all
(
getattr
(
opt
,
'retains_inputs'
,
False
)
for
opt
in
optimizers
)
for
opt
in
optimizers
)
self
.
retains_inputs
=
all
(
getattr
(
opt
,
'retains_inputs'
,
False
)
for
opt
in
optimizers
)
def
__str__
(
self
):
def
__str__
(
self
):
return
getattr
(
self
,
'__name__'
,
'<theano.gof.opt.LocalOptGroup instance>'
+
str
([
str
(
o
)
for
o
in
self
.
opts
]))
return
getattr
(
self
,
'__name__'
,
(
'<theano.gof.opt.LocalOptGroup instance>'
+
str
([
str
(
o
)
for
o
in
self
.
opts
])))
def
transform
(
self
,
node
):
def
transform
(
self
,
node
):
for
opt
in
self
.
opts
:
for
opt
in
self
.
opts
:
...
@@ -526,11 +645,12 @@ class LocalOptGroup(LocalOptimizer):
...
@@ -526,11 +645,12 @@ class LocalOptGroup(LocalOptimizer):
return
repl
return
repl
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
if
depth
!=
0
:
if
depth
!=
0
:
depth
-=
1
depth
-=
1
for
lopt
in
self
.
opts
:
for
lopt
in
self
.
opts
:
lopt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
)
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
)
,
depth
=
depth
)
class
_LocalOpKeyOptGroup
(
LocalOptGroup
):
class
_LocalOpKeyOptGroup
(
LocalOptGroup
):
...
@@ -550,13 +670,16 @@ class OpSub(LocalOptimizer):
...
@@ -550,13 +670,16 @@ class OpSub(LocalOptimizer):
Replaces the application of a certain op by the application of
Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing.
another op that take the same inputs as what they are replacing.
e.g. OpSub(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
e.g. OpSub(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
"""
reentrant
=
False
# an OpSub does not apply to the nodes it produces
# an OpSub does not apply to the nodes it produces
retains_inputs
=
True
# all the inputs of the original node are transferred to the outputs
reentrant
=
False
# all the inputs of the original node are transferred to the outputs
retains_inputs
=
True
def
__init__
(
self
,
op1
,
op2
,
transfer_tags
=
True
):
def
__init__
(
self
,
op1
,
op2
,
transfer_tags
=
True
):
"""
"""
op1.make_node and op2.make_node must take the same number of
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
inputs and have the same number of outputs.
...
@@ -611,7 +734,8 @@ class OpRemove(LocalOptimizer):
...
@@ -611,7 +734,8 @@ class OpRemove(LocalOptimizer):
return
"
%
s(x) -> x"
%
(
self
.
op
)
return
"
%
s(x) -> x"
%
(
self
.
op
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s(
%
s) id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s(
%
s) id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
,
str
(
self
.
op
),
str
(
self
.
op
),
id
(
self
))
id
(
self
))
...
@@ -662,12 +786,12 @@ class PatternSub(LocalOptimizer):
...
@@ -662,12 +786,12 @@ class PatternSub(LocalOptimizer):
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
(scrabble, 'x'))
"""
"""
def
__init__
(
self
,
in_pattern
,
out_pattern
,
allow_multiple_clients
=
False
,
def
__init__
(
self
,
in_pattern
,
out_pattern
,
allow_multiple_clients
=
False
,
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
):
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
):
"""
"""
Creates a PatternSub that replaces occurrences of
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
in_pattern by occurrences of out_pattern.
...
@@ -677,7 +801,8 @@ class PatternSub(LocalOptimizer):
...
@@ -677,7 +801,8 @@ class PatternSub(LocalOptimizer):
:param allow_multiple_clients: if False, the pattern matching will fail
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
if one of the subpatterns has more than
one client.
one client.
:param pdb: if True, we invoke pdb when the first node in the pattern match.
:param pdb: if True, we invoke pdb when the first node in the
pattern match.
"""
"""
self
.
in_pattern
=
in_pattern
self
.
in_pattern
=
in_pattern
self
.
out_pattern
=
out_pattern
self
.
out_pattern
=
out_pattern
...
@@ -686,8 +811,11 @@ class PatternSub(LocalOptimizer):
...
@@ -686,8 +811,11 @@ class PatternSub(LocalOptimizer):
elif
isinstance
(
in_pattern
,
dict
):
elif
isinstance
(
in_pattern
,
dict
):
self
.
op
=
self
.
in_pattern
[
'pattern'
][
0
]
self
.
op
=
self
.
in_pattern
[
'pattern'
][
0
]
else
:
else
:
raise
TypeError
(
"The pattern to search for must start with a specific Op instance."
)
raise
TypeError
(
"The pattern to search for must start with "
self
.
__doc__
=
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
"a specific Op instance."
)
self
.
__doc__
=
(
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
)
self
.
allow_multiple_clients
=
allow_multiple_clients
self
.
allow_multiple_clients
=
allow_multiple_clients
self
.
skip_identities_fn
=
skip_identities_fn
self
.
skip_identities_fn
=
skip_identities_fn
if
name
:
if
name
:
...
@@ -722,7 +850,7 @@ class PatternSub(LocalOptimizer):
...
@@ -722,7 +850,7 @@ class PatternSub(LocalOptimizer):
if
node
.
op
!=
self
.
op
:
if
node
.
op
!=
self
.
op
:
return
False
return
False
def
match
(
pattern
,
expr
,
u
,
allow_multiple_clients
=
False
,
pdb
=
False
):
def
match
(
pattern
,
expr
,
u
,
allow_multiple_clients
=
False
,
pdb
=
False
):
def
retry_with_equiv
():
def
retry_with_equiv
():
expr_equiv
=
self
.
skip_identities
(
expr
)
expr_equiv
=
self
.
skip_identities
(
expr
)
if
expr_equiv
is
None
:
if
expr_equiv
is
None
:
...
@@ -735,7 +863,9 @@ class PatternSub(LocalOptimizer):
...
@@ -735,7 +863,9 @@ class PatternSub(LocalOptimizer):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
expr
.
owner
is
None
:
if
expr
.
owner
is
None
:
return
False
return
False
if
not
(
expr
.
owner
.
op
==
pattern
[
0
])
or
(
not
allow_multiple_clients
and
len
(
expr
.
clients
)
>
1
):
if
(
not
(
expr
.
owner
.
op
==
pattern
[
0
])
or
(
not
allow_multiple_clients
and
len
(
expr
.
clients
)
>
1
)):
return
retry_with_equiv
()
return
retry_with_equiv
()
if
len
(
pattern
)
-
1
!=
len
(
expr
.
owner
.
inputs
):
if
len
(
pattern
)
-
1
!=
len
(
expr
.
owner
.
inputs
):
return
retry_with_equiv
()
return
retry_with_equiv
()
...
@@ -747,10 +877,14 @@ class PatternSub(LocalOptimizer):
...
@@ -747,10 +877,14 @@ class PatternSub(LocalOptimizer):
try
:
try
:
real_pattern
=
pattern
[
'pattern'
]
real_pattern
=
pattern
[
'pattern'
]
except
KeyError
:
except
KeyError
:
raise
KeyError
(
"Malformed pattern:
%
s (expected key 'pattern')"
%
pattern
)
raise
KeyError
(
"Malformed pattern:
%
s (expected key 'pattern')"
%
pattern
)
constraint
=
pattern
.
get
(
'constraint'
,
lambda
expr
:
True
)
constraint
=
pattern
.
get
(
'constraint'
,
lambda
expr
:
True
)
if
constraint
(
expr
):
if
constraint
(
expr
):
return
match
(
real_pattern
,
expr
,
u
,
pattern
.
get
(
'allow_multiple_clients'
,
allow_multiple_clients
))
return
match
(
real_pattern
,
expr
,
u
,
pattern
.
get
(
'allow_multiple_clients'
,
allow_multiple_clients
))
else
:
else
:
return
retry_with_equiv
()
return
retry_with_equiv
()
elif
isinstance
(
pattern
,
basestring
):
elif
isinstance
(
pattern
,
basestring
):
...
@@ -759,17 +893,22 @@ class PatternSub(LocalOptimizer):
...
@@ -759,17 +893,22 @@ class PatternSub(LocalOptimizer):
return
retry_with_equiv
()
return
retry_with_equiv
()
else
:
else
:
u
=
u
.
merge
(
expr
,
v
)
u
=
u
.
merge
(
expr
,
v
)
elif
isinstance
(
pattern
,
(
int
,
float
))
and
isinstance
(
expr
,
graph
.
Constant
):
elif
(
isinstance
(
pattern
,
(
int
,
float
))
if
numpy
.
all
(
theano
.
tensor
.
constant
(
pattern
)
.
value
==
expr
.
value
):
and
isinstance
(
expr
,
graph
.
Constant
)):
if
numpy
.
all
(
theano
.
tensor
.
constant
(
pattern
)
.
value
==
expr
.
value
):
return
u
return
u
else
:
else
:
return
retry_with_equiv
()
return
retry_with_equiv
()
elif
isinstance
(
pattern
,
graph
.
Constant
)
and
isinstance
(
expr
,
graph
.
Constant
)
and
pattern
.
equals
(
expr
):
elif
(
isinstance
(
pattern
,
graph
.
Constant
)
and
isinstance
(
expr
,
graph
.
Constant
)
and
pattern
.
equals
(
expr
)):
return
u
return
u
else
:
else
:
return
retry_with_equiv
()
return
retry_with_equiv
()
if
pdb
:
if
pdb
:
import
pdb
;
pdb
.
set_trace
()
import
pdb
pdb
.
set_trace
()
return
u
return
u
def
build
(
pattern
,
u
):
def
build
(
pattern
,
u
):
...
@@ -778,11 +917,12 @@ class PatternSub(LocalOptimizer):
...
@@ -778,11 +917,12 @@ class PatternSub(LocalOptimizer):
return
pattern
[
0
](
*
args
)
return
pattern
[
0
](
*
args
)
elif
isinstance
(
pattern
,
basestring
):
elif
isinstance
(
pattern
,
basestring
):
return
u
[
unify
.
Var
(
pattern
)]
return
u
[
unify
.
Var
(
pattern
)]
elif
isinstance
(
pattern
,
(
int
,
float
)):
elif
isinstance
(
pattern
,
(
int
,
float
)):
return
pattern
return
pattern
else
:
else
:
return
pattern
.
clone
()
return
pattern
.
clone
()
u
=
match
(
self
.
in_pattern
,
node
.
out
,
unify
.
Unification
(),
True
,
self
.
pdb
)
u
=
match
(
self
.
in_pattern
,
node
.
out
,
unify
.
Unification
(),
True
,
self
.
pdb
)
if
u
:
if
u
:
p
=
self
.
out_pattern
p
=
self
.
out_pattern
new
=
build
(
p
,
u
)
new
=
build
(
p
,
u
)
...
@@ -792,23 +932,31 @@ class PatternSub(LocalOptimizer):
...
@@ -792,23 +932,31 @@ class PatternSub(LocalOptimizer):
return
False
return
False
def
__str__
(
self
):
def
__str__
(
self
):
if
getattr
(
self
,
'__name__'
,
None
):
if
getattr
(
self
,
'__name__'
,
None
):
return
self
.
__name__
return
self
.
__name__
def
pattern_to_str
(
pattern
):
def
pattern_to_str
(
pattern
):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
return
"
%
s(
%
s)"
%
(
str
(
pattern
[
0
]),
", "
.
join
([
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]]))
return
"
%
s(
%
s)"
%
(
str
(
pattern
[
0
]),
", "
.
join
([
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]]))
elif
isinstance
(
pattern
,
dict
):
elif
isinstance
(
pattern
,
dict
):
return
"
%
s subject to
%
s"
%
(
pattern_to_str
(
pattern
[
'pattern'
]),
str
(
pattern
.
get
(
'constraint'
,
'no conditions'
)))
return
"
%
s subject to
%
s"
%
(
pattern_to_str
(
pattern
[
'pattern'
]),
str
(
pattern
.
get
(
'constraint'
,
'no conditions'
)))
else
:
else
:
return
str
(
pattern
)
return
str
(
pattern
)
return
"
%
s ->
%
s"
%
(
pattern_to_str
(
self
.
in_pattern
),
pattern_to_str
(
self
.
out_pattern
))
return
"
%
s ->
%
s"
%
(
pattern_to_str
(
self
.
in_pattern
),
pattern_to_str
(
self
.
out_pattern
))
def
__repr__
(
self
):
def
__repr__
(
self
):
return
str
(
self
)
return
str
(
self
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'__name__'
,
getattr
(
self
,
'name'
,
None
))
name
=
getattr
(
self
,
'__name__'
,
getattr
(
self
,
'name'
,
None
))
print
>>
stream
,
"
%
s
%
s
%
s(
%
s,
%
s) id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s
%
s(
%
s,
%
s) id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
,
name
,
name
,
str
(
self
.
in_pattern
),
str
(
self
.
in_pattern
),
...
@@ -836,37 +984,48 @@ class NavigatorOptimizer(Optimizer):
...
@@ -836,37 +984,48 @@ class NavigatorOptimizer(Optimizer):
_logger
.
error
(
traceback
.
format_exc
())
_logger
.
error
(
traceback
.
format_exc
())
if
isinstance
(
exc
,
AssertionError
)
or
config
.
on_opt_error
==
'raise'
:
if
isinstance
(
exc
,
AssertionError
)
or
config
.
on_opt_error
==
'raise'
:
raise
exc
raise
exc
@staticmethod
@staticmethod
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
):
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: ignore InconsistencyErrors, print traceback
"""failure_callback for NavigatorOptimizer
ignore InconsistencyErrors, print traceback
"""
"""
if
isinstance
(
exc
,
InconsistencyError
):
if
isinstance
(
exc
,
InconsistencyError
):
return
return
return
NavigatorOptimizer
.
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
)
return
NavigatorOptimizer
.
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
)
@staticmethod
@staticmethod
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
local_opt
):
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: ignore all errors
"""failure_callback for NavigatorOptimizer: ignore all errors
"""
"""
pass
pass
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
"""
"""
:param local_opt: a LocalOptimizer to apply over a Env (or None is Ok too).
:param local_opt: a LocalOptimizer to apply over a Env
(or None is Ok too).
:param ignore_newtrees:
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization
- True: new subgraphs returned by an optimization is not a
- False: new subgraphs returned by an optimization is a candidate for optimization
candidate for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant' attribute.
- False: new subgraphs returned by an optimization is a candidate
for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
:param failure_callback:
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables will be 'None'.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
example) then the new variables will be the ones created by
transform().
transform().
If this parameter is None, then exceptions are not caught here (raised normally).
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
"""
self
.
local_opt
=
local_opt
self
.
local_opt
=
local_opt
if
ignore_newtrees
==
'auto'
:
if
ignore_newtrees
==
'auto'
:
...
@@ -875,15 +1034,19 @@ class NavigatorOptimizer(Optimizer):
...
@@ -875,15 +1034,19 @@ class NavigatorOptimizer(Optimizer):
self
.
ignore_newtrees
=
ignore_newtrees
self
.
ignore_newtrees
=
ignore_newtrees
self
.
failure_callback
=
failure_callback
self
.
failure_callback
=
failure_callback
def
attach_updater
(
self
,
env
,
importer
,
pruner
,
chin
=
None
):
def
attach_updater
(
self
,
env
,
importer
,
pruner
,
chin
=
None
):
"""Install some Env listeners to help the navigator deal with the ignore_trees-related functionality.
"""
Install some Env listeners to help the navigator deal with the
ignore_trees-related functionality.
:param importer: function that will be called whenever when optimizations add stuff to the graph.
:param importer: function that will be called whenever when
:param pruner: function to be called when optimizations remove stuff from graph.
optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff
from graph.
:param chin: "on change input" called whenever an node's inputs change.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The Env plugin that handles the three tasks.
Keep this around so that you can detach later!
:returns: The Env plugin that handles the three tasks.
Keep this around so that you can detach later!
"""
"""
if
self
.
ignore_newtrees
:
if
self
.
ignore_newtrees
:
importer
=
None
importer
=
None
...
@@ -916,21 +1079,22 @@ class NavigatorOptimizer(Optimizer):
...
@@ -916,21 +1079,22 @@ class NavigatorOptimizer(Optimizer):
if
u
is
not
None
:
if
u
is
not
None
:
env
.
remove_feature
(
u
)
env
.
remove_feature
(
u
)
def
process_node
(
self
,
env
,
node
,
lopt
=
None
):
def
process_node
(
self
,
env
,
node
,
lopt
=
None
):
"""
"""
This function will use `lopt` to `transform` the `node`. The `transform` method will
This function will use `lopt` to `transform` the `node`. The
return either False or a list of Variables that are intended to replace `node.outputs`.
`transform` method will return either False or a list of Variables
that are intended to replace `node.outputs`.
If the env accepts the replacement, then the optimization is
successful, and this
If the env accepts the replacement, then the optimization is
function returns True.
successful, and this
function returns True.
If there are no replacement candidates or the env rejects the
replacements, this
If there are no replacement candidates or the env rejects the
function returns False.
replacements, this
function returns False.
:param env: an Env
:param env: an Env
:param node: an Apply instance in `env`
:param node: an Apply instance in `env`
:param lopt: a LocalOptimizer instance that may have a better idea for
how to compute
:param lopt: a LocalOptimizer instance that may have a better idea for
node's outputs.
how to compute
node's outputs.
:rtype: Bool
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `env`.
:returns: True iff the `node`'s outputs were replaced in the `env`.
...
@@ -940,16 +1104,19 @@ class NavigatorOptimizer(Optimizer):
...
@@ -940,16 +1104,19 @@ class NavigatorOptimizer(Optimizer):
replacements
=
lopt
.
transform
(
node
)
replacements
=
lopt
.
transform
(
node
)
except
Exception
,
e
:
except
Exception
,
e
:
if
self
.
failure_callback
is
not
None
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
e
,
self
,
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
)
self
.
failure_callback
(
e
,
self
,
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
)
return
False
return
False
else
:
else
:
raise
raise
if
replacements
is
False
or
replacements
is
None
:
if
replacements
is
False
or
replacements
is
None
:
return
False
return
False
if
not
isinstance
(
replacements
,
(
tuple
,
list
)):
if
not
isinstance
(
replacements
,
(
tuple
,
list
)):
raise
TypeError
(
'Optimizer
%
s gave wrong type of replacement. Expected list or tuple.'
%
lopt
)
raise
TypeError
(
'Optimizer
%
s gave wrong type of replacement. '
'Expected list or tuple.'
%
lopt
)
if
len
(
node
.
outputs
)
!=
len
(
replacements
):
if
len
(
node
.
outputs
)
!=
len
(
replacements
):
raise
ValueError
(
'Optimizer
%
s gave wrong number of replacements'
%
lopt
)
raise
ValueError
(
'Optimizer
%
s gave wrong number of replacements'
%
lopt
)
# If an output would be replaced by itself, no need to perform
# If an output would be replaced by itself, no need to perform
# the replacement
# the replacement
repl_pairs
=
[(
r
,
rnew
)
for
r
,
rnew
in
zip
(
node
.
outputs
,
replacements
)
repl_pairs
=
[(
r
,
rnew
)
for
r
,
rnew
in
zip
(
node
.
outputs
,
replacements
)
...
@@ -962,8 +1129,8 @@ class NavigatorOptimizer(Optimizer):
...
@@ -962,8 +1129,8 @@ class NavigatorOptimizer(Optimizer):
except
Exception
,
e
:
except
Exception
,
e
:
# This means the replacements were rejected by the env.
# This means the replacements were rejected by the env.
#
#
# This is not supposed to happen. The default failure_callback
will print a
# This is not supposed to happen. The default failure_callback
# traceback as a warning.
#
will print a
traceback as a warning.
if
self
.
failure_callback
is
not
None
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
e
,
self
,
repl_pairs
,
lopt
)
self
.
failure_callback
(
e
,
self
,
repl_pairs
,
lopt
)
return
False
return
False
...
@@ -978,26 +1145,33 @@ class NavigatorOptimizer(Optimizer):
...
@@ -978,26 +1145,33 @@ class NavigatorOptimizer(Optimizer):
self
.
local_opt
.
add_requirements
(
env
)
self
.
local_opt
.
add_requirements
(
env
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s (
%
i)"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s (
%
i)"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
if
depth
!=
0
:
if
depth
!=
0
:
self
.
local_opt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
-
1
)
self
.
local_opt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
class
TopoOptimizer
(
NavigatorOptimizer
):
class
TopoOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
if
order
not
in
[
'out_to_in'
,
'in_to_out'
]:
if
order
not
in
[
'out_to_in'
,
'in_to_out'
]:
raise
ValueError
(
"order must be 'out_to_in' or 'in_to_out'"
)
raise
ValueError
(
"order must be 'out_to_in' or 'in_to_out'"
)
self
.
order
=
order
self
.
order
=
order
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
def
apply
(
self
,
env
,
start_from
=
None
):
def
apply
(
self
,
env
,
start_from
=
None
):
if
start_from
is
None
:
start_from
=
env
.
outputs
if
start_from
is
None
:
start_from
=
env
.
outputs
q
=
deque
(
graph
.
io_toposort
(
env
.
inputs
,
start_from
))
q
=
deque
(
graph
.
io_toposort
(
env
.
inputs
,
start_from
))
def
importer
(
node
):
def
importer
(
node
):
if
node
is
not
current_node
:
if
node
is
not
current_node
:
q
.
append
(
node
)
q
.
append
(
node
)
def
pruner
(
node
):
def
pruner
(
node
):
if
node
is
not
current_node
:
if
node
is
not
current_node
:
try
:
try
:
...
@@ -1020,14 +1194,16 @@ class TopoOptimizer(NavigatorOptimizer):
...
@@ -1020,14 +1194,16 @@ class TopoOptimizer(NavigatorOptimizer):
self
.
detach_updater
(
env
,
u
)
self
.
detach_updater
(
env
,
u
)
class
OpKeyOptimizer
(
NavigatorOptimizer
):
class
OpKeyOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
if
not
hasattr
(
local_opt
,
'op_key'
):
if
not
hasattr
(
local_opt
,
'op_key'
):
raise
TypeError
(
"LocalOptimizer for OpKeyOptimizer must have an 'op_key' method."
)
raise
TypeError
(
"LocalOptimizer for OpKeyOptimizer must have "
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
"an 'op_key' method."
)
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
op
=
self
.
local_opt
.
op_key
()
op
=
self
.
local_opt
.
op_key
()
...
@@ -1035,9 +1211,12 @@ class OpKeyOptimizer(NavigatorOptimizer):
...
@@ -1035,9 +1211,12 @@ class OpKeyOptimizer(NavigatorOptimizer):
q
=
reduce
(
list
.
__iadd__
,
map
(
env
.
get_nodes
,
op
))
q
=
reduce
(
list
.
__iadd__
,
map
(
env
.
get_nodes
,
op
))
else
:
else
:
q
=
list
(
env
.
get_nodes
(
op
))
q
=
list
(
env
.
get_nodes
(
op
))
def
importer
(
node
):
def
importer
(
node
):
if
node
is
not
current_node
:
if
node
is
not
current_node
:
if
node
.
op
==
op
:
q
.
append
(
node
)
if
node
.
op
==
op
:
q
.
append
(
node
)
def
pruner
(
node
):
def
pruner
(
node
):
if
node
is
not
current_node
and
node
.
op
==
op
:
if
node
is
not
current_node
and
node
.
op
==
op
:
try
:
try
:
...
@@ -1065,7 +1244,6 @@ class OpKeyOptimizer(NavigatorOptimizer):
...
@@ -1065,7 +1244,6 @@ class OpKeyOptimizer(NavigatorOptimizer):
env
.
extend
(
toolbox
.
NodeFinder
())
env
.
extend
(
toolbox
.
NodeFinder
())
class
ChangeTracker
:
class
ChangeTracker
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
changed
=
False
self
.
changed
=
False
...
@@ -1082,17 +1260,19 @@ class ChangeTracker:
...
@@ -1082,17 +1260,19 @@ class ChangeTracker:
def
on_attach
(
self
,
env
):
def
on_attach
(
self
,
env
):
env
.
change_tracker
=
self
env
.
change_tracker
=
self
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
def
__init__
(
self
,
def
__init__
(
self
,
optimizers
,
optimizers
,
failure_callback
=
None
,
failure_callback
=
None
,
max_depth
=
None
,
max_depth
=
None
,
max_use_ratio
=
None
):
max_use_ratio
=
None
):
"""
"""
:param optimizers: list or set of local or global optimizations to
apply until
:param optimizers: list or set of local or global optimizations to
equilibrium.
apply until
equilibrium.
:param max_use_ratio: each optimizer can be applied at most (size of graph * this number)
:param max_use_ratio: each optimizer can be applied at most
(size of graph * this number) times
:param max_depth: TODO what does this do? (EquilibriumDB sets it to 5)
:param max_depth: TODO what does this do? (EquilibriumDB sets it to 5)
...
@@ -1100,8 +1280,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1100,8 +1280,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
super
(
EquilibriumOptimizer
,
self
)
.
__init__
(
super
(
EquilibriumOptimizer
,
self
)
.
__init__
(
None
,
None
,
ignore_newtrees
=
True
,
ignore_newtrees
=
True
,
failure_callback
=
failure_callback
)
failure_callback
=
failure_callback
)
self
.
local_optimizers
=
[]
self
.
local_optimizers
=
[]
self
.
global_optimizers
=
[]
self
.
global_optimizers
=
[]
...
@@ -1112,13 +1292,18 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1112,13 +1292,18 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self
.
global_optimizers
.
append
(
opt
)
self
.
global_optimizers
.
append
(
opt
)
self
.
max_depth
=
max_depth
self
.
max_depth
=
max_depth
self
.
max_use_ratio
=
max_use_ratio
self
.
max_use_ratio
=
max_use_ratio
assert
self
.
max_use_ratio
is
not
None
,
'max_use_ratio has to be a number'
assert
self
.
max_use_ratio
is
not
None
,
(
'max_use_ratio has to be a number'
)
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
super
(
EquilibriumOptimizer
,
self
)
.
add_requirements
(
env
)
super
(
EquilibriumOptimizer
,
self
)
.
add_requirements
(
env
)
env
.
extend
(
ChangeTracker
())
env
.
extend
(
ChangeTracker
())
for
opt
in
self
.
local_optimizers
:
opt
.
add_requirements
(
env
)
for
opt
in
self
.
global_optimizers
:
opt
.
add_requirements
(
env
)
def
apply
(
self
,
env
,
start_from
=
None
):
def
apply
(
self
,
env
,
start_from
=
None
):
if
start_from
is
None
:
if
start_from
is
None
:
start_from
=
env
.
outputs
start_from
=
env
.
outputs
changed
=
True
changed
=
True
...
@@ -1153,9 +1338,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1153,9 +1338,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
nb_nodes
.
append
(
len
(
q
))
nb_nodes
.
append
(
len
(
q
))
max_nb_nodes
=
max
(
max_nb_nodes
,
len
(
q
))
max_nb_nodes
=
max
(
max_nb_nodes
,
len
(
q
))
max_use
=
max_nb_nodes
*
self
.
max_use_ratio
max_use
=
max_nb_nodes
*
self
.
max_use_ratio
def
importer
(
node
):
def
importer
(
node
):
if
node
is
not
current_node
:
if
node
is
not
current_node
:
q
.
append
(
node
)
q
.
append
(
node
)
def
pruner
(
node
):
def
pruner
(
node
):
if
node
is
not
current_node
:
if
node
is
not
current_node
:
try
:
try
:
...
@@ -1179,12 +1366,13 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1179,12 +1366,13 @@ class EquilibriumOptimizer(NavigatorOptimizer):
opt_name
=
(
getattr
(
lopt
,
"name"
,
None
)
opt_name
=
(
getattr
(
lopt
,
"name"
,
None
)
or
getattr
(
lopt
,
"__name__"
,
""
))
or
getattr
(
lopt
,
"__name__"
,
""
))
if
node
not
in
env
.
nodes
:
if
node
not
in
env
.
nodes
:
break
# go to next node
# go to next node
break
finally
:
finally
:
self
.
detach_updater
(
env
,
u
)
self
.
detach_updater
(
env
,
u
)
self
.
detach_updater
(
env
,
u
)
#TODO: erase this line, it's redundant at best
loop_timing
.
append
(
float
(
time
.
time
()
-
t0
))
loop_timing
.
append
(
float
(
time
.
time
()
-
t0
))
if
max_use_abort
:
if
max_use_abort
:
_logger
.
error
(
"EquilibriumOptimizer max'ed out by '
%
s'"
%
opt_name
_logger
.
error
(
"EquilibriumOptimizer max'ed out by '
%
s'"
%
opt_name
+
". You can safely raise the current threshold of "
+
". You can safely raise the current threshold of "
...
@@ -1216,10 +1404,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1216,10 +1404,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
name
=
getattr
(
self
,
'name'
,
None
)
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
name
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
))
if
depth
!=
0
:
if
depth
!=
0
:
for
lopt
in
self
.
local_optimizers
:
for
lopt
in
self
.
local_optimizers
:
lopt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
-
1
)
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
#################
#################
...
@@ -1242,7 +1432,8 @@ def _check_chain(r, chain):
...
@@ -1242,7 +1432,8 @@ def _check_chain(r, chain):
return
False
return
False
else
:
else
:
try
:
try
:
if
issubclass
(
elem
,
op
.
Op
)
and
not
isinstance
(
r
.
owner
.
op
,
elem
):
if
(
issubclass
(
elem
,
op
.
Op
)
and
not
isinstance
(
r
.
owner
.
op
,
elem
)):
return
False
return
False
except
TypeError
:
except
TypeError
:
return
False
return
False
...
@@ -1256,6 +1447,7 @@ def _check_chain(r, chain):
...
@@ -1256,6 +1447,7 @@ def _check_chain(r, chain):
return
(
r
is
not
None
)
return
(
r
is
not
None
)
#_check_chain.n_calls = 0
#_check_chain.n_calls = 0
def
check_chain
(
r
,
*
chain
):
def
check_chain
(
r
,
*
chain
):
"""WRITEME"""
"""WRITEME"""
if
isinstance
(
r
,
graph
.
Apply
):
if
isinstance
(
r
,
graph
.
Apply
):
...
@@ -1280,7 +1472,7 @@ def pre_greedy_local_optimizer(list_optimizations, out):
...
@@ -1280,7 +1472,7 @@ def pre_greedy_local_optimizer(list_optimizations, out):
add additional node to the inputs of the node, it can
add additional node to the inputs of the node, it can
be needed to call this function multiple time.
be needed to call this function multiple time.
'''
'''
def
local_recursive_function
(
list_opt
,
out
,
optimized_vars
,
depth
):
def
local_recursive_function
(
list_opt
,
out
,
optimized_vars
,
depth
):
if
not
getattr
(
out
,
'owner'
,
None
):
if
not
getattr
(
out
,
'owner'
,
None
):
return
[
out
],
optimized_vars
return
[
out
],
optimized_vars
node
=
out
.
owner
node
=
out
.
owner
...
@@ -1292,11 +1484,11 @@ def pre_greedy_local_optimizer(list_optimizations, out):
...
@@ -1292,11 +1484,11 @@ def pre_greedy_local_optimizer(list_optimizations, out):
else
:
else
:
if
inp
.
owner
:
if
inp
.
owner
:
outs
,
optimized_vars
=
local_recursive_function
(
outs
,
optimized_vars
=
local_recursive_function
(
list_opt
list_opt
,
,
inp
inp
,
,
optimized_vars
optimized_vars
,
,
depth
+
1
)
depth
+
1
)
for
k
,
v
in
zip
(
inp
.
owner
.
outputs
,
outs
):
for
k
,
v
in
zip
(
inp
.
owner
.
outputs
,
outs
):
optimized_vars
[
k
]
=
v
optimized_vars
[
k
]
=
v
nw_in
=
outs
[
inp
.
owner
.
outputs
.
index
(
inp
)]
nw_in
=
outs
[
inp
.
owner
.
outputs
.
index
(
inp
)]
...
@@ -1310,10 +1502,10 @@ def pre_greedy_local_optimizer(list_optimizations, out):
...
@@ -1310,10 +1502,10 @@ def pre_greedy_local_optimizer(list_optimizations, out):
ret
=
opt
.
transform
(
node
)
ret
=
opt
.
transform
(
node
)
if
ret
is
not
False
and
ret
is
not
None
:
if
ret
is
not
False
and
ret
is
not
None
:
assert
len
(
ret
)
==
len
(
node
.
outputs
)
assert
len
(
ret
)
==
len
(
node
.
outputs
)
for
k
,
v
in
zip
(
node
.
outputs
,
ret
):
for
k
,
v
in
zip
(
node
.
outputs
,
ret
):
optimized_vars
[
k
]
=
v
optimized_vars
[
k
]
=
v
results
=
ret
results
=
ret
if
ret
[
0
]
.
owner
:
if
ret
[
0
]
.
owner
:
node
=
out
.
owner
node
=
out
.
owner
else
:
else
:
break
break
...
@@ -1324,8 +1516,6 @@ def pre_greedy_local_optimizer(list_optimizations, out):
...
@@ -1324,8 +1516,6 @@ def pre_greedy_local_optimizer(list_optimizations, out):
return
final_outs
[
0
]
return
final_outs
[
0
]
############
############
### Misc ###
### Misc ###
############
############
...
...
theano/tensor/opt.py
浏览文件 @
d4dfbf2a
...
@@ -1823,7 +1823,6 @@ def local_subtensor_merge(node):
...
@@ -1823,7 +1823,6 @@ def local_subtensor_merge(node):
merged_slices
.
append
(
slice1
)
merged_slices
.
append
(
slice1
)
pos_1
+=
1
pos_1
+=
1
if
pos_2
<
len
(
slices2
):
if
pos_2
<
len
(
slices2
):
merged_slices
+=
slices2
[
pos_2
:]
merged_slices
+=
slices2
[
pos_2
:]
else
:
else
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论