Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
608c1a07
提交
608c1a07
authored
2月 18, 2009
作者:
Joseph Turian
浏览文件
操作
浏览文件
下载
差异文件
merge
上级
bd1133b9
824f5fa7
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
117 行增加
和
36 行删除
+117
-36
env.py
theano/gof/env.py
+15
-4
opt.py
theano/gof/opt.py
+27
-20
debugmode.py
theano/sandbox/debugmode.py
+73
-10
opt.py
theano/tensor/opt.py
+2
-2
没有找到文件。
theano/gof/env.py
浏览文件 @
608c1a07
...
...
@@ -171,6 +171,10 @@ class Env(utils.object2):
Updates the list of clients of r with new_clients.
"""
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
print
'RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
r
.
clients
]
print
'NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
new_clients
]
assert
not
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
))
r
.
clients
+=
new_clients
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
):
...
...
@@ -182,6 +186,10 @@ class Env(utils.object2):
"""
for
entry
in
clients_to_remove
:
r
.
clients
.
remove
(
entry
)
if
entry
in
r
.
clients
:
print
'ENTRY'
,
repr
(
entry
),
type
(
entry
[
0
])
print
'CLIENTS'
,
repr
(
r
.
clients
)
assert
entry
not
in
r
.
clients
# an op,i pair should be unique
if
not
r
.
clients
:
if
prune
:
self
.
__prune_r__
([
r
])
...
...
@@ -194,8 +202,11 @@ class Env(utils.object2):
def
__import_r__
(
self
,
results
):
# Imports the owners of the results
for
node
in
set
(
r
.
owner
for
r
in
results
if
r
.
owner
is
not
None
):
self
.
__import__
(
node
)
r_owner_done
=
set
()
for
node
in
[
r
.
owner
for
r
in
results
if
r
.
owner
is
not
None
]:
if
node
not
in
r_owner_done
:
r_owner_done
.
add
(
node
)
self
.
__import__
(
node
)
for
r
in
results
:
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Value
)
and
r
not
in
self
.
inputs
:
raise
TypeError
(
"Undeclared input"
,
r
)
...
...
@@ -319,8 +330,8 @@ class Env(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops
return
for
node
,
i
in
list
(
r
.
clients
):
assert
node
==
'output'
and
self
.
outputs
[
i
]
is
r
or
node
.
inputs
[
i
]
is
r
for
node
,
i
in
list
(
r
.
clients
):
#copy the client list for iteration
assert
(
node
==
'output'
and
self
.
outputs
[
i
]
is
r
)
or
(
node
.
inputs
[
i
]
is
r
)
self
.
change_input
(
node
,
i
,
new_r
,
reason
=
reason
)
def
replace_all
(
self
,
pairs
,
reason
=
None
):
...
...
theano/gof/opt.py
浏览文件 @
608c1a07
...
...
@@ -187,21 +187,27 @@ class MergeOptimizer(Optimizer):
env
.
extend
(
toolbox
.
ReplaceValidate
())
def
apply_constant_merge
(
self
,
env
):
seen_constants
=
set
()
const_sig
=
_metadict
()
# result -> result.signature() (for constants)
const_sig_inv
=
_metadict
()
# signature -> result (for constants)
for
i
,
c
in
enumerate
([
r
for
r
in
env
.
results
if
isinstance
(
r
,
graph
.
Constant
)]):
sig
=
c
.
signature
()
other_c
=
const_sig_inv
.
get
(
sig
,
None
)
if
other_c
is
not
None
:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if
c
.
name
:
other_c
.
name
=
c
.
name
env
.
replace_validate
(
c
,
other_c
,
reason
=
'Constant Merge'
)
else
:
#this is a new constant
const_sig
[
c
]
=
sig
const_sig_inv
[
sig
]
=
c
for
node
in
_list_of_nodes
(
env
):
for
i
,
c
in
enumerate
([
r
for
r
in
node
.
inputs
if
isinstance
(
r
,
graph
.
Constant
)]):
if
id
(
c
)
in
seen_constants
:
continue
else
:
seen_constants
.
add
(
id
(
c
))
sig
=
c
.
signature
()
other_c
=
const_sig_inv
.
get
(
sig
,
None
)
if
other_c
is
not
None
:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if
c
.
name
:
other_c
.
name
=
c
.
name
env
.
replace_validate
(
c
,
other_c
,
reason
=
'Constant Merge'
)
else
:
#this is a new constant
const_sig
[
c
]
=
sig
const_sig_inv
[
sig
]
=
c
def
exptime_apply_node_merge
(
self
,
env
):
# we clear the dicts because the Constants signatures are not necessarily hashable
...
...
@@ -242,19 +248,20 @@ class MergeOptimizer(Optimizer):
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Results
nodes_seen
=
set
()
nodes_seen
=
{}
for
node
in
_list_of_nodes
(
env
):
for
node
_idx
,
node
in
enumerate
(
_list_of_nodes
(
env
)
):
#
# these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself!
#
assert
len
(
node
.
inputs
[
0
]
.
clients
)
>
0
assert
(
node
,
0
)
in
node
.
inputs
[
0
]
.
clients
merge_candidates
=
[
c
for
(
c
,
i
)
in
node
.
inputs
[
0
]
.
clients
if
c
in
nodes_seen
]
nodes_seen
.
add
(
node
)
merge_candidates
=
[(
nodes_seen
[
c
],
c
)
for
(
c
,
i
)
in
node
.
inputs
[
0
]
.
clients
if
c
in
nodes_seen
]
merge_candidates
.
sort
()
nodes_seen
[
node
]
=
node_idx
#print 'NODE', node, merge_candidates, node.inputs[0].clients
for
candidate
in
merge_candidates
:
for
candidate
_idx
,
candidate
in
merge_candidates
:
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
continue
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
))
...
...
@@ -626,8 +633,8 @@ class NavigatorOptimizer(Optimizer):
def
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: print traceback
"""
print
"WARNING: Optimization failure due to: "
,
local_opt
print
"TRACEBACK:"
print
>>
sys
.
stderr
,
"WARNING: Optimization failure due to: "
,
local_opt
print
>>
sys
.
stderr
,
"TRACEBACK:"
traceback
.
print_exc
()
@staticmethod
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
):
...
...
theano/sandbox/debugmode.py
浏览文件 @
608c1a07
...
...
@@ -43,6 +43,38 @@ def debugprint(a, prefix='', depth=-1, done=None, file=sys.stdout):
return
file
class
Event
(
object
):
def
__init__
(
self
,
kind
,
node
,
idx
=
None
,
reason
=
None
):
self
.
kind
=
kind
if
node
==
'output'
:
self
.
node
=
'output'
self
.
op
=
'output'
else
:
self
.
node
=
node
self
.
op
=
node
.
op
self
.
idx
=
idx
self
.
reason
=
reason
def
__str__
(
self
):
if
self
.
kind
==
'change'
:
return
' '
.
join
([
'change'
,
self
.
reason
,
str
(
self
.
op
),
str
(
self
.
idx
),
str
(
len
(
self
.
node
.
inputs
))])
else
:
return
str
(
self
.
__dict__
)
def
__eq__
(
self
,
other
):
rval
=
type
(
self
)
==
type
(
other
)
if
rval
:
for
attr
in
[
'kind'
,
'op'
,
'idx'
,
'reason'
]:
rval
=
rval
and
getattr
(
self
,
attr
)
==
getattr
(
other
,
attr
)
return
rval
def
__ne__
(
self
,
other
):
return
not
(
self
==
other
)
class
ResultEquivalenceTracker
(
object
):
def
__init__
(
self
):
self
.
env
=
None
...
...
@@ -57,12 +89,14 @@ class ResultEquivalenceTracker(object):
self
.
reasons
=
{}
self
.
replaced_by
=
{}
self
.
snapshots
=
{}
self
.
event_list
=
[]
def
on_detach
(
self
,
env
):
assert
env
is
self
.
env
self
.
env
=
None
def
on_prune
(
self
,
env
,
node
):
self
.
event_list
.
append
(
Event
(
'prune'
,
node
))
#print 'PRUNING NODE', node, id(node)
assert
node
in
self
.
active_nodes
assert
node
not
in
self
.
inactive_nodes
...
...
@@ -70,6 +104,8 @@ class ResultEquivalenceTracker(object):
self
.
inactive_nodes
.
add
(
node
)
def
on_import
(
self
,
env
,
node
):
self
.
event_list
.
append
(
Event
(
'import'
,
node
))
#print 'NEW NODE', node, id(node)
assert
node
not
in
self
.
active_nodes
self
.
active_nodes
.
add
(
node
)
...
...
@@ -93,6 +129,7 @@ class ResultEquivalenceTracker(object):
def
on_change_input
(
self
,
env
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
#print 'CHANGE by', reason, 'to use', new_r, type(new_r)
self
.
event_list
.
append
(
Event
(
'change'
,
node
,
reason
=
str
(
reason
),
idx
=
i
))
self
.
reasons
.
setdefault
(
new_r
,
[])
self
.
replaced_by
.
setdefault
(
new_r
,
[])
...
...
@@ -291,7 +328,6 @@ class OptCheckLinker(OpWiseCLinker):
# because the incorrect result detected here will cause
# subsequent outputs to be incorrect.
raise
Exception
(
"OptCheckFailure"
)
print
>>
sys
.
stderr
,
'OptCheck PASS'
if
0
:
#OLD CODE
#print out the summary of the first problematic equivalence group
...
...
@@ -321,7 +357,9 @@ NODEFAULT = ['NODEFAULT']
class
OptCheckFunctionMaker
(
FunctionMaker
):
def
__init__
(
self
,
inputs
,
outputs
,
optimizer
,
accept_inplace
=
False
,
function_builder
=
Function
):
chances_for_optimizer_to_screw_up
=
10
,
accept_inplace
=
False
,
function_builder
=
Function
):
"""
:type inputs: a list of SymbolicInput instances
...
...
@@ -350,17 +388,39 @@ class OptCheckFunctionMaker(FunctionMaker):
expanded_inputs
=
reduce
(
list
.
__add__
,
[
list
(
z
)
for
x
,
y
,
z
in
indices
],
[])
# make the env
env
,
additional_outputs
,
equivalence_tracker
=
optcheck_env
(
expanded_inputs
,
outputs
,
accept_inplace
)
self
.
env
=
env
for
i
in
xrange
(
chances_for_optimizer_to_screw_up
):
env
,
additional_outputs
,
equivalence_tracker
=
optcheck_env
(
expanded_inputs
,
outputs
,
accept_inplace
)
env
.
equivalence_tracker
=
equivalence_tracker
# optimize the env
optimizer
(
env
)
if
i
:
li
=
env
.
equivalence_tracker
.
event_list
l0
=
env0
.
equivalence_tracker
.
event_list
if
li
!=
l0
:
print
>>
sys
.
stderr
,
"WARNING: Optimization process is unstable"
for
j
in
xrange
(
max
(
len
(
li
),
len
(
l0
))):
if
li
[
j
]
!=
l0
[
j
]:
print
>>
sys
.
stderr
,
"* "
,
j
print
>>
sys
.
stderr
,
" "
,
str
(
li
[
j
])
if
j
<
len
(
li
)
else
'-'
print
>>
sys
.
stderr
,
" "
,
str
(
l0
[
j
])
if
j
<
len
(
l0
)
else
'-'
else
:
pass
linker
=
OptCheckLinker
()
print
>>
sys
.
stderr
,
"EXITING"
sys
.
exit
(
1
)
break
else
:
print
>>
sys
.
stdout
,
"OPTCHECK: optimization"
,
i
,
"of"
,
len
(
li
),
"events was stable."
else
:
env0
=
env
# optimize the env
optimizer
(
env
)
env
.
equivalence_tracker
=
equivalence_tracker
del
env0
self
.
env
=
env
#equivalence_tracker.printstuff()
linker
=
OptCheckLinker
()
#the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer.
no_borrow
=
[
output
for
output
,
spec
in
zip
(
env
.
outputs
,
outputs
+
additional_outputs
)
if
not
spec
.
borrow
]
...
...
@@ -487,11 +547,14 @@ class OptCheck(Mode):
# function_module.function
def
function_maker
(
self
,
i
,
o
,
m
,
*
args
,
**
kwargs
):
assert
m
is
self
return
OptCheckFunctionMaker
(
i
,
o
,
self
.
optimizer
,
*
args
,
**
kwargs
)
def
__init__
(
self
,
optimizer
=
'fast_run'
):
return
OptCheckFunctionMaker
(
i
,
o
,
self
.
optimizer
,
chances_for_optimizer_to_screw_up
=
self
.
stability_patience
,
*
args
,
**
kwargs
)
def
__init__
(
self
,
optimizer
=
'fast_run'
,
stability_patience
=
10
):
super
(
OptCheck
,
self
)
.
__init__
(
optimizer
=
optimizer
,
linker
=
OptCheckLinker
)
self
.
stability_patience
=
stability_patience
theano/tensor/opt.py
浏览文件 @
608c1a07
...
...
@@ -5,7 +5,7 @@
from
..
import
gof
from
..gof
import
opt
,
InconsistencyError
,
TopoOptimizer
from
..gof
import
opt
,
InconsistencyError
,
TopoOptimizer
,
graph
from
elemwise
import
Elemwise
,
DimShuffle
from
..
import
scalar
import
basic
as
T
...
...
@@ -47,7 +47,7 @@ def insert_inplace_optimizer(env):
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
for
node
in
list
(
env
.
nodes
):
for
node
in
list
(
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
)
):
op
=
node
.
op
if
not
isinstance
(
op
,
Elemwise
):
continue
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论