Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
608c1a07
提交
608c1a07
authored
2月 18, 2009
作者:
Joseph Turian
浏览文件
操作
浏览文件
下载
差异文件
merge
上级
bd1133b9
824f5fa7
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
117 行增加
和
36 行删除
+117
-36
env.py
theano/gof/env.py
+15
-4
opt.py
theano/gof/opt.py
+27
-20
debugmode.py
theano/sandbox/debugmode.py
+73
-10
opt.py
theano/tensor/opt.py
+2
-2
没有找到文件。
theano/gof/env.py
浏览文件 @
608c1a07
...
@@ -171,6 +171,10 @@ class Env(utils.object2):
...
@@ -171,6 +171,10 @@ class Env(utils.object2):
Updates the list of clients of r with new_clients.
Updates the list of clients of r with new_clients.
"""
"""
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
print
'RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
r
.
clients
]
print
'NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
new_clients
]
assert
not
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
))
r
.
clients
+=
new_clients
r
.
clients
+=
new_clients
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
):
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
):
...
@@ -182,6 +186,10 @@ class Env(utils.object2):
...
@@ -182,6 +186,10 @@ class Env(utils.object2):
"""
"""
for
entry
in
clients_to_remove
:
for
entry
in
clients_to_remove
:
r
.
clients
.
remove
(
entry
)
r
.
clients
.
remove
(
entry
)
if
entry
in
r
.
clients
:
print
'ENTRY'
,
repr
(
entry
),
type
(
entry
[
0
])
print
'CLIENTS'
,
repr
(
r
.
clients
)
assert
entry
not
in
r
.
clients
# an op,i pair should be unique
if
not
r
.
clients
:
if
not
r
.
clients
:
if
prune
:
if
prune
:
self
.
__prune_r__
([
r
])
self
.
__prune_r__
([
r
])
...
@@ -194,8 +202,11 @@ class Env(utils.object2):
...
@@ -194,8 +202,11 @@ class Env(utils.object2):
def
__import_r__
(
self
,
results
):
def
__import_r__
(
self
,
results
):
# Imports the owners of the results
# Imports the owners of the results
for
node
in
set
(
r
.
owner
for
r
in
results
if
r
.
owner
is
not
None
):
r_owner_done
=
set
()
self
.
__import__
(
node
)
for
node
in
[
r
.
owner
for
r
in
results
if
r
.
owner
is
not
None
]:
if
node
not
in
r_owner_done
:
r_owner_done
.
add
(
node
)
self
.
__import__
(
node
)
for
r
in
results
:
for
r
in
results
:
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Value
)
and
r
not
in
self
.
inputs
:
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Value
)
and
r
not
in
self
.
inputs
:
raise
TypeError
(
"Undeclared input"
,
r
)
raise
TypeError
(
"Undeclared input"
,
r
)
...
@@ -319,8 +330,8 @@ class Env(utils.object2):
...
@@ -319,8 +330,8 @@ class Env(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops
# because it makes it easier to implement some optimizations for multiple-output ops
return
return
for
node
,
i
in
list
(
r
.
clients
):
for
node
,
i
in
list
(
r
.
clients
):
#copy the client list for iteration
assert
node
==
'output'
and
self
.
outputs
[
i
]
is
r
or
node
.
inputs
[
i
]
is
r
assert
(
node
==
'output'
and
self
.
outputs
[
i
]
is
r
)
or
(
node
.
inputs
[
i
]
is
r
)
self
.
change_input
(
node
,
i
,
new_r
,
reason
=
reason
)
self
.
change_input
(
node
,
i
,
new_r
,
reason
=
reason
)
def
replace_all
(
self
,
pairs
,
reason
=
None
):
def
replace_all
(
self
,
pairs
,
reason
=
None
):
...
...
theano/gof/opt.py
浏览文件 @
608c1a07
...
@@ -187,21 +187,27 @@ class MergeOptimizer(Optimizer):
...
@@ -187,21 +187,27 @@ class MergeOptimizer(Optimizer):
env
.
extend
(
toolbox
.
ReplaceValidate
())
env
.
extend
(
toolbox
.
ReplaceValidate
())
def
apply_constant_merge
(
self
,
env
):
def
apply_constant_merge
(
self
,
env
):
seen_constants
=
set
()
const_sig
=
_metadict
()
# result -> result.signature() (for constants)
const_sig
=
_metadict
()
# result -> result.signature() (for constants)
const_sig_inv
=
_metadict
()
# signature -> result (for constants)
const_sig_inv
=
_metadict
()
# signature -> result (for constants)
for
i
,
c
in
enumerate
([
r
for
r
in
env
.
results
if
isinstance
(
r
,
graph
.
Constant
)]):
for
node
in
_list_of_nodes
(
env
):
sig
=
c
.
signature
()
for
i
,
c
in
enumerate
([
r
for
r
in
node
.
inputs
if
isinstance
(
r
,
graph
.
Constant
)]):
other_c
=
const_sig_inv
.
get
(
sig
,
None
)
if
id
(
c
)
in
seen_constants
:
if
other_c
is
not
None
:
continue
# multiple names will clobber each other..
else
:
# we adopt convention to keep the last name
seen_constants
.
add
(
id
(
c
))
if
c
.
name
:
sig
=
c
.
signature
()
other_c
.
name
=
c
.
name
other_c
=
const_sig_inv
.
get
(
sig
,
None
)
env
.
replace_validate
(
c
,
other_c
,
reason
=
'Constant Merge'
)
if
other_c
is
not
None
:
else
:
# multiple names will clobber each other..
#this is a new constant
# we adopt convention to keep the last name
const_sig
[
c
]
=
sig
if
c
.
name
:
const_sig_inv
[
sig
]
=
c
other_c
.
name
=
c
.
name
env
.
replace_validate
(
c
,
other_c
,
reason
=
'Constant Merge'
)
else
:
#this is a new constant
const_sig
[
c
]
=
sig
const_sig_inv
[
sig
]
=
c
def
exptime_apply_node_merge
(
self
,
env
):
def
exptime_apply_node_merge
(
self
,
env
):
# we clear the dicts because the Constants signatures are not necessarily hashable
# we clear the dicts because the Constants signatures are not necessarily hashable
...
@@ -242,19 +248,20 @@ class MergeOptimizer(Optimizer):
...
@@ -242,19 +248,20 @@ class MergeOptimizer(Optimizer):
# we clear the dicts because the Constants signatures are not necessarily hashable
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Results
# and it's more efficient to give them an integer like the other Results
nodes_seen
=
set
()
nodes_seen
=
{}
for
node
in
_list_of_nodes
(
env
):
for
node
_idx
,
node
in
enumerate
(
_list_of_nodes
(
env
)
):
#
#
# these asserts ensure that the env has set the clients field properly the clients
# these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself!
# should at least contain `node` itself!
#
#
assert
len
(
node
.
inputs
[
0
]
.
clients
)
>
0
assert
len
(
node
.
inputs
[
0
]
.
clients
)
>
0
assert
(
node
,
0
)
in
node
.
inputs
[
0
]
.
clients
assert
(
node
,
0
)
in
node
.
inputs
[
0
]
.
clients
merge_candidates
=
[
c
for
(
c
,
i
)
in
node
.
inputs
[
0
]
.
clients
if
c
in
nodes_seen
]
merge_candidates
=
[(
nodes_seen
[
c
],
c
)
for
(
c
,
i
)
in
node
.
inputs
[
0
]
.
clients
if
c
in
nodes_seen
]
nodes_seen
.
add
(
node
)
merge_candidates
.
sort
()
nodes_seen
[
node
]
=
node_idx
#print 'NODE', node, merge_candidates, node.inputs[0].clients
#print 'NODE', node, merge_candidates, node.inputs[0].clients
for
candidate
in
merge_candidates
:
for
candidate
_idx
,
candidate
in
merge_candidates
:
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
continue
continue
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
))
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
))
...
@@ -626,8 +633,8 @@ class NavigatorOptimizer(Optimizer):
...
@@ -626,8 +633,8 @@ class NavigatorOptimizer(Optimizer):
def
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
):
def
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: print traceback
"""failure_callback for NavigatorOptimizer: print traceback
"""
"""
print
"WARNING: Optimization failure due to: "
,
local_opt
print
>>
sys
.
stderr
,
"WARNING: Optimization failure due to: "
,
local_opt
print
"TRACEBACK:"
print
>>
sys
.
stderr
,
"TRACEBACK:"
traceback
.
print_exc
()
traceback
.
print_exc
()
@staticmethod
@staticmethod
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
):
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
):
...
...
theano/sandbox/debugmode.py
浏览文件 @
608c1a07
...
@@ -43,6 +43,38 @@ def debugprint(a, prefix='', depth=-1, done=None, file=sys.stdout):
...
@@ -43,6 +43,38 @@ def debugprint(a, prefix='', depth=-1, done=None, file=sys.stdout):
return
file
return
file
class
Event
(
object
):
def
__init__
(
self
,
kind
,
node
,
idx
=
None
,
reason
=
None
):
self
.
kind
=
kind
if
node
==
'output'
:
self
.
node
=
'output'
self
.
op
=
'output'
else
:
self
.
node
=
node
self
.
op
=
node
.
op
self
.
idx
=
idx
self
.
reason
=
reason
def
__str__
(
self
):
if
self
.
kind
==
'change'
:
return
' '
.
join
([
'change'
,
self
.
reason
,
str
(
self
.
op
),
str
(
self
.
idx
),
str
(
len
(
self
.
node
.
inputs
))])
else
:
return
str
(
self
.
__dict__
)
def
__eq__
(
self
,
other
):
rval
=
type
(
self
)
==
type
(
other
)
if
rval
:
for
attr
in
[
'kind'
,
'op'
,
'idx'
,
'reason'
]:
rval
=
rval
and
getattr
(
self
,
attr
)
==
getattr
(
other
,
attr
)
return
rval
def
__ne__
(
self
,
other
):
return
not
(
self
==
other
)
class
ResultEquivalenceTracker
(
object
):
class
ResultEquivalenceTracker
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
env
=
None
self
.
env
=
None
...
@@ -57,12 +89,14 @@ class ResultEquivalenceTracker(object):
...
@@ -57,12 +89,14 @@ class ResultEquivalenceTracker(object):
self
.
reasons
=
{}
self
.
reasons
=
{}
self
.
replaced_by
=
{}
self
.
replaced_by
=
{}
self
.
snapshots
=
{}
self
.
snapshots
=
{}
self
.
event_list
=
[]
def
on_detach
(
self
,
env
):
def
on_detach
(
self
,
env
):
assert
env
is
self
.
env
assert
env
is
self
.
env
self
.
env
=
None
self
.
env
=
None
def
on_prune
(
self
,
env
,
node
):
def
on_prune
(
self
,
env
,
node
):
self
.
event_list
.
append
(
Event
(
'prune'
,
node
))
#print 'PRUNING NODE', node, id(node)
#print 'PRUNING NODE', node, id(node)
assert
node
in
self
.
active_nodes
assert
node
in
self
.
active_nodes
assert
node
not
in
self
.
inactive_nodes
assert
node
not
in
self
.
inactive_nodes
...
@@ -70,6 +104,8 @@ class ResultEquivalenceTracker(object):
...
@@ -70,6 +104,8 @@ class ResultEquivalenceTracker(object):
self
.
inactive_nodes
.
add
(
node
)
self
.
inactive_nodes
.
add
(
node
)
def
on_import
(
self
,
env
,
node
):
def
on_import
(
self
,
env
,
node
):
self
.
event_list
.
append
(
Event
(
'import'
,
node
))
#print 'NEW NODE', node, id(node)
#print 'NEW NODE', node, id(node)
assert
node
not
in
self
.
active_nodes
assert
node
not
in
self
.
active_nodes
self
.
active_nodes
.
add
(
node
)
self
.
active_nodes
.
add
(
node
)
...
@@ -93,6 +129,7 @@ class ResultEquivalenceTracker(object):
...
@@ -93,6 +129,7 @@ class ResultEquivalenceTracker(object):
def
on_change_input
(
self
,
env
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
def
on_change_input
(
self
,
env
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
#print 'CHANGE by', reason, 'to use', new_r, type(new_r)
#print 'CHANGE by', reason, 'to use', new_r, type(new_r)
self
.
event_list
.
append
(
Event
(
'change'
,
node
,
reason
=
str
(
reason
),
idx
=
i
))
self
.
reasons
.
setdefault
(
new_r
,
[])
self
.
reasons
.
setdefault
(
new_r
,
[])
self
.
replaced_by
.
setdefault
(
new_r
,
[])
self
.
replaced_by
.
setdefault
(
new_r
,
[])
...
@@ -291,7 +328,6 @@ class OptCheckLinker(OpWiseCLinker):
...
@@ -291,7 +328,6 @@ class OptCheckLinker(OpWiseCLinker):
# because the incorrect result detected here will cause
# because the incorrect result detected here will cause
# subsequent outputs to be incorrect.
# subsequent outputs to be incorrect.
raise
Exception
(
"OptCheckFailure"
)
raise
Exception
(
"OptCheckFailure"
)
print
>>
sys
.
stderr
,
'OptCheck PASS'
if
0
:
#OLD CODE
if
0
:
#OLD CODE
#print out the summary of the first problematic equivalence group
#print out the summary of the first problematic equivalence group
...
@@ -321,7 +357,9 @@ NODEFAULT = ['NODEFAULT']
...
@@ -321,7 +357,9 @@ NODEFAULT = ['NODEFAULT']
class
OptCheckFunctionMaker
(
FunctionMaker
):
class
OptCheckFunctionMaker
(
FunctionMaker
):
def
__init__
(
self
,
inputs
,
outputs
,
optimizer
,
def
__init__
(
self
,
inputs
,
outputs
,
optimizer
,
accept_inplace
=
False
,
function_builder
=
Function
):
chances_for_optimizer_to_screw_up
=
10
,
accept_inplace
=
False
,
function_builder
=
Function
):
"""
"""
:type inputs: a list of SymbolicInput instances
:type inputs: a list of SymbolicInput instances
...
@@ -350,17 +388,39 @@ class OptCheckFunctionMaker(FunctionMaker):
...
@@ -350,17 +388,39 @@ class OptCheckFunctionMaker(FunctionMaker):
expanded_inputs
=
reduce
(
list
.
__add__
,
[
list
(
z
)
for
x
,
y
,
z
in
indices
],
[])
expanded_inputs
=
reduce
(
list
.
__add__
,
[
list
(
z
)
for
x
,
y
,
z
in
indices
],
[])
# make the env
# make the env
env
,
additional_outputs
,
equivalence_tracker
=
optcheck_env
(
expanded_inputs
,
outputs
,
accept_inplace
)
for
i
in
xrange
(
chances_for_optimizer_to_screw_up
):
self
.
env
=
env
env
,
additional_outputs
,
equivalence_tracker
=
optcheck_env
(
expanded_inputs
,
outputs
,
accept_inplace
)
env
.
equivalence_tracker
=
equivalence_tracker
# optimize the env
optimizer
(
env
)
if
i
:
li
=
env
.
equivalence_tracker
.
event_list
l0
=
env0
.
equivalence_tracker
.
event_list
if
li
!=
l0
:
print
>>
sys
.
stderr
,
"WARNING: Optimization process is unstable"
for
j
in
xrange
(
max
(
len
(
li
),
len
(
l0
))):
if
li
[
j
]
!=
l0
[
j
]:
print
>>
sys
.
stderr
,
"* "
,
j
print
>>
sys
.
stderr
,
" "
,
str
(
li
[
j
])
if
j
<
len
(
li
)
else
'-'
print
>>
sys
.
stderr
,
" "
,
str
(
l0
[
j
])
if
j
<
len
(
l0
)
else
'-'
else
:
pass
linker
=
OptCheckLinker
()
print
>>
sys
.
stderr
,
"EXITING"
sys
.
exit
(
1
)
break
else
:
print
>>
sys
.
stdout
,
"OPTCHECK: optimization"
,
i
,
"of"
,
len
(
li
),
"events was stable."
else
:
env0
=
env
# optimize the env
optimizer
(
env
)
env
.
equivalence_tracker
=
equivalence_tracker
del
env0
self
.
env
=
env
#equivalence_tracker.printstuff()
#equivalence_tracker.printstuff()
linker
=
OptCheckLinker
()
#the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer.
#the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer.
no_borrow
=
[
output
for
output
,
spec
in
zip
(
env
.
outputs
,
outputs
+
additional_outputs
)
if
not
spec
.
borrow
]
no_borrow
=
[
output
for
output
,
spec
in
zip
(
env
.
outputs
,
outputs
+
additional_outputs
)
if
not
spec
.
borrow
]
...
@@ -487,11 +547,14 @@ class OptCheck(Mode):
...
@@ -487,11 +547,14 @@ class OptCheck(Mode):
# function_module.function
# function_module.function
def
function_maker
(
self
,
i
,
o
,
m
,
*
args
,
**
kwargs
):
def
function_maker
(
self
,
i
,
o
,
m
,
*
args
,
**
kwargs
):
assert
m
is
self
assert
m
is
self
return
OptCheckFunctionMaker
(
i
,
o
,
self
.
optimizer
,
*
args
,
**
kwargs
)
return
OptCheckFunctionMaker
(
i
,
o
,
self
.
optimizer
,
def
__init__
(
self
,
optimizer
=
'fast_run'
):
chances_for_optimizer_to_screw_up
=
self
.
stability_patience
,
*
args
,
**
kwargs
)
def
__init__
(
self
,
optimizer
=
'fast_run'
,
stability_patience
=
10
):
super
(
OptCheck
,
self
)
.
__init__
(
super
(
OptCheck
,
self
)
.
__init__
(
optimizer
=
optimizer
,
optimizer
=
optimizer
,
linker
=
OptCheckLinker
)
linker
=
OptCheckLinker
)
self
.
stability_patience
=
stability_patience
theano/tensor/opt.py
浏览文件 @
608c1a07
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
from
..
import
gof
from
..
import
gof
from
..gof
import
opt
,
InconsistencyError
,
TopoOptimizer
from
..gof
import
opt
,
InconsistencyError
,
TopoOptimizer
,
graph
from
elemwise
import
Elemwise
,
DimShuffle
from
elemwise
import
Elemwise
,
DimShuffle
from
..
import
scalar
from
..
import
scalar
import
basic
as
T
import
basic
as
T
...
@@ -47,7 +47,7 @@ def insert_inplace_optimizer(env):
...
@@ -47,7 +47,7 @@ def insert_inplace_optimizer(env):
x + y + z -> x += y += z
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
"""
for
node
in
list
(
env
.
nodes
):
for
node
in
list
(
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
)
):
op
=
node
.
op
op
=
node
.
op
if
not
isinstance
(
op
,
Elemwise
):
if
not
isinstance
(
op
,
Elemwise
):
continue
continue
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论