Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cba9c812
提交
cba9c812
authored
7月 15, 2015
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3130 from harlouci/flake8_gof
Flake8 gof
上级
a1783c0b
4bae84c6
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
12 个修改的文件
包含
105 行增加
和
78 行删除
+105
-78
configparser.py
theano/configparser.py
+1
-1
cc.py
theano/gof/cc.py
+0
-0
cmodule.py
theano/gof/cmodule.py
+0
-0
destroyhandler.py
theano/gof/destroyhandler.py
+33
-31
fg.py
theano/gof/fg.py
+20
-21
link.py
theano/gof/link.py
+18
-16
opt.py
theano/gof/opt.py
+0
-0
utils.py
theano/gof/utils.py
+30
-0
nvcc_compiler.py
theano/sandbox/cuda/nvcc_compiler.py
+1
-1
utils.py
theano/sparse/utils.py
+1
-1
utils.py
theano/tensor/utils.py
+1
-1
test_flake8.py
theano/tests/test_flake8.py
+0
-6
没有找到文件。
theano/configparser.py
浏览文件 @
cba9c812
...
@@ -177,7 +177,7 @@ def get_config_md5():
...
@@ -177,7 +177,7 @@ def get_config_md5():
"""
"""
all_opts
=
sorted
([
c
for
c
in
_config_var_list
if
c
.
in_c_key
],
all_opts
=
sorted
([
c
for
c
in
_config_var_list
if
c
.
in_c_key
],
key
=
lambda
cv
:
cv
.
fullname
)
key
=
lambda
cv
:
cv
.
fullname
)
return
theano
.
gof
.
cc
.
hash_from_code
(
'
\n
'
.
join
(
return
theano
.
gof
.
utils
.
hash_from_code
(
'
\n
'
.
join
(
[
'
%
s =
%
s'
%
(
cv
.
fullname
,
cv
.
__get__
())
for
cv
in
all_opts
]))
[
'
%
s =
%
s'
%
(
cv
.
fullname
,
cv
.
__get__
())
for
cv
in
all_opts
]))
...
...
theano/gof/cc.py
浏览文件 @
cba9c812
差异被折叠。
点击展开。
theano/gof/cmodule.py
浏览文件 @
cba9c812
差异被折叠。
点击展开。
theano/gof/destroyhandler.py
浏览文件 @
cba9c812
...
@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
...
@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
"""
"""
# These are lists of Variable instances
# These are lists of Variable instances
inputs
=
fgraph
.
inputs
outputs
=
fgraph
.
outputs
outputs
=
fgraph
.
outputs
# this is hard-coded reimplementation of functions from graph.py
# this is hard-coded reimplementation of functions from graph.py
...
@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings):
...
@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings):
# (defaultdict runs faster than dict in the case where the key
# (defaultdict runs faster than dict in the case where the key
# is not in the dictionary, at least in CPython)
# is not in the dictionary, at least in CPython)
iset
=
set
(
inputs
)
# IG: I tried converting parent_counts to use an id for the key,
# IG: I tried converting parent_counts to use an id for the key,
# so that the dict would do reference counting on its keys.
# so that the dict would do reference counting on its keys.
# This caused a slowdown.
# This caused a slowdown.
...
@@ -236,9 +233,9 @@ def fast_inplace_check(inputs):
...
@@ -236,9 +233,9 @@ def fast_inplace_check(inputs):
protected_inputs
.
extend
(
fgraph
.
outputs
)
protected_inputs
.
extend
(
fgraph
.
outputs
)
inputs
=
[
i
for
i
in
inputs
if
inputs
=
[
i
for
i
in
inputs
if
not
isinstance
(
i
,
graph
.
Constant
)
not
isinstance
(
i
,
graph
.
Constant
)
and
and
not
fgraph
.
destroyers
(
i
)
not
fgraph
.
destroyers
(
i
)
and
and
i
not
in
protected_inputs
]
i
not
in
protected_inputs
]
return
inputs
return
inputs
if
0
:
if
0
:
...
@@ -293,7 +290,7 @@ if 0:
...
@@ -293,7 +290,7 @@ if 0:
TODO: WRITEME: what does this do besides the checks?
TODO: WRITEME: what does this do besides the checks?
"""
"""
#
###### Do the checking ##########
#
#
Do the checking
#
already_there
=
False
already_there
=
False
if
self
.
fgraph
not
in
[
None
,
fgraph
]:
if
self
.
fgraph
not
in
[
None
,
fgraph
]:
raise
Exception
(
"A DestroyHandler instance can only serve"
raise
Exception
(
"A DestroyHandler instance can only serve"
...
@@ -309,7 +306,7 @@ if 0:
...
@@ -309,7 +306,7 @@ if 0:
"DestroyHandler feature is already present or in"
"DestroyHandler feature is already present or in"
" conflict with another plugin."
)
" conflict with another plugin."
)
#
###### end of checking ###########
#
#
end of checking
#
def
get_destroyers_of
(
r
):
def
get_destroyers_of
(
r
):
droot
,
impact
,
root_destroyer
=
self
.
refresh_droot_impact
()
droot
,
impact
,
root_destroyer
=
self
.
refresh_droot_impact
()
...
@@ -362,8 +359,8 @@ if 0:
...
@@ -362,8 +359,8 @@ if 0:
"Multiple destroyers of
%
s"
%
input_root
)
"Multiple destroyers of
%
s"
%
input_root
)
droot
[
input_root
]
=
input_root
droot
[
input_root
]
=
input_root
root_destroyer
[
input_root
]
=
app
root_destroyer
[
input_root
]
=
app
#input_impact = set([input_root])
#
input_impact = set([input_root])
#add_impact(input_root, self.view_o, input_impact)
#
add_impact(input_root, self.view_o, input_impact)
input_impact
=
get_impact
(
input_root
,
self
.
view_o
)
input_impact
=
get_impact
(
input_root
,
self
.
view_o
)
for
v
in
input_impact
:
for
v
in
input_impact
:
assert
v
not
in
droot
assert
v
not
in
droot
...
@@ -390,7 +387,7 @@ if 0:
...
@@ -390,7 +387,7 @@ if 0:
def
on_import
(
self
,
fgraph
,
app
,
reason
):
def
on_import
(
self
,
fgraph
,
app
,
reason
):
"""Add Apply instance to set which must be computed"""
"""Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import")
#
if app in self.debug_all_apps: raise ProtocolError("double import")
# self.debug_all_apps.add(app)
# self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
...
@@ -421,7 +418,7 @@ if 0:
...
@@ -421,7 +418,7 @@ if 0:
def
on_prune
(
self
,
fgraph
,
app
,
reason
):
def
on_prune
(
self
,
fgraph
,
app
,
reason
):
"""Remove Apply instance from set which must be computed"""
"""Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#
if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# self.debug_all_apps.remove(app)
# self.debug_all_apps.remove(app)
# UPDATE self.clients
# UPDATE self.clients
...
@@ -458,7 +455,7 @@ if 0:
...
@@ -458,7 +455,7 @@ if 0:
# considered 'outputs' of the graph.
# considered 'outputs' of the graph.
pass
pass
else
:
else
:
#if app not in self.debug_all_apps: raise ProtocolError("change without import")
#
if app not in self.debug_all_apps: raise ProtocolError("change without import")
# UPDATE self.clients
# UPDATE self.clients
self
.
clients
[
old_r
][
app
]
-=
1
self
.
clients
[
old_r
][
app
]
-=
1
...
@@ -529,9 +526,10 @@ if 0:
...
@@ -529,9 +526,10 @@ if 0:
droot
,
impact
,
__ignore
=
self
.
refresh_droot_impact
()
droot
,
impact
,
__ignore
=
self
.
refresh_droot_impact
()
# check for destruction of constants
# check for destruction of constants
illegal_destroy
=
[
r
for
r
in
droot
if
illegal_destroy
=
[
getattr
(
r
.
tag
,
'indestructible'
,
False
)
or
r
for
r
in
droot
if
isinstance
(
r
,
graph
.
Constant
)]
getattr
(
r
.
tag
,
'indestructible'
,
False
)
or
isinstance
(
r
,
graph
.
Constant
)]
if
illegal_destroy
:
if
illegal_destroy
:
# print 'destroying illegally'
# print 'destroying illegally'
raise
InconsistencyError
(
raise
InconsistencyError
(
...
@@ -603,7 +601,7 @@ if 0:
...
@@ -603,7 +601,7 @@ if 0:
if
input
in
root_impact
\
if
input
in
root_impact
\
and
(
i
not
in
tolerated
or
input
is
not
destroyed_variable
):
and
(
i
not
in
tolerated
or
input
is
not
destroyed_variable
):
raise
InconsistencyError
(
"Input aliasing:
%
s (
%
i,
%
i)"
raise
InconsistencyError
(
"Input aliasing:
%
s (
%
i,
%
i)"
%
(
app
,
destroyed_idx
,
i
))
%
(
app
,
destroyed_idx
,
i
))
# add the rule: app must be preceded by all other Apply instances that
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
# depend on destroyed_input
...
@@ -621,7 +619,7 @@ if 0:
...
@@ -621,7 +619,7 @@ if 0:
return
rval
return
rval
class
DestroyHandler
(
toolbox
.
Bookkeeper
):
class
DestroyHandler
(
toolbox
.
Bookkeeper
):
# noqa
"""
"""
The DestroyHandler class detects when a graph is impossible to evaluate
The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations.
because of aliasing and destructive operations.
...
@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper):
...
@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper):
TODO: WRITEME: what does this do besides the checks?
TODO: WRITEME: what does this do besides the checks?
"""
"""
#
###### Do the checking ##########
#
#
Do the checking
#
already_there
=
False
already_there
=
False
if
self
.
fgraph
is
fgraph
:
if
self
.
fgraph
is
fgraph
:
already_there
=
True
already_there
=
True
...
@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper):
...
@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper):
"DestroyHandler feature is already present"
"DestroyHandler feature is already present"
" or in conflict with another plugin."
)
" or in conflict with another plugin."
)
#
###### Annotate the FunctionGraph ###########
#
#
Annotate the FunctionGraph
#
self
.
unpickle
(
fgraph
)
self
.
unpickle
(
fgraph
)
fgraph
.
destroy_handler
=
self
fgraph
.
destroy_handler
=
self
...
@@ -945,12 +943,13 @@ class DestroyHandler(toolbox.Bookkeeper):
...
@@ -945,12 +943,13 @@ class DestroyHandler(toolbox.Bookkeeper):
droot
,
impact
,
__ignore
=
self
.
refresh_droot_impact
()
droot
,
impact
,
__ignore
=
self
.
refresh_droot_impact
()
# check for destruction of constants
# check for destruction of constants
illegal_destroy
=
[
r
for
r
in
droot
if
\
illegal_destroy
=
[
r
for
r
in
droot
if
getattr
(
r
.
tag
,
'indestructible'
,
False
)
or
\
getattr
(
r
.
tag
,
'indestructible'
,
False
)
or
isinstance
(
r
,
graph
.
Constant
)]
isinstance
(
r
,
graph
.
Constant
)]
if
illegal_destroy
:
if
illegal_destroy
:
raise
InconsistencyError
(
"Attempting to destroy indestructible variables:
%
s"
%
raise
InconsistencyError
(
illegal_destroy
)
"Attempting to destroy indestructible variables:
%
s"
%
illegal_destroy
)
# add destroyed variable clients as computational dependencies
# add destroyed variable clients as computational dependencies
for
app
in
self
.
destroyers
:
for
app
in
self
.
destroyers
:
...
@@ -995,24 +994,27 @@ class DestroyHandler(toolbox.Bookkeeper):
...
@@ -995,24 +994,27 @@ class DestroyHandler(toolbox.Bookkeeper):
# CHECK FOR INPUT ALIASING
# CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
# OPT: pre-compute this on import
tolerate_same
=
getattr
(
app
.
op
,
'destroyhandler_tolerate_same'
,
[])
tolerate_same
=
getattr
(
app
.
op
,
'destroyhandler_tolerate_same'
,
[])
assert
isinstance
(
tolerate_same
,
list
)
assert
isinstance
(
tolerate_same
,
list
)
tolerated
=
OrderedSet
(
idx1
for
idx0
,
idx1
in
tolerate_same
tolerated
=
OrderedSet
(
idx1
for
idx0
,
idx1
in
tolerate_same
if
idx0
==
destroyed_idx
)
if
idx0
==
destroyed_idx
)
tolerated
.
add
(
destroyed_idx
)
tolerated
.
add
(
destroyed_idx
)
tolerate_aliased
=
getattr
(
app
.
op
,
'destroyhandler_tolerate_aliased'
,
[])
tolerate_aliased
=
getattr
(
app
.
op
,
'destroyhandler_tolerate_aliased'
,
[])
assert
isinstance
(
tolerate_aliased
,
list
)
assert
isinstance
(
tolerate_aliased
,
list
)
ignored
=
OrderedSet
(
idx1
for
idx0
,
idx1
in
tolerate_aliased
ignored
=
OrderedSet
(
idx1
for
idx0
,
idx1
in
tolerate_aliased
if
idx0
==
destroyed_idx
)
if
idx0
==
destroyed_idx
)
# print 'tolerated', tolerated
# print 'tolerated', tolerated
# print 'ignored', ignored
# print 'ignored', ignored
for
i
,
input
in
enumerate
(
app
.
inputs
):
for
i
,
input
in
enumerate
(
app
.
inputs
):
if
i
in
ignored
:
if
i
in
ignored
:
continue
continue
if
input
in
root_impact
\
if
input
in
root_impact
\
and
(
i
not
in
tolerated
or
input
is
not
destroyed_variable
):
and
(
i
not
in
tolerated
or
input
is
not
destroyed_variable
):
raise
InconsistencyError
(
"Input aliasing:
%
s (
%
i,
%
i)"
raise
InconsistencyError
(
"Input aliasing:
%
s (
%
i,
%
i)"
%
(
app
,
destroyed_idx
,
i
))
%
(
app
,
destroyed_idx
,
i
))
# add the rule: app must be preceded by all other Apply instances that
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
# depend on destroyed_input
...
...
theano/gof/fg.py
浏览文件 @
cba9c812
...
@@ -13,7 +13,6 @@ from theano.gof import graph
...
@@ -13,7 +13,6 @@ from theano.gof import graph
from
theano.gof
import
utils
from
theano.gof
import
utils
from
theano.gof
import
toolbox
from
theano.gof
import
toolbox
from
theano
import
config
from
theano
import
config
import
warnings
from
theano.compat
import
OrderedDict
from
theano.compat
import
OrderedDict
from
six
import
iteritems
,
itervalues
from
six
import
iteritems
,
itervalues
...
@@ -22,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet
...
@@ -22,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet
NullType
=
None
NullType
=
None
class
CachedConstantError
(
Exception
):
class
CachedConstantError
(
Exception
):
"""An exception thrown when we put in a FunctionGraph a Constant
"""An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
that is cached. This should not happen as the user can reuse this
...
@@ -143,7 +143,7 @@ class FunctionGraph(utils.object2):
...
@@ -143,7 +143,7 @@ class FunctionGraph(utils.object2):
self
.
variable_locks
=
{}
self
.
variable_locks
=
{}
self
.
profile
=
None
self
.
profile
=
None
#
## Setup a Variable ##
#
#
Setup a Variable
#
def
__setup_r__
(
self
,
r
):
def
__setup_r__
(
self
,
r
):
# sets up r so it belongs to this fgraph
# sets up r so it belongs to this fgraph
if
getattr
(
r
,
'cached'
,
False
):
if
getattr
(
r
,
'cached'
,
False
):
...
@@ -152,12 +152,12 @@ class FunctionGraph(utils.object2):
...
@@ -152,12 +152,12 @@ class FunctionGraph(utils.object2):
" graph that has a cached constant. This should not happen."
" graph that has a cached constant. This should not happen."
" Clone the graph before building the FunctionGraph."
)
" Clone the graph before building the FunctionGraph."
)
if
(
hasattr
(
r
,
'fgraph'
)
and
if
(
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
self
):
r
.
fgraph
is
not
self
):
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
r
.
fgraph
=
self
r
.
fgraph
=
self
r
.
clients
=
[]
r
.
clients
=
[]
#self.execute_callbacks('on_setup_variable', r)
#
self.execute_callbacks('on_setup_variable', r)
def
__setup_node__
(
self
,
node
):
def
__setup_node__
(
self
,
node
):
# sets up node so it belongs to this fgraph
# sets up node so it belongs to this fgraph
...
@@ -177,7 +177,7 @@ class FunctionGraph(utils.object2):
...
@@ -177,7 +177,7 @@ class FunctionGraph(utils.object2):
str
(
node
.
op
),
str
(
node
.
op
.
destroy_map
)))
str
(
node
.
op
),
str
(
node
.
op
.
destroy_map
)))
node
.
fgraph
=
self
node
.
fgraph
=
self
node
.
deps
=
{}
node
.
deps
=
{}
#self.execute_callbacks('on_setup_node', node)
#
self.execute_callbacks('on_setup_node', node)
def
disown
(
self
):
def
disown
(
self
):
""" WRITEME
""" WRITEME
...
@@ -201,7 +201,7 @@ class FunctionGraph(utils.object2):
...
@@ -201,7 +201,7 @@ class FunctionGraph(utils.object2):
self
.
inputs
=
None
self
.
inputs
=
None
self
.
outputs
=
None
self
.
outputs
=
None
#
## clients ##
#
#
clients
#
def
clients
(
self
,
r
):
def
clients
(
self
,
r
):
"""
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Set of all the (node, i) pairs such that node.inputs[i] is r.
...
@@ -221,9 +221,9 @@ class FunctionGraph(utils.object2):
...
@@ -221,9 +221,9 @@ class FunctionGraph(utils.object2):
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
print
(
'ERROR: clients intersect!'
,
file
=
sys
.
stderr
)
print
(
'ERROR: clients intersect!'
,
file
=
sys
.
stderr
)
print
(
' RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
print
(
' RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
r
.
clients
],
file
=
sys
.
stderr
)
for
n
,
i
in
r
.
clients
],
file
=
sys
.
stderr
)
print
(
' NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
print
(
' NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
new_clients
],
file
=
sys
.
stderr
)
for
n
,
i
in
new_clients
],
file
=
sys
.
stderr
)
assert
not
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
))
assert
not
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
))
r
.
clients
+=
new_clients
r
.
clients
+=
new_clients
...
@@ -245,7 +245,7 @@ class FunctionGraph(utils.object2):
...
@@ -245,7 +245,7 @@ class FunctionGraph(utils.object2):
return
True
return
True
return
False
return
False
#
## import ##
#
#
import
#
def
__import_r__
(
self
,
variable
,
reason
):
def
__import_r__
(
self
,
variable
,
reason
):
global
NullType
global
NullType
if
NullType
is
None
:
if
NullType
is
None
:
...
@@ -279,9 +279,8 @@ class FunctionGraph(utils.object2):
...
@@ -279,9 +279,8 @@ class FunctionGraph(utils.object2):
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
self
:
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
self
:
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
if
(
r
.
owner
is
None
and
if
(
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
):
r
not
in
self
.
inputs
):
# Verbose error message
# Verbose error message
# Show a complete chain of variables from the missing input to an output
# Show a complete chain of variables from the missing input to an output
if
config
.
exception_verbosity
==
'high'
:
if
config
.
exception_verbosity
==
'high'
:
...
@@ -373,7 +372,7 @@ class FunctionGraph(utils.object2):
...
@@ -373,7 +372,7 @@ class FunctionGraph(utils.object2):
assert
node
.
fgraph
is
self
assert
node
.
fgraph
is
self
self
.
execute_callbacks
(
'on_import'
,
node
,
reason
)
self
.
execute_callbacks
(
'on_import'
,
node
,
reason
)
#
## prune ##
#
#
prune
#
def
__prune_r__
(
self
,
variable
,
reason
=
None
):
def
__prune_r__
(
self
,
variable
,
reason
=
None
):
"""Should be called for variable that aren't used anymore:
"""Should be called for variable that aren't used anymore:
len(var.clients) == 0
len(var.clients) == 0
...
@@ -430,7 +429,7 @@ class FunctionGraph(utils.object2):
...
@@ -430,7 +429,7 @@ class FunctionGraph(utils.object2):
self
.
__remove_clients__
(
input
,
[(
apply_node
,
i
)],
reason
=
reason
)
self
.
__remove_clients__
(
input
,
[(
apply_node
,
i
)],
reason
=
reason
)
# self.__prune_r__(apply_node.inputs)
# self.__prune_r__(apply_node.inputs)
#
## change input ##
#
#
change input
#
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
"""WRITEME
"""WRITEME
Changes node.inputs[i] to new_r.
Changes node.inputs[i] to new_r.
...
@@ -475,7 +474,7 @@ class FunctionGraph(utils.object2):
...
@@ -475,7 +474,7 @@ class FunctionGraph(utils.object2):
if
prune
:
if
prune
:
self
.
__prune_r__
(
r
,
reason
=
reason
)
self
.
__prune_r__
(
r
,
reason
=
reason
)
#
## replace ##
#
#
replace
#
def
replace
(
self
,
r
,
new_r
,
reason
=
None
,
verbose
=
None
):
def
replace
(
self
,
r
,
new_r
,
reason
=
None
,
verbose
=
None
):
""" WRITEME
""" WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
This is the main interface to manipulate the subgraph in FunctionGraph.
...
@@ -582,7 +581,7 @@ class FunctionGraph(utils.object2):
...
@@ -582,7 +581,7 @@ class FunctionGraph(utils.object2):
if
detach
is
not
None
:
if
detach
is
not
None
:
detach
(
self
)
detach
(
self
)
#
## callback utils ##
#
#
callback utils
#
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
"""WRITEME
"""WRITEME
Calls
Calls
...
@@ -618,7 +617,7 @@ class FunctionGraph(utils.object2):
...
@@ -618,7 +617,7 @@ class FunctionGraph(utils.object2):
d
[
feature
]
=
fn
(
*
args
)
d
[
feature
]
=
fn
(
*
args
)
return
d
return
d
#
## misc ##
#
#
misc
#
def
toposort
(
self
):
def
toposort
(
self
):
"""WRITEME
"""WRITEME
Returns an ordering of the graph's Apply nodes such that:
Returns an ordering of the graph's Apply nodes such that:
...
@@ -712,8 +711,8 @@ class FunctionGraph(utils.object2):
...
@@ -712,8 +711,8 @@ class FunctionGraph(utils.object2):
missing
,
excess
)
missing
,
excess
)
for
variable
in
variables
:
for
variable
in
variables
:
if
(
variable
.
owner
is
None
and
if
(
variable
.
owner
is
None
and
variable
not
in
self
.
inputs
and
variable
not
in
self
.
inputs
and
not
isinstance
(
variable
,
graph
.
Constant
)):
not
isinstance
(
variable
,
graph
.
Constant
)):
raise
Exception
(
"Undeclared input."
,
variable
)
raise
Exception
(
"Undeclared input."
,
variable
)
if
variable
.
fgraph
is
not
self
:
if
variable
.
fgraph
is
not
self
:
raise
Exception
(
"Variable should belong to the FunctionGraph."
,
raise
Exception
(
"Variable should belong to the FunctionGraph."
,
...
@@ -737,7 +736,7 @@ class FunctionGraph(utils.object2):
...
@@ -737,7 +736,7 @@ class FunctionGraph(utils.object2):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__str__
()
return
self
.
__str__
()
#
## clone ##
#
#
clone
#
def
clone
(
self
,
check_integrity
=
True
):
def
clone
(
self
,
check_integrity
=
True
):
"""WRITEME"""
"""WRITEME"""
return
self
.
clone_get_equiv
(
check_integrity
)[
0
]
return
self
.
clone_get_equiv
(
check_integrity
)[
0
]
...
...
theano/gof/link.py
浏览文件 @
cba9c812
...
@@ -7,14 +7,14 @@ import traceback
...
@@ -7,14 +7,14 @@ import traceback
import
numpy
import
numpy
import
theano
import
theano
from
theano.compat
import
PY3
,
izip
from
theano.compat
import
izip
from
six
import
reraise
from
six
import
reraise
from
six.moves
import
StringIO
from
six.moves
import
StringIO
from
theano.gof
import
utils
from
theano.gof
import
utils
from
theano.gof
import
graph
from
theano.gof
import
graph
from
theano.gof.type
import
Type
from
theano.gof.type
import
Type
from
.utils
import
MethodNotDefined
,
undef
from
.utils
import
undef
__excepthook
=
sys
.
excepthook
__excepthook
=
sys
.
excepthook
...
@@ -281,9 +281,9 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
...
@@ -281,9 +281,9 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
else
:
else
:
detailed_err_msg
+=
"
\n
"
detailed_err_msg
+=
"
\n
"
detailed_err_msg
+=
" TotalSize:
%
s Byte(s)
%.3
f GB
\n
"
%
(
detailed_err_msg
+=
" TotalSize:
%
s Byte(s)
%.3
f GB
\n
"
%
(
total_size
,
total_size
/
1024.
/
1024
/
1024
)
total_size
,
total_size
/
1024.
/
1024
/
1024
)
detailed_err_msg
+=
" TotalSize inputs:
%
s Byte(s)
%.3
f BG
\n
"
%
(
detailed_err_msg
+=
" TotalSize inputs:
%
s Byte(s)
%.3
f BG
\n
"
%
(
total_size_inputs
,
total_size_inputs
/
1024.
/
1024
/
1024
)
total_size_inputs
,
total_size_inputs
/
1024.
/
1024
/
1024
)
else
:
else
:
hints
.
append
(
hints
.
append
(
...
@@ -326,7 +326,7 @@ class Linker(object):
...
@@ -326,7 +326,7 @@ class Linker(object):
raise
utils
.
MethodNotDefined
(
"make_thunk"
,
type
(
self
),
raise
utils
.
MethodNotDefined
(
"make_thunk"
,
type
(
self
),
self
.
__class__
.
__name__
)
self
.
__class__
.
__name__
)
#
# DELETEME #
#
#
DELETEME
#
def
make_function
(
self
,
unpack_single
=
True
,
**
kwargs
):
def
make_function
(
self
,
unpack_single
=
True
,
**
kwargs
):
"""
"""
Returns a function that takes values corresponding to the inputs of the
Returns a function that takes values corresponding to the inputs of the
...
@@ -350,8 +350,8 @@ class Linker(object):
...
@@ -350,8 +350,8 @@ class Linker(object):
def
execute
(
*
args
):
def
execute
(
*
args
):
def
e_arity
(
takes
,
got
):
def
e_arity
(
takes
,
got
):
return
'Function call takes exactly
%
i
%
s (
%
i given)'
\
return
'Function call takes exactly
%
i
%
s (
%
i given)'
%
(
%
(
takes
,
[
'argument'
,
'arguments'
][
takes
>
1
],
got
)
takes
,
[
'argument'
,
'arguments'
][
takes
>
1
],
got
)
if
(
len
(
args
)
!=
len
(
inputs
)):
if
(
len
(
args
)
!=
len
(
inputs
)):
raise
TypeError
(
e_arity
(
len
(
inputs
),
len
(
args
)))
raise
TypeError
(
e_arity
(
len
(
inputs
),
len
(
args
)))
for
arg
,
variable
in
izip
(
args
,
inputs
):
for
arg
,
variable
in
izip
(
args
,
inputs
):
...
@@ -394,7 +394,7 @@ class Container(object):
...
@@ -394,7 +394,7 @@ class Container(object):
"""
"""
if
not
isinstance
(
storage
,
list
)
or
not
len
(
storage
)
>=
1
:
if
not
isinstance
(
storage
,
list
)
or
not
len
(
storage
)
>=
1
:
raise
TypeError
(
"storage must be a list of length at least one"
)
raise
TypeError
(
"storage must be a list of length at least one"
)
#self.r = r
#
self.r = r
if
isinstance
(
r
,
Type
):
if
isinstance
(
r
,
Type
):
self
.
type
=
r
self
.
type
=
r
else
:
else
:
...
@@ -454,12 +454,11 @@ class Container(object):
...
@@ -454,12 +454,11 @@ class Container(object):
deepcopy
(
self
.
strict
,
memo
=
memo
),
deepcopy
(
self
.
strict
,
memo
=
memo
),
deepcopy
(
self
.
allow_downcast
,
memo
=
memo
),
deepcopy
(
self
.
allow_downcast
,
memo
=
memo
),
deepcopy
(
self
.
name
,
memo
=
memo
),
deepcopy
(
self
.
name
,
memo
=
memo
),
)
)
# Work around NumPy deepcopy of ndarray with 0 dimention that
# Work around NumPy deepcopy of ndarray with 0 dimention that
# don't return an ndarray.
# don't return an ndarray.
if
(
r
.
storage
[
0
]
is
not
None
and
if
(
r
.
storage
[
0
]
is
not
None
and
not
self
.
type
.
is_valid_value
(
r
.
storage
[
0
])):
not
self
.
type
.
is_valid_value
(
r
.
storage
[
0
])):
assert
not
data_was_in_memo
assert
not
data_was_in_memo
assert
self
.
type
.
is_valid_value
(
self
.
storage
[
0
])
assert
self
.
type
.
is_valid_value
(
self
.
storage
[
0
])
# This should also work for read only container.
# This should also work for read only container.
...
@@ -672,7 +671,7 @@ class PerformLinker(LocalLinker):
...
@@ -672,7 +671,7 @@ class PerformLinker(LocalLinker):
no_recycling
=
[]
no_recycling
=
[]
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
return
type
(
self
)(
allow_gc
=
self
.
allow_gc
)
.
accept
(
fgraph
,
no_recycling
)
return
type
(
self
)(
allow_gc
=
self
.
allow_gc
)
.
accept
(
fgraph
,
no_recycling
)
#raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
#
raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
self
.
fgraph
=
fgraph
self
.
fgraph
=
fgraph
self
.
no_recycling
=
no_recycling
self
.
no_recycling
=
no_recycling
return
self
return
self
...
@@ -721,9 +720,12 @@ class PerformLinker(LocalLinker):
...
@@ -721,9 +720,12 @@ class PerformLinker(LocalLinker):
for
node
in
order
:
for
node
in
order
:
if
self
.
allow_gc
:
if
self
.
allow_gc
:
post_thunk_old_storage
.
append
([
storage_map
[
input
]
post_thunk_old_storage
.
append
(
for
input
in
node
.
inputs
[
storage_map
[
input
]
if
(
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
node
==
last_user
[
input
]])
for
input
in
node
.
inputs
if
(
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
(
node
==
last_user
[
input
])])
if
no_recycling
is
True
:
if
no_recycling
is
True
:
# True seems like some special code for *everything*?? -JB
# True seems like some special code for *everything*?? -JB
...
@@ -855,7 +857,7 @@ class WrapLinker(Linker):
...
@@ -855,7 +857,7 @@ class WrapLinker(Linker):
make_all
+=
[
l
.
make_all
(
**
kwargs
)
for
l
in
self
.
linkers
[
1
:]]
make_all
+=
[
l
.
make_all
(
**
kwargs
)
for
l
in
self
.
linkers
[
1
:]]
fns
,
input_lists
,
output_lists
,
thunk_lists
,
order_lists
\
fns
,
input_lists
,
output_lists
,
thunk_lists
,
order_lists
\
=
zip
(
*
make_all
)
=
zip
(
*
make_all
)
order_list0
=
order_lists
[
0
]
order_list0
=
order_lists
[
0
]
for
order_list
in
order_lists
[
1
:]:
for
order_list
in
order_lists
[
1
:]:
...
...
theano/gof/opt.py
浏览文件 @
cba9c812
差异被折叠。
点击展开。
theano/gof/utils.py
浏览文件 @
cba9c812
...
@@ -3,9 +3,11 @@ import linecache
...
@@ -3,9 +3,11 @@ import linecache
import
traceback
import
traceback
import
sys
import
sys
import
numpy
from
six
import
iteritems
from
six
import
iteritems
from
theano
import
config
from
theano
import
config
from
theano.compat
import
PY3
def
simple_extract_stack
(
f
=
None
,
limit
=
None
):
def
simple_extract_stack
(
f
=
None
,
limit
=
None
):
...
@@ -435,3 +437,31 @@ def remove(predicate, coll):
...
@@ -435,3 +437,31 @@ def remove(predicate, coll):
[1, 3]
[1, 3]
"""
"""
return
[
x
for
x
in
coll
if
not
predicate
(
x
)]
return
[
x
for
x
in
coll
if
not
predicate
(
x
)]
if
PY3
:
import
hashlib
def
hash_from_code
(
msg
):
# hashlib.md5() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
if
isinstance
(
msg
,
str
):
msg
=
msg
.
encode
()
# Python 3 does not like module names that start with
# a digit.
return
'm'
+
hashlib
.
md5
(
msg
)
.
hexdigest
()
else
:
import
hashlib
def
hash_from_code
(
msg
):
try
:
return
hashlib
.
md5
(
msg
)
.
hexdigest
()
except
TypeError
:
assert
isinstance
(
msg
,
numpy
.
ndarray
)
return
hashlib
.
md5
(
numpy
.
getbuffer
(
msg
))
.
hexdigest
()
def
hash_from_file
(
file_path
):
"""Return the MD5 hash of a file."""
return
hash_from_code
(
open
(
file_path
,
'rb'
)
.
read
())
theano/sandbox/cuda/nvcc_compiler.py
浏览文件 @
cba9c812
...
@@ -10,7 +10,7 @@ import numpy
...
@@ -10,7 +10,7 @@ import numpy
from
theano.compat
import
decode
,
decode_iter
from
theano.compat
import
decode
,
decode_iter
from
theano.gof
import
local_bitwidth
from
theano.gof
import
local_bitwidth
from
theano.gof.
cc
import
hash_from_file
from
theano.gof.
utils
import
hash_from_file
from
theano.gof.cmodule
import
(
std_libs
,
std_lib_dirs
,
from
theano.gof.cmodule
import
(
std_libs
,
std_lib_dirs
,
std_include_dirs
,
dlimport
,
std_include_dirs
,
dlimport
,
Compiler
,
Compiler
,
...
...
theano/sparse/utils.py
浏览文件 @
cba9c812
from
theano.gof.
cc
import
hash_from_code
from
theano.gof.
utils
import
hash_from_code
def
hash_from_sparse
(
data
):
def
hash_from_sparse
(
data
):
...
...
theano/tensor/utils.py
浏览文件 @
cba9c812
...
@@ -2,7 +2,7 @@ import numpy
...
@@ -2,7 +2,7 @@ import numpy
import
theano
import
theano
from
theano.compat
import
izip
from
theano.compat
import
izip
from
theano.gof.
cc
import
hash_from_code
from
theano.gof.
utils
import
hash_from_code
def
hash_from_ndarray
(
data
):
def
hash_from_ndarray
(
data
):
...
...
theano/tests/test_flake8.py
浏览文件 @
cba9c812
...
@@ -233,16 +233,10 @@ whitelist_flake8 = [
...
@@ -233,16 +233,10 @@ whitelist_flake8 = [
"sparse/sandbox/sp2.py"
,
"sparse/sandbox/sp2.py"
,
"sparse/sandbox/truedot.py"
,
"sparse/sandbox/truedot.py"
,
"sparse/sandbox/sp.py"
,
"sparse/sandbox/sp.py"
,
"gof/destroyhandler.py"
,
"gof/unify.py"
,
"gof/unify.py"
,
"gof/graph.py"
,
"gof/graph.py"
,
"gof/__init__.py"
,
"gof/__init__.py"
,
"gof/cc.py"
,
"gof/opt.py"
,
"gof/link.py"
,
"gof/fg.py"
,
"gof/op.py"
,
"gof/op.py"
,
"gof/cmodule.py"
,
"gof/tests/test_cmodule.py"
,
"gof/tests/test_cmodule.py"
,
"gof/tests/test_destroyhandler.py"
,
"gof/tests/test_destroyhandler.py"
,
"gof/tests/test_opt.py"
,
"gof/tests/test_opt.py"
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论