Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ff91a554
提交
ff91a554
authored
9月 09, 2014
作者:
abergeron
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2074 from nouiz/opt_order
Make optimization more deterministic
上级
8dea107d
75464f23
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
71 行增加
和
22 行删除
+71
-22
optdb.py
theano/gof/optdb.py
+14
-12
ordered_set.py
theano/misc/ordered_set.py
+57
-10
没有找到文件。
theano/gof/optdb.py
浏览文件 @
ff91a554
import
sys
import
sys
from
theano.gof.python25
import
DefaultOrderedDict
import
numpy
import
numpy
from
theano.gof.python25
import
DefaultOrderedDict
from
theano.misc.ordered_set
import
OrderedSet
from
theano.compat.six
import
StringIO
from
theano.compat.six
import
StringIO
from
theano.gof
import
opt
from
theano.gof
import
opt
from
theano.configparser
import
AddConfigVar
,
FloatParam
from
theano.configparser
import
AddConfigVar
,
FloatParam
...
@@ -26,7 +27,7 @@ class DB(object):
...
@@ -26,7 +27,7 @@ class DB(object):
return
self
.
_optimizer_idx
return
self
.
_optimizer_idx
def
__init__
(
self
):
def
__init__
(
self
):
self
.
__db__
=
DefaultOrderedDict
(
s
et
)
self
.
__db__
=
DefaultOrderedDict
(
OrderedS
et
)
self
.
_names
=
set
()
self
.
_names
=
set
()
self
.
name
=
None
# will be reset by register
self
.
name
=
None
# will be reset by register
#(via obj.name by the thing doing the registering)
#(via obj.name by the thing doing the registering)
...
@@ -51,7 +52,7 @@ class DB(object):
...
@@ -51,7 +52,7 @@ class DB(object):
raise
ValueError
(
'''You can
\'
t register the same optimization
raise
ValueError
(
'''You can
\'
t register the same optimization
multiple time in a DB. Tryed to register "
%
s" again under the new name "
%
s".
multiple time in a DB. Tryed to register "
%
s" again under the new name "
%
s".
Use theano.gof.ProxyDB to work around that'''
%
(
obj
.
name
,
name
))
Use theano.gof.ProxyDB to work around that'''
%
(
obj
.
name
,
name
))
self
.
__db__
[
name
]
=
s
et
([
obj
])
self
.
__db__
[
name
]
=
OrderedS
et
([
obj
])
self
.
_names
.
add
(
name
)
self
.
_names
.
add
(
name
)
self
.
__db__
[
obj
.
__class__
.
__name__
]
.
add
(
obj
)
self
.
__db__
[
obj
.
__class__
.
__name__
]
.
add
(
obj
)
self
.
add_tags
(
name
,
*
tags
)
self
.
add_tags
(
name
,
*
tags
)
...
@@ -79,15 +80,16 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
...
@@ -79,15 +80,16 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
def
__query__
(
self
,
q
):
def
__query__
(
self
,
q
):
if
not
isinstance
(
q
,
Query
):
if
not
isinstance
(
q
,
Query
):
raise
TypeError
(
'Expected a Query.'
,
q
)
raise
TypeError
(
'Expected a Query.'
,
q
)
variables
=
set
()
# The ordered set is needed for deterministic optimization.
variables
=
OrderedSet
()
for
tag
in
q
.
include
:
for
tag
in
q
.
include
:
variables
.
update
(
self
.
__db__
[
tag
])
variables
.
update
(
self
.
__db__
[
tag
])
for
tag
in
q
.
require
:
for
tag
in
q
.
require
:
variables
.
intersection_update
(
self
.
__db__
[
tag
])
variables
.
intersection_update
(
self
.
__db__
[
tag
])
for
tag
in
q
.
exclude
:
for
tag
in
q
.
exclude
:
variables
.
difference_update
(
self
.
__db__
[
tag
])
variables
.
difference_update
(
self
.
__db__
[
tag
])
remove
=
s
et
()
remove
=
OrderedS
et
()
add
=
s
et
()
add
=
OrderedS
et
()
for
obj
in
variables
:
for
obj
in
variables
:
if
isinstance
(
obj
,
DB
):
if
isinstance
(
obj
,
DB
):
sq
=
q
.
subquery
.
get
(
obj
.
name
,
q
)
sq
=
q
.
subquery
.
get
(
obj
.
name
,
q
)
...
@@ -143,15 +145,15 @@ class Query(object):
...
@@ -143,15 +145,15 @@ class Query(object):
:param position_cutoff: Used by SequenceDB to keep only optimizer that
:param position_cutoff: Used by SequenceDB to keep only optimizer that
are positioned before the cut_off point.
are positioned before the cut_off point.
"""
"""
self
.
include
=
s
et
(
include
)
self
.
include
=
OrderedS
et
(
include
)
self
.
require
=
require
or
s
et
()
self
.
require
=
require
or
OrderedS
et
()
self
.
exclude
=
exclude
or
s
et
()
self
.
exclude
=
exclude
or
OrderedS
et
()
self
.
subquery
=
subquery
or
{}
self
.
subquery
=
subquery
or
{}
self
.
position_cutoff
=
position_cutoff
self
.
position_cutoff
=
position_cutoff
if
isinstance
(
self
.
require
,
(
list
,
tuple
)):
if
isinstance
(
self
.
require
,
(
list
,
tuple
)):
self
.
require
=
s
et
(
self
.
require
)
self
.
require
=
OrderedS
et
(
self
.
require
)
if
isinstance
(
self
.
exclude
,
(
list
,
tuple
)):
if
isinstance
(
self
.
exclude
,
(
list
,
tuple
)):
self
.
exclude
=
s
et
(
self
.
exclude
)
self
.
exclude
=
OrderedS
et
(
self
.
exclude
)
#add all opt with this tag
#add all opt with this tag
def
including
(
self
,
*
tags
):
def
including
(
self
,
*
tags
):
...
...
theano/misc/ordered_set.py
浏览文件 @
ff91a554
...
@@ -7,6 +7,7 @@ except ImportError:
...
@@ -7,6 +7,7 @@ except ImportError:
from
theano.gof.python25
import
OrderedDict
from
theano.gof.python25
import
OrderedDict
import
types
import
types
def
check_deterministic
(
iterable
):
def
check_deterministic
(
iterable
):
# Most places where OrderedSet is used, theano interprets any exception
# Most places where OrderedSet is used, theano interprets any exception
# whatsoever as a problem that an optimization introduced into the graph.
# whatsoever as a problem that an optimization introduced into the graph.
...
@@ -40,11 +41,28 @@ if MutableSet is not None:
...
@@ -40,11 +41,28 @@ if MutableSet is not None:
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
## {{{ http://code.activestate.com/recipes/576696/ (r5)
## {{{ http://code.activestate.com/recipes/576696/ (r5)
import
collections
import
collections
from
weakref
import
proxy
import
weakref
class
Link
(
object
):
class
Link
(
object
):
__slots__
=
'prev'
,
'next'
,
'key'
,
'__weakref__'
__slots__
=
'prev'
,
'next'
,
'key'
,
'__weakref__'
def
__getstate__
(
self
):
# weakref.proxy don't pickle well, so we use weakref.ref
# manually and don't pickle the weakref.
# We restore the weakref when we unpickle.
ret
=
[
self
.
prev
(),
self
.
next
()]
try
:
ret
.
append
(
self
.
key
)
except
AttributeError
:
pass
return
ret
def
__setstate__
(
self
,
state
):
self
.
prev
=
weakref
.
ref
(
state
[
0
])
self
.
next
=
weakref
.
ref
(
state
[
1
])
if
len
(
state
)
==
3
:
self
.
key
=
state
[
2
]
class
OrderedSet
(
collections
.
MutableSet
):
class
OrderedSet
(
collections
.
MutableSet
):
'Set the remembers the order elements were added'
'Set the remembers the order elements were added'
# Big-O running times for all methods are the same as for regular sets.
# Big-O running times for all methods are the same as for regular sets.
...
@@ -65,7 +83,7 @@ if MutableSet is not None:
...
@@ -65,7 +83,7 @@ if MutableSet is not None:
# Checks added by IG
# Checks added by IG
check_deterministic
(
iterable
)
check_deterministic
(
iterable
)
self
.
__root
=
root
=
Link
()
# sentinel node for doubly linked list
self
.
__root
=
root
=
Link
()
# sentinel node for doubly linked list
root
.
prev
=
root
.
next
=
root
root
.
prev
=
root
.
next
=
weakref
.
ref
(
root
)
self
.
__map
=
{}
# key --> link
self
.
__map
=
{}
# key --> link
if
iterable
is
not
None
:
if
iterable
is
not
None
:
self
|=
iterable
self
|=
iterable
...
@@ -82,32 +100,61 @@ if MutableSet is not None:
...
@@ -82,32 +100,61 @@ if MutableSet is not None:
self
.
__map
[
key
]
=
link
=
Link
()
self
.
__map
[
key
]
=
link
=
Link
()
root
=
self
.
__root
root
=
self
.
__root
last
=
root
.
prev
last
=
root
.
prev
link
.
prev
,
link
.
next
,
link
.
key
=
last
,
root
,
key
link
.
prev
,
link
.
next
,
link
.
key
=
last
,
weakref
.
ref
(
root
),
key
last
.
next
=
root
.
prev
=
proxy
(
link
)
last
()
.
next
=
root
.
prev
=
weakref
.
ref
(
link
)
def
union
(
self
,
s
):
check_deterministic
(
s
)
n
=
self
.
copy
()
for
elem
in
s
:
if
elem
not
in
n
:
n
.
add
(
elem
)
return
n
def
intersection_update
(
self
,
s
):
l
=
[]
for
elem
in
self
:
if
elem
not
in
s
:
l
.
append
(
elem
)
for
elem
in
l
:
self
.
remove
(
elem
)
return
self
def
difference_update
(
self
,
s
):
check_deterministic
(
s
)
for
elem
in
s
:
if
elem
in
self
:
self
.
remove
(
elem
)
return
self
def
copy
(
self
):
n
=
OrderedSet
()
n
.
update
(
self
)
return
n
def
discard
(
self
,
key
):
def
discard
(
self
,
key
):
# Remove an existing item using self.__map to find the link which is
# Remove an existing item using self.__map to find the link which is
# then removed by updating the links in the predecessor and successors.
# then removed by updating the links in the predecessor and successors.
if
key
in
self
.
__map
:
if
key
in
self
.
__map
:
link
=
self
.
__map
.
pop
(
key
)
link
=
self
.
__map
.
pop
(
key
)
link
.
prev
.
next
=
link
.
next
link
.
prev
()
.
next
=
link
.
next
link
.
next
.
prev
=
link
.
prev
link
.
next
()
.
prev
=
link
.
prev
def
__iter__
(
self
):
def
__iter__
(
self
):
# Traverse the linked list in order.
# Traverse the linked list in order.
root
=
self
.
__root
root
=
self
.
__root
curr
=
root
.
next
curr
=
root
.
next
()
while
curr
is
not
root
:
while
curr
is
not
root
:
yield
curr
.
key
yield
curr
.
key
curr
=
curr
.
next
curr
=
curr
.
next
()
def
__reversed__
(
self
):
def
__reversed__
(
self
):
# Traverse the linked list in reverse order.
# Traverse the linked list in reverse order.
root
=
self
.
__root
root
=
self
.
__root
curr
=
root
.
prev
curr
=
root
.
prev
()
while
curr
is
not
root
:
while
curr
is
not
root
:
yield
curr
.
key
yield
curr
.
key
curr
=
curr
.
prev
curr
=
curr
.
prev
()
def
pop
(
self
,
last
=
True
):
def
pop
(
self
,
last
=
True
):
if
not
self
:
if
not
self
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论