Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d4dfbf2a
提交
d4dfbf2a
authored
4月 19, 2012
作者:
nouiz
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #612 from lamblin/merge_feature_rebased
Merge feature (rebased)
上级
a96d5716
c42d1808
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
531 行增加
和
289 行删除
+531
-289
mode.py
theano/compile/mode.py
+132
-79
opt.py
theano/gof/opt.py
+399
-209
opt.py
theano/tensor/opt.py
+0
-1
没有找到文件。
theano/compile/mode.py
浏览文件 @
d4dfbf2a
"""WRITEME
"""
import
os
,
logging
,
warnings
import
logging
import
numpy
,
theano
import
numpy
import
theano
from
theano
import
gof
import
theano.gof.vm
from
theano.configparser
import
config
,
AddConfigVar
,
StrParam
,
EnumStr
from
theano.configparser
import
config
,
AddConfigVar
,
StrParam
_logger
=
logging
.
getLogger
(
'theano.compile.mode'
)
AddConfigVar
(
'optimizer_excluding'
,
"When using the default mode, we will remove optimizer with that tag. Separate many tags with ':'."
,
(
"When using the default mode, we will remove optimizer with these "
"tags. Separate tags with ':'."
),
StrParam
(
""
,
allow_override
=
False
),
in_c_key
=
False
)
AddConfigVar
(
'optimizer_including'
,
"When using the default mode, we will add optimizer with that tag. Separate many tags with ':'."
,
(
"When using the default mode, we will add optimizer with these tags. "
"Separate tags with ':'."
),
StrParam
(
""
,
allow_override
=
False
),
in_c_key
=
False
)
AddConfigVar
(
'optimizer_requiring'
,
"When using the default mode, we will require optimizer with that tag. Separate many tags with ':'."
,
(
"When using the default mode, we will require optimizer with these "
"tags. Separate tags with ':'."
),
StrParam
(
""
,
allow_override
=
False
),
in_c_key
=
False
)
def
check_equal
(
x
,
y
):
"""
Returns True iff x[0] and y[0] are equal (checks the dtype and
...
...
@@ -32,35 +38,37 @@ def check_equal(x, y):
import
scipy.sparse
as
sp
x
,
y
=
x
[
0
],
y
[
0
]
# TODO: bug in current scipy, two sparse matrices are never equal, remove when moving to 0.7
# TODO: bug in current scipy, two sparse matrices are never equal,
# remove when moving to 0.7
if
sp
.
issparse
(
x
):
x
=
x
.
todense
()
if
sp
.
issparse
(
y
):
y
=
y
.
todense
()
if
isinstance
(
x
,
numpy
.
ndarray
)
and
isinstance
(
y
,
numpy
.
ndarray
):
if
x
.
dtype
!=
y
.
dtype
or
x
.
shape
!=
y
.
shape
or
numpy
.
any
(
abs
(
x
-
y
)
>
1e-10
):
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
,
'clinker'
:
y
})
if
(
x
.
dtype
!=
y
.
dtype
or
x
.
shape
!=
y
.
shape
or
numpy
.
any
(
abs
(
x
-
y
)
>
1e-10
)):
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
,
'clinker'
:
y
})
else
:
if
x
!=
y
:
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
,
'clinker'
:
y
})
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
,
'clinker'
:
y
})
# If a string is passed as the linker argument in the constructor for
# Mode, it will be used as the key to retrieve the real linker in this
# dictionary
predefined_linkers
=
{
'py'
:
gof
.
PerformLinker
(),
'c'
:
gof
.
CLinker
(),
'c|py'
:
gof
.
OpWiseCLinker
(
allow_gc
=
True
),
'c|py_nogc'
:
gof
.
OpWiseCLinker
(
allow_gc
=
False
),
'c&py'
:
gof
.
DualLinker
(
checker
=
check_equal
),
'vm'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
True
,
use_cloop
=
False
),
'cvm'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
True
,
use_cloop
=
True
),
'vm_nogc'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
False
,
use_cloop
=
False
),
'py'
:
gof
.
PerformLinker
(),
'c'
:
gof
.
CLinker
(),
'c|py'
:
gof
.
OpWiseCLinker
(
allow_gc
=
True
),
'c|py_nogc'
:
gof
.
OpWiseCLinker
(
allow_gc
=
False
),
'c&py'
:
gof
.
DualLinker
(
checker
=
check_equal
),
'vm'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
True
,
use_cloop
=
False
),
'cvm'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
True
,
use_cloop
=
True
),
'vm_nogc'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
False
,
use_cloop
=
False
),
'cvm_nogc'
:
gof
.
vm
.
VM_Linker
(
allow_gc
=
False
,
use_cloop
=
True
),
}
...
...
@@ -72,37 +80,37 @@ def register_linker(name, linker):
predefined_linkers
[
name
]
=
linker
# If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
OPT_FAST_RUN
=
gof
.
Query
(
include
=
[
'fast_run'
])
OPT_FAST_RUN
=
gof
.
Query
(
include
=
[
'fast_run'
])
OPT_FAST_RUN_STABLE
=
OPT_FAST_RUN
.
requiring
(
'stable'
)
OPT_FAST_COMPILE
=
gof
.
Query
(
include
=
[
'fast_compile'
])
OPT_STABILIZE
=
gof
.
Query
(
include
=
[
'fast_run'
])
OPT_FAST_COMPILE
=
gof
.
Query
(
include
=
[
'fast_compile'
])
OPT_STABILIZE
=
gof
.
Query
(
include
=
[
'fast_run'
])
OPT_STABILIZE
.
position_cutoff
=
1.5000001
predefined_optimizers
=
{
None
:
lambda
env
:
None
,
'None'
:
lambda
env
:
None
,
'merge'
:
gof
.
MergeOptimizer
(),
'fast_run'
:
OPT_FAST_RUN
,
'fast_run_stable'
:
OPT_FAST_RUN_STABLE
,
'fast_compile'
:
OPT_FAST_COMPILE
,
None
:
(
lambda
env
:
None
)
,
'None'
:
(
lambda
env
:
None
)
,
'merge'
:
gof
.
MergeOptimizer
(),
'fast_run'
:
OPT_FAST_RUN
,
'fast_run_stable'
:
OPT_FAST_RUN_STABLE
,
'fast_compile'
:
OPT_FAST_COMPILE
,
'stabilize'
:
OPT_STABILIZE
}
def
register_optimizer
(
name
,
opt
):
"""Add a `Optimizer` which can be referred to by `name` in `Mode`."""
if
name
in
predefined_optimizers
:
raise
ValueError
(
'Optimizer name already taken:
%
s'
%
name
)
predefined_optimizers
[
name
]
=
opt
def
register_OutputGuard_c_code
(
type
):
OutputGuard
.
c_code_types
.
append
(
type
)
class
OutputGuard
(
gof
.
Op
):
"""
This op is used only internally by Theano.
...
...
@@ -120,20 +128,24 @@ class OutputGuard(gof.Op):
TODO: find a current full explanation.
"""
destroy_map
=
{
0
:[
0
]}
view_map
=
{
0
:[
0
]}
destroy_map
=
{
0
:
[
0
]}
view_map
=
{
0
:
[
0
]}
c_code_types
=
[]
def
make_node
(
self
,
x
):
return
gof
.
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
perform
(
self
,
node
,
inp
,
out
):
x
,
=
inp
z
,
=
out
z
[
0
]
=
x
def
__str__
(
self
):
return
'
%
s'
%
self
.
__class__
.
__name__
...
...
@@ -141,7 +153,8 @@ class OutputGuard(gof.Op):
x
,
=
inp
z
,
=
out
if
isinstance
(
node
.
inputs
[
0
]
.
type
,
theano
.
scalar
.
Scalar
):
# Scalars are C objects on the stacks, and should not be inc/decrefed
# Scalars are C objects on the stack,
# and should not be inc/decrefed
return
"""
%(z)
s =
%(x)
s;
"""
%
locals
()
...
...
@@ -161,71 +174,99 @@ class OutputGuard(gof.Op):
_output_guard
=
OutputGuard
()
class
AddDestroyHandler
(
gof
.
Optimizer
):
"""This optimizer performs two important functions:
1) it has a 'requirement' of the destroyhandler. This means that the env will include it
as a feature for this optimization, and keep this feature enabled for subsequent
optimizations. All optimizations that work inplace on any of their inputs must run *after*
this optimization to ensure that the DestroyHandler has been included in the env.
1) it has a 'requirement' of the destroyhandler. This means that the env
will include it as a feature for this optimization, and keep this feature
enabled for subsequent optimizations. All optimizations that work inplace
on any of their inputs must run *after* this optimization to ensure that
the DestroyHandler has been included in the env.
2) It tries to replace each output with an Op that purports to destroy it
(but it won't I
promise). If this replacement succeeds it means that there is a bug in theano. It should
not be possible to destroy outputs.
2) It tries to replace each output with an Op that purports to destroy it
(but it won't I promise). If this replacement succeeds it means that
there is a bug in theano. It should
not be possible to destroy outputs.
"""
def
apply
(
self
,
env
):
for
o
in
env
.
outputs
:
try
:
env
.
replace_validate
(
o
,
_output_guard
(
o
),
reason
=
'output_guard'
)
_logger
.
info
(
"Output variable
%
s required output_guard,"
" how was this output left unprotected against destructive operations?"
env
.
replace_validate
(
o
,
_output_guard
(
o
),
reason
=
'output_guard'
)
_logger
.
info
(
"Output variable
%
s required output_guard, "
"how was this output left unprotected against "
"destructive operations?"
%
o
)
except
gof
.
InconsistencyError
:
#this output is already impossible to destroy. no guard necessary
# This output is already impossible to destroy.
# No guard necessary
pass
def
add_requirements
(
self
,
env
):
super
(
AddDestroyHandler
,
self
)
.
add_requirements
(
env
)
env
.
extend
(
gof
.
DestroyHandler
())
class
PrintCurrentEnv
(
gof
.
Optimizer
):
"""This optimizer is for debugging.
Toss it into the optimization pipeline to see the state of things at any given point.
Toss it into the optimization pipeline to see the state of things at any
given point.
"""
def
__init__
(
self
,
header
):
self
.
header
=
header
self
.
header
=
header
def
apply
(
self
,
env
):
import
theano.printing
print
"PrintCurrentEnv:"
,
self
.
header
theano
.
printing
.
debugprint
(
env
.
outputs
)
optdb
=
gof
.
SequenceDB
()
optdb
.
register
(
'merge1'
,
gof
.
MergeOptimizer
(),
0
,
'fast_run'
,
'fast_compile'
)
optdb
.
register
(
'canonicalize'
,
gof
.
EquilibriumDB
(),
# rearranges elemwise expressions
# rearranges elemwise expressions
optdb
.
register
(
'canonicalize'
,
gof
.
EquilibriumDB
(),
1
,
'fast_run'
,
'fast_compile'
)
optdb
.
register
(
'merge1.2'
,
gof
.
MergeOptimizer
(
skip_const_merge
=
False
),
optdb
.
register
(
'merge1.2'
,
gof
.
MergeOptimizer
(),
1.2
,
'fast_run'
,
'fast_compile'
)
optdb
.
register
(
'Print1.21'
,
PrintCurrentEnv
(
'Post-canonicalize'
),
1.21
,)
# 'fast_run', 'fast_compile')
1.21
,)
# 'fast_run', 'fast_compile')
optdb
.
register
(
'stabilize'
,
gof
.
EquilibriumDB
(),
# replace unstable subgraphs
# replace unstable subgraphs
optdb
.
register
(
'stabilize'
,
gof
.
EquilibriumDB
(),
1.5
,
'fast_run'
)
optdb
.
register
(
'Print1.51'
,
PrintCurrentEnv
(
'Post-stabilize'
),
1.51
,)
#'fast_run', 'fast_compile')
optdb
.
register
(
'specialize'
,
gof
.
EquilibriumDB
(),
# misc special cases for speed
1.51
,)
# 'fast_run', 'fast_compile')
# misc special cases for speed
optdb
.
register
(
'specialize'
,
gof
.
EquilibriumDB
(),
2
,
'fast_run'
)
optdb
.
register
(
'Print2.01'
,
PrintCurrentEnv
(
'Post-specialize'
),
2.01
,
)
#'fast_run', 'fast_compile')
optdb
.
register
(
'uncanonicalize'
,
gof
.
EquilibriumDB
(),
# misc special cases for speed that break canonicalization
2.01
,)
# 'fast_run', 'fast_compile')
# misc special cases for speed that break canonicalization
optdb
.
register
(
'uncanonicalize'
,
gof
.
EquilibriumDB
(),
3
,
'fast_run'
)
optdb
.
register
(
'specialize_device'
,
gof
.
EquilibriumDB
(),
# misc special cases for speed that are dependent on the device.
48.6
,
'fast_run'
)
#must be after gpu stuff at 48.5
optdb
.
register
(
'merge2'
,
gof
.
MergeOptimizer
(),
# especially constant merge
# misc special cases for speed that are dependent on the device.
optdb
.
register
(
'specialize_device'
,
gof
.
EquilibriumDB
(),
48.6
,
'fast_run'
)
# must be after gpu stuff at 48.5
# especially constant merge
optdb
.
register
(
'merge2'
,
gof
.
MergeOptimizer
(),
49
,
'fast_run'
)
optdb
.
register
(
'add_destroy_handler'
,
AddDestroyHandler
(),
49.5
,
'fast_run'
,
'inplace'
)
optdb
.
register
(
'merge3'
,
gof
.
MergeOptimizer
(),
# final pass just to make sure
# final pass just to make sure
optdb
.
register
(
'merge3'
,
gof
.
MergeOptimizer
(),
100
,
'fast_run'
)
...
...
@@ -251,12 +292,15 @@ class Mode(object):
if
optimizer
is
None
:
optimizer
=
config
.
optimizer
self
.
__setstate__
((
linker
,
optimizer
))
#self.provided_optimizer - typically the `optimizer` arg. But if the `optimizer` arg is
# keyword corresponding to a predefined Query, then this stores the query
#self._optimizer - typically same as provided_optimizer??
#self.__get_optimizer - returns self._optimizer (possibly querying optdb with self._optimizer)
#self.optimizer - property that returns __get_optimizer()
# self.provided_optimizer - typically the `optimizer` arg.
# But if the `optimizer` arg is keyword corresponding to a predefined
# Query, then this stores the query
# self._optimizer - typically same as provided_optimizer??
# self.__get_optimizer - returns self._optimizer (possibly querying
# optdb with self._optimizer)
# self.optimizer - property that returns __get_optimizer()
def
__getstate__
(
self
):
return
(
self
.
provided_linker
,
self
.
provided_optimizer
)
...
...
@@ -275,12 +319,13 @@ class Mode(object):
self
.
_optimizer
=
optimizer
self
.
call_time
=
0
self
.
fn_time
=
0
linker
.
mode
=
self
#
TODO: WHY IS THIS HERE?
linker
.
mode
=
self
#
TODO: WHY IS THIS HERE?
self
.
optimizer_time
=
0
self
.
linker_time
=
0
def
__str__
(
self
):
return
"Mode(linker =
%
s, optimizer =
%
s)"
%
(
self
.
provided_linker
,
self
.
provided_optimizer
)
return
"Mode(linker =
%
s, optimizer =
%
s)"
%
(
self
.
provided_linker
,
self
.
provided_optimizer
)
def
__get_optimizer
(
self
):
if
isinstance
(
self
.
_optimizer
,
gof
.
Query
):
...
...
@@ -298,17 +343,20 @@ class Mode(object):
return
(
linker
,
optimizer
)
def
including
(
self
,
*
tags
):
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
)
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
)
#N.B. opt might be a Query instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows???
return
self
.
__class__
(
linker
=
link
,
optimizer
=
opt
.
including
(
*
tags
))
def
excluding
(
self
,
*
tags
):
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
)
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
)
return
self
.
__class__
(
linker
=
link
,
optimizer
=
opt
.
excluding
(
*
tags
))
def
requiring
(
self
,
*
tags
):
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
)
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
)
return
self
.
__class__
(
linker
=
link
,
optimizer
=
opt
.
requiring
(
*
tags
))
# If a string is passed as the mode argument in function or
...
...
@@ -321,20 +369,22 @@ predefined_modes = {'FAST_COMPILE': FAST_COMPILE,
'FAST_RUN'
:
FAST_RUN
,
}
instanciated_default_mode
=
None
instanciated_default_mode
=
None
def
get_mode
(
orig_string
):
if
orig_string
is
None
:
string
=
config
.
mode
else
:
string
=
orig_string
if
not
isinstance
(
string
,
basestring
):
return
string
#
it is hopefully already a mode...
return
string
#
it is hopefully already a mode...
global
instanciated_default_mode
# The default mode is cached. However, config.mode can change
# If instanciated_default_mode has the right class, use it.
if
orig_string
is
None
and
instanciated_default_mode
:
if
predefined_modes
.
has_key
(
string
)
:
if
string
in
predefined_modes
:
default_mode_class
=
predefined_modes
[
string
]
.
__class__
.
__name__
else
:
default_mode_class
=
string
...
...
@@ -342,7 +392,7 @@ def get_mode(orig_string):
default_mode_class
):
return
instanciated_default_mode
if
string
in
[
'Mode'
,
'ProfileMode'
,
'DebugMode'
]:
if
string
in
[
'Mode'
,
'ProfileMode'
,
'DebugMode'
]:
if
string
==
'DebugMode'
:
#need to import later to break circular dependency.
from
debugmode
import
DebugMode
...
...
@@ -350,12 +400,13 @@ def get_mode(orig_string):
ret
=
DebugMode
(
optimizer
=
config
.
optimizer
)
else
:
# The import is needed in case string is ProfileMode
from
profilemode
import
ProfileMode
,
prof_mode_instance_to_print
ret
=
eval
(
string
+
'(linker=config.linker, optimizer=config.optimizer)'
)
elif
predefined_modes
.
has_key
(
string
):
from
profilemode
import
ProfileMode
,
prof_mode_instance_to_print
ret
=
eval
(
string
+
'(linker=config.linker, optimizer=config.optimizer)'
)
elif
string
in
predefined_modes
:
ret
=
predefined_modes
[
string
]
else
:
raise
Exception
(
"No predefined mode exist for string:
%
s"
%
string
)
raise
Exception
(
"No predefined mode exist for string:
%
s"
%
string
)
if
orig_string
is
None
:
# Build and cache the default mode
...
...
@@ -374,12 +425,14 @@ def get_mode(orig_string):
return
ret
def
get_default_mode
():
return
get_mode
(
None
)
# Removed: use config.mode instead.
#default_mode = config.mode
def
register_mode
(
name
,
mode
):
"""Add a `Mode` which can be referred to by `name` in `function`."""
if
name
in
predefined_modes
:
...
...
theano/gof/opt.py
浏览文件 @
d4dfbf2a
...
...
@@ -3,21 +3,23 @@ Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
"""
import
copy
,
logging
,
sys
,
time
import
copy
import
logging
import
sys
import
time
import
numpy
import
graph
from
env
import
InconsistencyError
import
op
import
utils
import
unify
import
toolbox
import
op
import
theano
from
theano
import
config
from
theano.gof.python25
import
any
,
all
,
deque
from
theano.configparser
import
AddConfigVar
,
BoolParam
,
config
from
theano.configparser
import
AddConfigVar
,
BoolParam
#if sys.version_info[:2] >= (2,5):
# from collections import defaultdict
...
...
@@ -39,9 +41,11 @@ import traceback
_optimizer_idx
=
[
0
]
def
_list_of_nodes
(
env
):
return
list
(
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
))
class
Optimizer
(
object
):
"""WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it.
...
...
@@ -91,26 +95,30 @@ class Optimizer(object):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
name
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
))
class
FromFunctionOptimizer
(
Optimizer
):
"""WRITEME"""
def
__init__
(
self
,
fn
):
self
.
apply
=
fn
def
add_requirements
(
self
,
env
):
# Added by default
#env.extend(toolbox.ReplaceValidate())
pass
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
str
(
self
.
apply
),
id
(
self
))
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
fn
(
*
args
,
**
kwargs
)
def
optimizer
(
f
):
"""decorator for FromFunctionOptimizer"""
rval
=
FromFunctionOptimizer
(
f
)
...
...
@@ -118,7 +126,6 @@ def optimizer(f):
return
rval
class
SeqOptimizer
(
Optimizer
,
list
):
#inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME
...
...
@@ -129,7 +136,7 @@ class SeqOptimizer(Optimizer, list):
def
warn
(
exc
,
self
,
optimizer
):
"""Default failure_callback for SeqOptimizer
"""
_logger
.
error
(
"SeqOptimizer apply
%
s"
%
str
(
optimizer
))
_logger
.
error
(
"SeqOptimizer apply
%
s"
%
str
(
optimizer
))
_logger
.
error
(
"Traceback:"
)
_logger
.
error
(
traceback
.
format_exc
())
if
config
.
on_opt_error
==
'raise'
:
...
...
@@ -146,14 +153,15 @@ class SeqOptimizer(Optimizer, list):
"""WRITEME
Applies each L{Optimizer} in self in turn.
"""
l
=
[]
l
=
[]
nb_node_before
=
len
(
env
.
nodes
)
for
optimizer
in
self
:
try
:
t0
=
time
.
time
()
t0
=
time
.
time
()
optimizer
.
optimize
(
env
)
l
.
append
(
float
(
time
.
time
()
-
t0
))
except
AssertionError
:
# do not catch Assertion failures
l
.
append
(
float
(
time
.
time
()
-
t0
))
except
AssertionError
:
# do not catch Assertion failures
raise
except
Exception
,
e
:
if
self
.
failure_callback
:
...
...
@@ -192,7 +200,6 @@ class SeqOptimizer(Optimizer, list):
#added to override the list's __neq__ implementation
return
id
(
self
)
!=
id
(
other
)
def
__str__
(
self
):
return
"SeqOpt(
%
s)"
%
list
.
__str__
(
self
)
...
...
@@ -201,14 +208,13 @@ class SeqOptimizer(Optimizer, list):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
name
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
))
# This way, -1 will do all depth
if
depth
!=
0
:
depth
-=
1
for
opt
in
self
:
opt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
)
opt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
depth
)
class
_metadict
:
...
...
@@ -219,17 +225,39 @@ class _metadict:
def
__init__
(
self
):
self
.
d
=
{}
self
.
l
=
[]
def
__getitem__
(
self
,
item
):
return
self
.
get
(
item
,
None
)
def
__setitem__
(
self
,
item
,
value
):
try
:
self
.
d
[
item
]
=
value
except
Exception
:
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
if
key
==
item
:
self
.
l
[
i
]
=
(
item
,
value
)
return
self
.
l
.
append
((
item
,
value
))
def
__delitem__
(
self
,
item
):
if
item
in
self
.
d
:
del
self
.
d
[
item
]
else
:
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
if
key
==
item
:
del
self
.
l
[
i
]
return
raise
KeyError
(
item
)
def
discard
(
self
,
item
):
if
item
in
self
.
d
:
del
self
.
d
[
item
]
else
:
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
if
key
==
item
:
del
self
.
l
[
i
]
return
def
get
(
self
,
item
,
default
):
try
:
return
self
.
d
[
item
]
...
...
@@ -245,13 +273,148 @@ class _metadict:
return
value
else
:
return
default
def
clear
(
self
):
self
.
d
=
{}
self
.
l
=
[]
def
__str__
(
self
):
return
"(
%
s,
%
s)"
%
(
self
.
d
,
self
.
l
)
class
MergeFeature
(
object
):
"""
Keeps track of variables in env that cannot be merged together.
That way, the MergeOptimizer can remember the result of the last merge
pass on the env.
"""
def
on_attach
(
self
,
env
):
assert
not
hasattr
(
env
,
'merge_feature'
)
env
.
merge_feature
=
self
## For constants
self
.
seen_constants
=
set
()
# variable -> signature (for constants)
self
.
const_sig
=
_metadict
()
# signature -> variable (for constants)
self
.
const_sig_inv
=
_metadict
()
## For all variables
# Set of distinct (not mergeable) nodes
self
.
nodes_seen
=
set
()
# Each element of scheduled is a list of list of (out, new_out) pairs.
# Each list of pairs represent the substitution needed to replace all
# the outputs of a node with the outputs of a replacement candidate.
# Each node can have several candidates. For instance, if "node" has
# 2 outputs, and there are 3 replacement candidates, we will have:
# shelf.scheduled = [
# [[(node.out1, cand1.out1), (node.out2, cand1.out2)],
# [(node.out1, cand2.out1), (node.out2, cand2.out2)],
# [(node.out1, cand3.out1), (node.out2, cand3.out2)]]]
self
.
scheduled
=
[]
# List of (node, candidate) pairs, where we tried to replace node by
# candidate, but it failed. This is used to avoid infinite loops
# during the replacement phase.
self
.
blacklist
=
[]
for
node
in
env
.
toposort
():
self
.
on_import
(
env
,
node
)
def
on_change_input
(
self
,
env
,
node
,
i
,
r
,
new_r
):
# If inputs to node change, it is not guaranteed that it is distinct
# from the other nodes in nodes_seen
if
node
in
self
.
nodes_seen
:
self
.
nodes_seen
.
discard
(
node
)
self
.
process_node
(
env
,
node
)
if
isinstance
(
new_r
,
graph
.
Constant
):
self
.
process_constant
(
env
,
new_r
)
def
on_import
(
self
,
env
,
node
):
for
c
in
node
.
inputs
:
if
isinstance
(
c
,
graph
.
Constant
):
self
.
process_constant
(
env
,
c
)
self
.
process_node
(
env
,
node
)
def
on_prune
(
self
,
env
,
node
):
self
.
nodes_seen
.
discard
(
node
)
for
c
in
node
.
inputs
:
if
isinstance
(
c
,
graph
.
Constant
)
and
(
len
(
c
.
clients
)
<=
1
):
# This was the last node using this constant
sig
=
self
.
const_sig
[
c
]
self
.
const_sig
.
discard
(
c
)
self
.
const_sig_inv
.
discard
(
sig
)
self
.
seen_constants
.
discard
(
id
(
c
))
def
process_constant
(
self
,
env
,
c
):
"""Check if a constant can be merged, and queue that replacement"""
if
id
(
c
)
in
self
.
seen_constants
:
return
sig
=
c
.
signature
()
other_c
=
self
.
const_sig_inv
.
get
(
sig
,
None
)
if
other_c
is
not
None
:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if
c
.
name
:
other_c
.
name
=
c
.
name
self
.
scheduled
.
append
([[(
c
,
other_c
)]])
else
:
#this is a new constant
self
.
const_sig
[
c
]
=
sig
self
.
const_sig_inv
[
sig
]
=
c
self
.
seen_constants
.
add
(
id
(
c
))
def
process_node
(
self
,
env
,
node
):
"""Check if a node can be merged, and queue that replacement."""
if
node
in
self
.
nodes_seen
:
return
# These asserts ensure that the env has set the clients field properly.
# The clients should at least contain `node` itself!
if
node
.
inputs
:
assert
len
(
node
.
inputs
[
0
]
.
clients
)
>
0
assert
(
node
,
0
)
in
node
.
inputs
[
0
]
.
clients
merge_candidates
=
[
c
for
(
c
,
i
)
in
node
.
inputs
[
0
]
.
clients
if
c
in
self
.
nodes_seen
]
else
:
merge_candidates
=
[]
replacement_candidates
=
[]
for
candidate
in
merge_candidates
:
if
candidate
is
node
:
continue
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
continue
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
))
if
inputs_match
and
node
.
op
==
candidate
.
op
:
if
(
node
,
candidate
)
in
self
.
blacklist
:
# They were already tried, and there was an error
continue
# Schedule transfer of clients from node to candidate
pairs
=
zip
(
node
.
outputs
,
candidate
.
outputs
)
#transfer names
for
node_output
,
cand_output
in
pairs
:
#clobber old name with new one
#it's arbitrary... one of the names has to go
if
node_output
.
name
:
cand_output
.
name
=
node_output
.
name
replacement_candidates
.
append
(
pairs
)
if
replacement_candidates
:
self
.
scheduled
.
append
(
replacement_candidates
)
else
:
self
.
nodes_seen
.
add
(
node
)
class
MergeOptimizer
(
Optimizer
):
"""
Merges parts of the graph that are identical and redundant.
...
...
@@ -264,94 +427,32 @@ class MergeOptimizer(Optimizer):
The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1).
"""
def
__init__
(
self
,
skip_const_merge
=
False
):
self
.
skip_const_merge
=
skip_const_merge
def
add_requirements
(
self
,
env
):
# Added by default
#env.extend(toolbox.ReplaceValidate())
pass
if
not
hasattr
(
env
,
'merge_feature'
):
env
.
extend
(
MergeFeature
())
def
apply_constant_merge
(
self
,
env
):
seen_constants
=
set
()
const_sig
=
_metadict
()
# variable -> variable.signature() (for constants)
const_sig_inv
=
_metadict
()
# signature -> variable (for constants)
for
node
in
_list_of_nodes
(
env
):
for
i
,
c
in
enumerate
([
r
for
r
in
node
.
inputs
if
isinstance
(
r
,
graph
.
Constant
)]):
if
id
(
c
)
in
seen_constants
:
continue
else
:
seen_constants
.
add
(
id
(
c
))
sig
=
c
.
signature
()
other_c
=
const_sig_inv
.
get
(
sig
,
None
)
if
other_c
is
not
None
:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if
c
.
name
:
other_c
.
name
=
c
.
name
env
.
replace_validate
(
c
,
other_c
,
reason
=
'Constant Merge'
)
else
:
#this is a new constant
const_sig
[
c
]
=
sig
const_sig_inv
[
sig
]
=
c
def
apply_node_merge
(
self
,
env
):
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Variables
nodes_seen
=
{}
for
node_idx
,
node
in
enumerate
(
_list_of_nodes
(
env
)):
#
# these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself!
#
if
node
.
inputs
:
assert
len
(
node
.
inputs
[
0
]
.
clients
)
>
0
assert
(
node
,
0
)
in
node
.
inputs
[
0
]
.
clients
merge_candidates
=
[(
nodes_seen
[
c
],
c
)
for
(
c
,
i
)
in
node
.
inputs
[
0
]
.
clients
if
c
in
nodes_seen
]
else
:
merge_candidates
=
[]
merge_candidates
.
sort
()
nodes_seen
[
node
]
=
node_idx
#print 'NODE', node, merge_candidates, node.inputs[0].clients
for
candidate_idx
,
candidate
in
merge_candidates
:
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
continue
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
))
if
inputs_match
and
node
.
op
==
candidate
.
op
:
assert
node
is
not
candidate
#
#transfer clients from node to candidate
#
success
=
True
assert
len
(
node
.
outputs
)
==
len
(
candidate
.
outputs
)
pairs
=
zip
(
node
.
outputs
,
candidate
.
outputs
)
#transfer names
for
node_output
,
cand_output
in
pairs
:
#clobber old name with new one
#it's arbitrary... one of the names has to go
if
node_output
.
name
:
cand_output
.
name
=
node_output
.
name
try
:
env
.
replace_all_validate
(
pairs
,
reason
=
"Merge"
)
except
InconsistencyError
,
e
:
success
=
False
if
success
:
#break out of the candidate loop
break
else
:
#try the next candidate
pass
#TODO: Consider splitting this into a separate optimizer (SeqOptimizer)
def
apply
(
self
,
env
):
if
not
self
.
skip_const_merge
:
self
.
apply_constant_merge
(
env
)
self
.
apply_node_merge
(
env
)
# Constant and non-constant are now applied in the same phase.
# I am not sure why, but it seems to be faster this way.
sched
=
env
.
merge_feature
.
scheduled
while
sched
:
pairs_list
=
sched
.
pop
()
success
=
True
for
pairs
in
pairs_list
:
try
:
env
.
replace_all_validate
(
pairs
,
'Merge'
)
except
InconsistencyError
:
success
=
False
env
.
merge_feature
.
blacklist
.
append
(
(
pairs
[
0
][
0
]
.
owner
,
pairs
[
0
][
1
]
.
owner
))
if
success
:
break
# clear blacklist
env
.
merge_feature
.
blacklist
=
[]
merge_optimizer
=
MergeOptimizer
()
...
...
@@ -417,8 +518,9 @@ def pre_constant_merge(vars):
"""
seen_var
=
set
()
const_sig
=
{}
# variable -> variable.signature() (for constants)
const_sig_inv
=
{}
# signature -> variable (for constants)
# signature -> variable (for constants)
const_sig_inv
=
{}
def
recursive_merge
(
var
):
if
var
in
seen_var
:
return
var
...
...
@@ -434,12 +536,13 @@ def pre_constant_merge(vars):
const_sig_inv
[
sig
]
=
var
return
var
if
var
.
owner
:
for
idx
,
inp
in
enumerate
(
var
.
owner
.
inputs
):
for
idx
,
inp
in
enumerate
(
var
.
owner
.
inputs
):
var
.
owner
.
inputs
[
idx
]
=
recursive_merge
(
inp
)
return
var
return
map
(
recursive_merge
,
vars
)
########################
### Local Optimizers ###
########################
...
...
@@ -463,25 +566,31 @@ class LocalOptimizer(object):
Subclasses should implement this function so that it returns one of two
kinds of things:
- False to indicate that no optimization can be applied to this `node`; or
- <list of variables> to use in place of `node`'s outputs in the greater graph.
- False to indicate that no optimization can be applied to this `node`;
or
- <list of variables> to use in place of `node`'s outputs in the
greater graph.
:type node: an Apply instance
"""
raise
utils
.
MethodNotDefined
(
"transform"
,
type
(
self
),
self
.
__class__
.
__name__
)
raise
utils
.
MethodNotDefined
(
"transform"
,
type
(
self
),
self
.
__class__
.
__name__
)
def
add_requirements
(
self
,
env
):
"""If this local optimization wants to add some requirements to the env,
This is the place to do it."""
"""
If this local optimization wants to add some requirements to the env,
This is the place to do it.
"""
# Added by default
#env.extend(toolbox.ReplaceValidate())
pass
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
"""WRITEME"""
...
...
@@ -490,15 +599,21 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
tracks
=
[]
self
.
transform
=
fn
self
.
_tracks
=
tracks
def
tracks
(
self
):
return
self
.
_tracks
def
__str__
(
self
):
return
getattr
(
self
,
'__name__'
,
'<FromFunctionLocalOptimizer instance>'
)
return
getattr
(
self
,
'__name__'
,
'<FromFunctionLocalOptimizer instance>'
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
str
(
self
.
transform
),
id
(
self
))
def
local_optimizer
(
*
tracks
):
def
decorator
(
f
):
"""WRITEME"""
...
...
@@ -513,11 +628,15 @@ class LocalOptGroup(LocalOptimizer):
def
__init__
(
self
,
*
optimizers
):
self
.
opts
=
optimizers
self
.
reentrant
=
any
(
getattr
(
opt
,
'reentrant'
,
True
)
for
opt
in
optimizers
)
self
.
retains_inputs
=
all
(
getattr
(
opt
,
'retains_inputs'
,
False
)
for
opt
in
optimizers
)
self
.
reentrant
=
any
(
getattr
(
opt
,
'reentrant'
,
True
)
for
opt
in
optimizers
)
self
.
retains_inputs
=
all
(
getattr
(
opt
,
'retains_inputs'
,
False
)
for
opt
in
optimizers
)
def
__str__
(
self
):
return
getattr
(
self
,
'__name__'
,
'<theano.gof.opt.LocalOptGroup instance>'
+
str
([
str
(
o
)
for
o
in
self
.
opts
]))
return
getattr
(
self
,
'__name__'
,
(
'<theano.gof.opt.LocalOptGroup instance>'
+
str
([
str
(
o
)
for
o
in
self
.
opts
])))
def
transform
(
self
,
node
):
for
opt
in
self
.
opts
:
...
...
@@ -526,11 +645,12 @@ class LocalOptGroup(LocalOptimizer):
return
repl
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
if
depth
!=
0
:
depth
-=
1
for
lopt
in
self
.
opts
:
lopt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
)
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
)
,
depth
=
depth
)
class
_LocalOpKeyOptGroup
(
LocalOptGroup
):
...
...
@@ -550,13 +670,16 @@ class OpSub(LocalOptimizer):
Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing.
e.g. OpSub(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
e.g. OpSub(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
reentrant
=
False
# an OpSub does not apply to the nodes it produces
retains_inputs
=
True
# all the inputs of the original node are transferred to the outputs
# an OpSub does not apply to the nodes it produces
reentrant
=
False
# all the inputs of the original node are transferred to the outputs
retains_inputs
=
True
def
__init__
(
self
,
op1
,
op2
,
transfer_tags
=
True
):
def
__init__
(
self
,
op1
,
op2
,
transfer_tags
=
True
):
"""
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
...
...
@@ -611,7 +734,8 @@ class OpRemove(LocalOptimizer):
return
"
%
s(x) -> x"
%
(
self
.
op
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s(
%
s) id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s(
%
s) id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
str
(
self
.
op
),
id
(
self
))
...
...
@@ -662,12 +786,12 @@ class PatternSub(LocalOptimizer):
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
"""
def
__init__
(
self
,
in_pattern
,
out_pattern
,
allow_multiple_clients
=
False
,
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
):
def
__init__
(
self
,
in_pattern
,
out_pattern
,
allow_multiple_clients
=
False
,
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
):
"""
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
...
...
@@ -677,7 +801,8 @@ class PatternSub(LocalOptimizer):
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
one client.
:param pdb: if True, we invoke pdb when the first node in the pattern match.
:param pdb: if True, we invoke pdb when the first node in the
pattern match.
"""
self
.
in_pattern
=
in_pattern
self
.
out_pattern
=
out_pattern
...
...
@@ -686,8 +811,11 @@ class PatternSub(LocalOptimizer):
elif
isinstance
(
in_pattern
,
dict
):
self
.
op
=
self
.
in_pattern
[
'pattern'
][
0
]
else
:
raise
TypeError
(
"The pattern to search for must start with a specific Op instance."
)
self
.
__doc__
=
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
raise
TypeError
(
"The pattern to search for must start with "
"a specific Op instance."
)
self
.
__doc__
=
(
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
)
self
.
allow_multiple_clients
=
allow_multiple_clients
self
.
skip_identities_fn
=
skip_identities_fn
if
name
:
...
...
@@ -722,7 +850,7 @@ class PatternSub(LocalOptimizer):
if
node
.
op
!=
self
.
op
:
return
False
def
match
(
pattern
,
expr
,
u
,
allow_multiple_clients
=
False
,
pdb
=
False
):
def
match
(
pattern
,
expr
,
u
,
allow_multiple_clients
=
False
,
pdb
=
False
):
def
retry_with_equiv
():
expr_equiv
=
self
.
skip_identities
(
expr
)
if
expr_equiv
is
None
:
...
...
@@ -735,7 +863,9 @@ class PatternSub(LocalOptimizer):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
expr
.
owner
is
None
:
return
False
if
not
(
expr
.
owner
.
op
==
pattern
[
0
])
or
(
not
allow_multiple_clients
and
len
(
expr
.
clients
)
>
1
):
if
(
not
(
expr
.
owner
.
op
==
pattern
[
0
])
or
(
not
allow_multiple_clients
and
len
(
expr
.
clients
)
>
1
)):
return
retry_with_equiv
()
if
len
(
pattern
)
-
1
!=
len
(
expr
.
owner
.
inputs
):
return
retry_with_equiv
()
...
...
@@ -747,10 +877,14 @@ class PatternSub(LocalOptimizer):
try
:
real_pattern
=
pattern
[
'pattern'
]
except
KeyError
:
raise
KeyError
(
"Malformed pattern:
%
s (expected key 'pattern')"
%
pattern
)
raise
KeyError
(
"Malformed pattern:
%
s (expected key 'pattern')"
%
pattern
)
constraint
=
pattern
.
get
(
'constraint'
,
lambda
expr
:
True
)
if
constraint
(
expr
):
return
match
(
real_pattern
,
expr
,
u
,
pattern
.
get
(
'allow_multiple_clients'
,
allow_multiple_clients
))
return
match
(
real_pattern
,
expr
,
u
,
pattern
.
get
(
'allow_multiple_clients'
,
allow_multiple_clients
))
else
:
return
retry_with_equiv
()
elif
isinstance
(
pattern
,
basestring
):
...
...
@@ -759,17 +893,22 @@ class PatternSub(LocalOptimizer):
return
retry_with_equiv
()
else
:
u
=
u
.
merge
(
expr
,
v
)
elif
isinstance
(
pattern
,
(
int
,
float
))
and
isinstance
(
expr
,
graph
.
Constant
):
if
numpy
.
all
(
theano
.
tensor
.
constant
(
pattern
)
.
value
==
expr
.
value
):
elif
(
isinstance
(
pattern
,
(
int
,
float
))
and
isinstance
(
expr
,
graph
.
Constant
)):
if
numpy
.
all
(
theano
.
tensor
.
constant
(
pattern
)
.
value
==
expr
.
value
):
return
u
else
:
return
retry_with_equiv
()
elif
isinstance
(
pattern
,
graph
.
Constant
)
and
isinstance
(
expr
,
graph
.
Constant
)
and
pattern
.
equals
(
expr
):
elif
(
isinstance
(
pattern
,
graph
.
Constant
)
and
isinstance
(
expr
,
graph
.
Constant
)
and
pattern
.
equals
(
expr
)):
return
u
else
:
return
retry_with_equiv
()
if
pdb
:
import
pdb
;
pdb
.
set_trace
()
import
pdb
pdb
.
set_trace
()
return
u
def
build
(
pattern
,
u
):
...
...
@@ -778,11 +917,12 @@ class PatternSub(LocalOptimizer):
return
pattern
[
0
](
*
args
)
elif
isinstance
(
pattern
,
basestring
):
return
u
[
unify
.
Var
(
pattern
)]
elif
isinstance
(
pattern
,
(
int
,
float
)):
elif
isinstance
(
pattern
,
(
int
,
float
)):
return
pattern
else
:
return
pattern
.
clone
()
u
=
match
(
self
.
in_pattern
,
node
.
out
,
unify
.
Unification
(),
True
,
self
.
pdb
)
u
=
match
(
self
.
in_pattern
,
node
.
out
,
unify
.
Unification
(),
True
,
self
.
pdb
)
if
u
:
p
=
self
.
out_pattern
new
=
build
(
p
,
u
)
...
...
@@ -792,23 +932,31 @@ class PatternSub(LocalOptimizer):
return
False
def
__str__
(
self
):
if
getattr
(
self
,
'__name__'
,
None
):
if
getattr
(
self
,
'__name__'
,
None
):
return
self
.
__name__
def
pattern_to_str
(
pattern
):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
return
"
%
s(
%
s)"
%
(
str
(
pattern
[
0
]),
", "
.
join
([
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]]))
return
"
%
s(
%
s)"
%
(
str
(
pattern
[
0
]),
", "
.
join
([
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]]))
elif
isinstance
(
pattern
,
dict
):
return
"
%
s subject to
%
s"
%
(
pattern_to_str
(
pattern
[
'pattern'
]),
str
(
pattern
.
get
(
'constraint'
,
'no conditions'
)))
return
"
%
s subject to
%
s"
%
(
pattern_to_str
(
pattern
[
'pattern'
]),
str
(
pattern
.
get
(
'constraint'
,
'no conditions'
)))
else
:
return
str
(
pattern
)
return
"
%
s ->
%
s"
%
(
pattern_to_str
(
self
.
in_pattern
),
pattern_to_str
(
self
.
out_pattern
))
return
"
%
s ->
%
s"
%
(
pattern_to_str
(
self
.
in_pattern
),
pattern_to_str
(
self
.
out_pattern
))
def
__repr__
(
self
):
return
str
(
self
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'__name__'
,
getattr
(
self
,
'name'
,
None
))
print
>>
stream
,
"
%
s
%
s
%
s(
%
s,
%
s) id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s
%
s(
%
s,
%
s) id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
name
,
str
(
self
.
in_pattern
),
...
...
@@ -836,37 +984,48 @@ class NavigatorOptimizer(Optimizer):
_logger
.
error
(
traceback
.
format_exc
())
if
isinstance
(
exc
,
AssertionError
)
or
config
.
on_opt_error
==
'raise'
:
raise
exc
@staticmethod
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: ignore InconsistencyErrors, print traceback
"""failure_callback for NavigatorOptimizer
ignore InconsistencyErrors, print traceback
"""
if
isinstance
(
exc
,
InconsistencyError
):
return
return
NavigatorOptimizer
.
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
)
@staticmethod
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: ignore all errors
"""
pass
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
"""
:param local_opt: a LocalOptimizer to apply over a Env (or None is Ok too).
:param local_opt: a LocalOptimizer to apply over a Env
(or None is Ok too).
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization
- False: new subgraphs returned by an optimization is a candidate for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant' attribute.
- True: new subgraphs returned by an optimization is not a
candidate for optimization
- False: new subgraphs returned by an optimization is a candidate
for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables will be 'None'.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here (raised normally).
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
self
.
local_opt
=
local_opt
if
ignore_newtrees
==
'auto'
:
...
...
@@ -875,15 +1034,19 @@ class NavigatorOptimizer(Optimizer):
self
.
ignore_newtrees
=
ignore_newtrees
self
.
failure_callback
=
failure_callback
def
attach_updater
(
self
,
env
,
importer
,
pruner
,
chin
=
None
):
"""Install some Env listeners to help the navigator deal with the ignore_trees-related functionality.
def
attach_updater
(
self
,
env
,
importer
,
pruner
,
chin
=
None
):
"""
Install some Env listeners to help the navigator deal with the
ignore_trees-related functionality.
:param importer: function that will be called whenever when optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff from graph.
:param importer: function that will be called whenever when
optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff
from graph.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The Env plugin that handles the three tasks.
Keep this around so that you can detach later!
:returns: The Env plugin that handles the three tasks.
Keep this around so that you can detach later!
"""
if
self
.
ignore_newtrees
:
importer
=
None
...
...
@@ -916,21 +1079,22 @@ class NavigatorOptimizer(Optimizer):
if
u
is
not
None
:
env
.
remove_feature
(
u
)
def
process_node
(
self
,
env
,
node
,
lopt
=
None
):
def
process_node
(
self
,
env
,
node
,
lopt
=
None
):
"""
This function will use `lopt` to `transform` the `node`. The `transform` method will
return either False or a list of Variables that are intended to replace `node.outputs`.
This function will use `lopt` to `transform` the `node`. The
`transform` method will return either False or a list of Variables
that are intended to replace `node.outputs`.
If the env accepts the replacement, then the optimization is
successful, and this
function returns True.
If the env accepts the replacement, then the optimization is
successful, and this
function returns True.
If there are no replacement candidates or the env rejects the
replacements, this
function returns False.
If there are no replacement candidates or the env rejects the
replacements, this
function returns False.
:param env: an Env
:param node: an Apply instance in `env`
:param lopt: a LocalOptimizer instance that may have a better idea for
how to compute
node's outputs.
:param lopt: a LocalOptimizer instance that may have a better idea for
how to compute
node's outputs.
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `env`.
...
...
@@ -940,16 +1104,19 @@ class NavigatorOptimizer(Optimizer):
replacements
=
lopt
.
transform
(
node
)
except
Exception
,
e
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
e
,
self
,
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
)
self
.
failure_callback
(
e
,
self
,
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
)
return
False
else
:
raise
if
replacements
is
False
or
replacements
is
None
:
return
False
if
not
isinstance
(
replacements
,
(
tuple
,
list
)):
raise
TypeError
(
'Optimizer
%
s gave wrong type of replacement. Expected list or tuple.'
%
lopt
)
raise
TypeError
(
'Optimizer
%
s gave wrong type of replacement. '
'Expected list or tuple.'
%
lopt
)
if
len
(
node
.
outputs
)
!=
len
(
replacements
):
raise
ValueError
(
'Optimizer
%
s gave wrong number of replacements'
%
lopt
)
raise
ValueError
(
'Optimizer
%
s gave wrong number of replacements'
%
lopt
)
# If an output would be replaced by itself, no need to perform
# the replacement
repl_pairs
=
[(
r
,
rnew
)
for
r
,
rnew
in
zip
(
node
.
outputs
,
replacements
)
...
...
@@ -962,8 +1129,8 @@ class NavigatorOptimizer(Optimizer):
except
Exception
,
e
:
# This means the replacements were rejected by the env.
#
# This is not supposed to happen. The default failure_callback
will print a
# traceback as a warning.
# This is not supposed to happen. The default failure_callback
#
will print a
traceback as a warning.
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
e
,
self
,
repl_pairs
,
lopt
)
return
False
...
...
@@ -978,26 +1145,33 @@ class NavigatorOptimizer(Optimizer):
self
.
local_opt
.
add_requirements
(
env
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s (
%
i)"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s (
%
i)"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
if
depth
!=
0
:
self
.
local_opt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
-
1
)
self
.
local_opt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
class
TopoOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
if
order
not
in
[
'out_to_in'
,
'in_to_out'
]:
raise
ValueError
(
"order must be 'out_to_in' or 'in_to_out'"
)
self
.
order
=
order
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
def
apply
(
self
,
env
,
start_from
=
None
):
if
start_from
is
None
:
start_from
=
env
.
outputs
def
apply
(
self
,
env
,
start_from
=
None
):
if
start_from
is
None
:
start_from
=
env
.
outputs
q
=
deque
(
graph
.
io_toposort
(
env
.
inputs
,
start_from
))
def
importer
(
node
):
if
node
is
not
current_node
:
q
.
append
(
node
)
def
pruner
(
node
):
if
node
is
not
current_node
:
try
:
...
...
@@ -1020,14 +1194,16 @@ class TopoOptimizer(NavigatorOptimizer):
self
.
detach_updater
(
env
,
u
)
class
OpKeyOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
if
not
hasattr
(
local_opt
,
'op_key'
):
raise
TypeError
(
"LocalOptimizer for OpKeyOptimizer must have an 'op_key' method."
)
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
raise
TypeError
(
"LocalOptimizer for OpKeyOptimizer must have "
"an 'op_key' method."
)
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
def
apply
(
self
,
env
):
op
=
self
.
local_opt
.
op_key
()
...
...
@@ -1035,9 +1211,12 @@ class OpKeyOptimizer(NavigatorOptimizer):
q
=
reduce
(
list
.
__iadd__
,
map
(
env
.
get_nodes
,
op
))
else
:
q
=
list
(
env
.
get_nodes
(
op
))
def
importer
(
node
):
if
node
is
not
current_node
:
if
node
.
op
==
op
:
q
.
append
(
node
)
if
node
.
op
==
op
:
q
.
append
(
node
)
def
pruner
(
node
):
if
node
is
not
current_node
and
node
.
op
==
op
:
try
:
...
...
@@ -1065,7 +1244,6 @@ class OpKeyOptimizer(NavigatorOptimizer):
env
.
extend
(
toolbox
.
NodeFinder
())
class
ChangeTracker
:
def
__init__
(
self
):
self
.
changed
=
False
...
...
@@ -1082,17 +1260,19 @@ class ChangeTracker:
def
on_attach
(
self
,
env
):
env
.
change_tracker
=
self
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
def
__init__
(
self
,
optimizers
,
failure_callback
=
None
,
max_depth
=
None
,
max_use_ratio
=
None
):
failure_callback
=
None
,
max_depth
=
None
,
max_use_ratio
=
None
):
"""
:param optimizers: list or set of local or global optimizations to
apply until
equilibrium.
:param optimizers: list or set of local or global optimizations to
apply until
equilibrium.
:param max_use_ratio: each optimizer can be applied at most (size of graph * this number)
:param max_use_ratio: each optimizer can be applied at most
(size of graph * this number) times
:param max_depth: TODO what does this do? (EquilibriumDB sets it to 5)
...
...
@@ -1100,8 +1280,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
super
(
EquilibriumOptimizer
,
self
)
.
__init__
(
None
,
ignore_newtrees
=
True
,
failure_callback
=
failure_callback
)
ignore_newtrees
=
True
,
failure_callback
=
failure_callback
)
self
.
local_optimizers
=
[]
self
.
global_optimizers
=
[]
...
...
@@ -1112,13 +1292,18 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self
.
global_optimizers
.
append
(
opt
)
self
.
max_depth
=
max_depth
self
.
max_use_ratio
=
max_use_ratio
assert
self
.
max_use_ratio
is
not
None
,
'max_use_ratio has to be a number'
assert
self
.
max_use_ratio
is
not
None
,
(
'max_use_ratio has to be a number'
)
def
add_requirements
(
self
,
env
):
super
(
EquilibriumOptimizer
,
self
)
.
add_requirements
(
env
)
env
.
extend
(
ChangeTracker
())
for
opt
in
self
.
local_optimizers
:
opt
.
add_requirements
(
env
)
for
opt
in
self
.
global_optimizers
:
opt
.
add_requirements
(
env
)
def
apply
(
self
,
env
,
start_from
=
None
):
def
apply
(
self
,
env
,
start_from
=
None
):
if
start_from
is
None
:
start_from
=
env
.
outputs
changed
=
True
...
...
@@ -1153,9 +1338,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
nb_nodes
.
append
(
len
(
q
))
max_nb_nodes
=
max
(
max_nb_nodes
,
len
(
q
))
max_use
=
max_nb_nodes
*
self
.
max_use_ratio
def
importer
(
node
):
if
node
is
not
current_node
:
q
.
append
(
node
)
def
pruner
(
node
):
if
node
is
not
current_node
:
try
:
...
...
@@ -1179,12 +1366,13 @@ class EquilibriumOptimizer(NavigatorOptimizer):
opt_name
=
(
getattr
(
lopt
,
"name"
,
None
)
or
getattr
(
lopt
,
"__name__"
,
""
))
if
node
not
in
env
.
nodes
:
break
# go to next node
# go to next node
break
finally
:
self
.
detach_updater
(
env
,
u
)
self
.
detach_updater
(
env
,
u
)
#TODO: erase this line, it's redundant at best
loop_timing
.
append
(
float
(
time
.
time
()
-
t0
))
if
max_use_abort
:
_logger
.
error
(
"EquilibriumOptimizer max'ed out by '
%
s'"
%
opt_name
+
". You can safely raise the current threshold of "
...
...
@@ -1216,10 +1404,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
name
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
))
if
depth
!=
0
:
for
lopt
in
self
.
local_optimizers
:
lopt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
-
1
)
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
#################
...
...
@@ -1242,7 +1432,8 @@ def _check_chain(r, chain):
return
False
else
:
try
:
if
issubclass
(
elem
,
op
.
Op
)
and
not
isinstance
(
r
.
owner
.
op
,
elem
):
if
(
issubclass
(
elem
,
op
.
Op
)
and
not
isinstance
(
r
.
owner
.
op
,
elem
)):
return
False
except
TypeError
:
return
False
...
...
@@ -1256,6 +1447,7 @@ def _check_chain(r, chain):
return
(
r
is
not
None
)
#_check_chain.n_calls = 0
def
check_chain
(
r
,
*
chain
):
"""WRITEME"""
if
isinstance
(
r
,
graph
.
Apply
):
...
...
@@ -1280,7 +1472,7 @@ def pre_greedy_local_optimizer(list_optimizations, out):
add additional node to the inputs of the node, it can
be needed to call this function multiple time.
'''
def
local_recursive_function
(
list_opt
,
out
,
optimized_vars
,
depth
):
def
local_recursive_function
(
list_opt
,
out
,
optimized_vars
,
depth
):
if
not
getattr
(
out
,
'owner'
,
None
):
return
[
out
],
optimized_vars
node
=
out
.
owner
...
...
@@ -1292,11 +1484,11 @@ def pre_greedy_local_optimizer(list_optimizations, out):
else
:
if
inp
.
owner
:
outs
,
optimized_vars
=
local_recursive_function
(
list_opt
,
inp
,
optimized_vars
,
depth
+
1
)
for
k
,
v
in
zip
(
inp
.
owner
.
outputs
,
outs
):
list_opt
,
inp
,
optimized_vars
,
depth
+
1
)
for
k
,
v
in
zip
(
inp
.
owner
.
outputs
,
outs
):
optimized_vars
[
k
]
=
v
nw_in
=
outs
[
inp
.
owner
.
outputs
.
index
(
inp
)]
...
...
@@ -1310,10 +1502,10 @@ def pre_greedy_local_optimizer(list_optimizations, out):
ret
=
opt
.
transform
(
node
)
if
ret
is
not
False
and
ret
is
not
None
:
assert
len
(
ret
)
==
len
(
node
.
outputs
)
for
k
,
v
in
zip
(
node
.
outputs
,
ret
):
for
k
,
v
in
zip
(
node
.
outputs
,
ret
):
optimized_vars
[
k
]
=
v
results
=
ret
if
ret
[
0
]
.
owner
:
if
ret
[
0
]
.
owner
:
node
=
out
.
owner
else
:
break
...
...
@@ -1324,8 +1516,6 @@ def pre_greedy_local_optimizer(list_optimizations, out):
return
final_outs
[
0
]
############
### Misc ###
############
...
...
theano/tensor/opt.py
浏览文件 @
d4dfbf2a
...
...
@@ -1823,7 +1823,6 @@ def local_subtensor_merge(node):
merged_slices
.
append
(
slice1
)
pos_1
+=
1
if
pos_2
<
len
(
slices2
):
merged_slices
+=
slices2
[
pos_2
:]
else
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论