Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c0b294ec
提交
c0b294ec
authored
6月 20, 2016
作者:
Frédéric Bastien
提交者:
GitHub
6月 20, 2016
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4629 from chinnadhurai/ccw_4483_indent_fix
Ccw 4483 indent fix
上级
d9028c7b
99fcfdcb
隐藏空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
245 行增加
和
279 行删除
+245
-279
callcache.py
theano/gof/callcache.py
+14
-0
cc.py
theano/gof/cc.py
+47
-45
cmodule.py
theano/gof/cmodule.py
+20
-4
compiledir.py
theano/gof/compiledir.py
+6
-0
destroyhandler.py
theano/gof/destroyhandler.py
+1
-0
fg.py
theano/gof/fg.py
+7
-19
graph.py
theano/gof/graph.py
+27
-18
link.py
theano/gof/link.py
+16
-14
op.py
theano/gof/op.py
+35
-1
equilibrium.py
theano/gof/sandbox/equilibrium.py
+0
-146
sched.py
theano/gof/sched.py
+3
-0
toolbox.py
theano/gof/toolbox.py
+35
-0
unify.py
theano/gof/unify.py
+25
-28
vm.py
theano/gof/vm.py
+9
-1
test_flake8.py
theano/tests/test_flake8.py
+0
-3
没有找到文件。
theano/gof/callcache.py
浏览文件 @
c0b294ec
...
@@ -17,12 +17,26 @@ class CallCache(object):
...
@@ -17,12 +17,26 @@ class CallCache(object):
self
.
cache
=
{}
self
.
cache
=
{}
def
persist
(
self
,
filename
=
None
):
def
persist
(
self
,
filename
=
None
):
"""
Cache "filename" as a pickle file
"""
if
filename
is
None
:
if
filename
is
None
:
filename
=
self
.
filename
filename
=
self
.
filename
with
open
(
filename
,
'w'
)
as
f
:
with
open
(
filename
,
'w'
)
as
f
:
pickle
.
dump
(
self
.
cache
,
f
)
pickle
.
dump
(
self
.
cache
,
f
)
def
call
(
self
,
fn
,
args
=
(),
key
=
None
):
def
call
(
self
,
fn
,
args
=
(),
key
=
None
):
"""
Retrieve item from the cache(if available)
based on a key
Parameters:
----------
key
parameter to retrieve cache item
fn,args
key to retrieve if "key" is None
"""
if
key
is
None
:
if
key
is
None
:
key
=
(
fn
,
tuple
(
args
))
key
=
(
fn
,
tuple
(
args
))
if
key
not
in
self
.
cache
:
if
key
not
in
self
.
cache
:
...
...
theano/gof/cc.py
浏览文件 @
c0b294ec
...
@@ -61,8 +61,6 @@ def get_persistent_module_cache():
...
@@ -61,8 +61,6 @@ def get_persistent_module_cache():
class
CodeBlock
:
class
CodeBlock
:
"""
"""
WRITEME
Represents a computation unit composed of declare, behavior, and cleanup.
Represents a computation unit composed of declare, behavior, and cleanup.
The constructor initializes a L{CodeBlock} with templatized declare,
The constructor initializes a L{CodeBlock} with templatized declare,
...
@@ -118,6 +116,12 @@ def failure_code_init(sub):
...
@@ -118,6 +116,12 @@ def failure_code_init(sub):
"""
"""
Code for failure in the struct init.
Code for failure in the struct init.
Parameters:
----------
sub
Dictionary used to template the struct.
* failure_var -> must contain a variable name to use for
the failure code.
"""
"""
return
'''{
return
'''{
if (!PyErr_Occurred()) {
if (!PyErr_Occurred()) {
...
@@ -131,10 +135,10 @@ def failure_code_init(sub):
...
@@ -131,10 +135,10 @@ def failure_code_init(sub):
def
code_gen
(
blocks
):
def
code_gen
(
blocks
):
"""
"""
WRITEME
From a list of L{CodeBlock} instances, returns a string
From a list of L{CodeBlock} instances, returns a string
that executes them all in sequence. eg for C{(decl1, task1,
that executes them all in sequence.
Eg for C{(decl1, task1,
cleanup1)} and C{(decl2, task2, cleanup2)} the returned string
cleanup1)} and C{(decl2, task2, cleanup2)} the returned string
will be of the form:
will be of the form:
...
@@ -149,6 +153,12 @@ def code_gen(blocks):
...
@@ -149,6 +153,12 @@ def code_gen(blocks):
cleanup1
cleanup1
}
}
Parameters:
----------
blocks
List of CodeBlock instances such that
* declarations, behavior and cleanup are in the run()
method of the struct
"""
"""
decl
=
""
decl
=
""
head
=
""
head
=
""
...
@@ -162,8 +172,6 @@ def code_gen(blocks):
...
@@ -162,8 +172,6 @@ def code_gen(blocks):
def
struct_gen
(
args
,
struct_builders
,
blocks
,
sub
):
def
struct_gen
(
args
,
struct_builders
,
blocks
,
sub
):
"""
"""
WRITEME
Generates a struct conforming to the following specifications:
Generates a struct conforming to the following specifications:
Parameters
Parameters
...
@@ -453,7 +461,7 @@ def get_c_sync(r, name, sub):
...
@@ -453,7 +461,7 @@ def get_c_sync(r, name, sub):
def
apply_policy
(
policy
,
r
,
name
,
sub
):
def
apply_policy
(
policy
,
r
,
name
,
sub
):
"""
"""
WRITEME
Apply the list of policies to name.r,sub
Parameters
Parameters
----------
----------
...
@@ -478,7 +486,7 @@ def apply_policy(policy, r, name, sub):
...
@@ -478,7 +486,7 @@ def apply_policy(policy, r, name, sub):
def
struct_variable_codeblocks
(
variable
,
policies
,
id
,
symbol_table
,
sub
):
def
struct_variable_codeblocks
(
variable
,
policies
,
id
,
symbol_table
,
sub
):
"""
"""
WRITEME
Update "sub" dict and create two codeblocks with different failure modes
Parameters
Parameters
----------
----------
...
@@ -525,8 +533,6 @@ def struct_variable_codeblocks(variable, policies, id, symbol_table, sub):
...
@@ -525,8 +533,6 @@ def struct_variable_codeblocks(variable, policies, id, symbol_table, sub):
class
CLinker
(
link
.
Linker
):
class
CLinker
(
link
.
Linker
):
"""
"""
WRITEME
Creates C code for an fgraph, compiles it and returns callables
Creates C code for an fgraph, compiles it and returns callables
through make_thunk and make_function that make use of the compiled
through make_thunk and make_function that make use of the compiled
code.
code.
...
@@ -544,7 +550,7 @@ class CLinker(link.Linker):
...
@@ -544,7 +550,7 @@ class CLinker(link.Linker):
def
accept
(
self
,
fgraph
,
no_recycling
=
None
):
def
accept
(
self
,
fgraph
,
no_recycling
=
None
):
"""
"""
WRITEME
Associate linker with fgraph
"""
"""
if
no_recycling
is
None
:
if
no_recycling
is
None
:
...
@@ -559,8 +565,6 @@ class CLinker(link.Linker):
...
@@ -559,8 +565,6 @@ class CLinker(link.Linker):
def
fetch_variables
(
self
):
def
fetch_variables
(
self
):
"""
"""
WRITEME
Fills the inputs, outputs, variables, orphans, temps and node_order
Fills the inputs, outputs, variables, orphans, temps and node_order
fields.
fields.
...
@@ -617,8 +621,6 @@ class CLinker(link.Linker):
...
@@ -617,8 +621,6 @@ class CLinker(link.Linker):
def
code_gen
(
self
):
def
code_gen
(
self
):
"""
"""
WRITEME
Generates code for a struct that does the computation of the fgraph and
Generates code for a struct that does the computation of the fgraph and
stores it in the struct_code field of the instance.
stores it in the struct_code field of the instance.
...
@@ -890,14 +892,9 @@ class CLinker(link.Linker):
...
@@ -890,14 +892,9 @@ class CLinker(link.Linker):
def
support_code
(
self
):
def
support_code
(
self
):
"""
"""
WRITEME
Returns a list of support code strings that are needed by
Returns a list of support code strings that are needed by
one or more Variables or Ops. The support code from Variables is
one or more Variables or Ops.
added before the support code from Ops.
The support code from Variables is added before the support code from Ops.This might contain duplicates.
This might contain duplicates.
"""
"""
ret
=
[]
ret
=
[]
# generic support code
# generic support code
...
@@ -911,8 +908,6 @@ class CLinker(link.Linker):
...
@@ -911,8 +908,6 @@ class CLinker(link.Linker):
def
compile_args
(
self
):
def
compile_args
(
self
):
"""
"""
WRITEME
Returns a list of compile args that are needed by one
Returns a list of compile args that are needed by one
or more Variables or Ops.
or more Variables or Ops.
...
@@ -971,8 +966,6 @@ class CLinker(link.Linker):
...
@@ -971,8 +966,6 @@ class CLinker(link.Linker):
def
headers
(
self
):
def
headers
(
self
):
"""
"""
WRITEME
Returns a list of headers that are needed by one
Returns a list of headers that are needed by one
or more Types or Ops.
or more Types or Ops.
...
@@ -1032,8 +1025,6 @@ class CLinker(link.Linker):
...
@@ -1032,8 +1025,6 @@ class CLinker(link.Linker):
def
header_dirs
(
self
):
def
header_dirs
(
self
):
"""
"""
WRITEME
Returns a list of lib directories that are needed by one
Returns a list of lib directories that are needed by one
or more Types or Ops.
or more Types or Ops.
...
@@ -1055,8 +1046,6 @@ class CLinker(link.Linker):
...
@@ -1055,8 +1046,6 @@ class CLinker(link.Linker):
def
libraries
(
self
):
def
libraries
(
self
):
"""
"""
WRITEME
Returns a list of libraries that are needed by one
Returns a list of libraries that are needed by one
or more Types or Ops.
or more Types or Ops.
...
@@ -1078,8 +1067,6 @@ class CLinker(link.Linker):
...
@@ -1078,8 +1067,6 @@ class CLinker(link.Linker):
def
lib_dirs
(
self
):
def
lib_dirs
(
self
):
"""
"""
WRITEME
Returns a list of lib directories that are needed by one
Returns a list of lib directories that are needed by one
or more Types or Ops.
or more Types or Ops.
...
@@ -1101,7 +1088,7 @@ class CLinker(link.Linker):
...
@@ -1101,7 +1088,7 @@ class CLinker(link.Linker):
def
__compile__
(
self
,
input_storage
=
None
,
output_storage
=
None
,
def
__compile__
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
,
keep_lock
=
False
):
storage_map
=
None
,
keep_lock
=
False
):
"""
WRITEME
"""
Compiles this linker's fgraph.
Compiles this linker's fgraph.
Parameters
Parameters
...
@@ -1166,7 +1153,7 @@ class CLinker(link.Linker):
...
@@ -1166,7 +1153,7 @@ class CLinker(link.Linker):
def
make_thunk
(
self
,
input_storage
=
None
,
output_storage
=
None
,
def
make_thunk
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
,
keep_lock
=
False
):
storage_map
=
None
,
keep_lock
=
False
):
"""
WRITEME
"""
Compiles this linker's fgraph and returns a function to perform the
Compiles this linker's fgraph and returns a function to perform the
computations, as well as lists of storage cells for both the inputs
computations, as well as lists of storage cells for both the inputs
and outputs.
and outputs.
...
@@ -1183,8 +1170,10 @@ class CLinker(link.Linker):
...
@@ -1183,8 +1170,10 @@ class CLinker(link.Linker):
be allocated.
be allocated.
storage_map: dict that map variables to storages.
storage_map: dict that map variables to storages.
This is used when you need to customize the storage of
This is used when you need to customize the storage of
this thunk.
this thunk
keep_lock:
If True, we won't release the lock on the compiledir
at the end of this function call.
Returns: thunk, input_storage, output_storage
Returns: thunk, input_storage, output_storage
The return values can be used as follows:
The return values can be used as follows:
...
@@ -1568,7 +1557,12 @@ class CLinker(link.Linker):
...
@@ -1568,7 +1557,12 @@ class CLinker(link.Linker):
def
cthunk_factory
(
self
,
error_storage
,
in_storage
,
out_storage
,
def
cthunk_factory
(
self
,
error_storage
,
in_storage
,
out_storage
,
storage_map
=
None
,
keep_lock
=
False
):
storage_map
=
None
,
keep_lock
=
False
):
"""WRITEME
"""
Returns a thunk that points to an instance of a C struct that
can carry on the computation of this linker's fgraph
Parameters:
----------
error_storage -> list of length 3
error_storage -> list of length 3
in_storage -> list of lists of length 1, one per input
in_storage -> list of lists of length 1, one per input
out_storage -> list of lists of length 1, one per output
out_storage -> list of lists of length 1, one per output
...
@@ -1705,8 +1699,6 @@ class _CThunk(object):
...
@@ -1705,8 +1699,6 @@ class _CThunk(object):
class
OpWiseCLinker
(
link
.
LocalLinker
):
class
OpWiseCLinker
(
link
.
LocalLinker
):
"""
"""
WRITEME
Uses CLinker on the individual Ops that comprise an fgraph and loops
Uses CLinker on the individual Ops that comprise an fgraph and loops
over them in Python. The variable is slower than a compiled version of
over them in Python. The variable is slower than a compiled version of
the whole fgraph, but saves on compilation time because small changes
the whole fgraph, but saves on compilation time because small changes
...
@@ -1746,6 +1738,9 @@ class OpWiseCLinker(link.LocalLinker):
...
@@ -1746,6 +1738,9 @@ class OpWiseCLinker(link.LocalLinker):
self
.
schedule
=
schedule
self
.
schedule
=
schedule
def
accept
(
self
,
fgraph
,
no_recycling
=
None
):
def
accept
(
self
,
fgraph
,
no_recycling
=
None
):
"""
Associate linker with fgraph
"""
if
no_recycling
is
None
:
if
no_recycling
is
None
:
no_recycling
=
[]
no_recycling
=
[]
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
...
@@ -1846,11 +1841,14 @@ class OpWiseCLinker(link.LocalLinker):
...
@@ -1846,11 +1841,14 @@ class OpWiseCLinker(link.LocalLinker):
def
_default_checker
(
x
,
y
):
def
_default_checker
(
x
,
y
):
"""
"""
WRITEME
Default checker for DualLinker. This checks that the
Default checker for DualLinker. This checks that the
variables contain the same data using ==.
variables contain the same data using ==.
Parameters:
----------
x,y
the variables to compare data
"""
"""
if
x
[
0
]
!=
y
[
0
]:
if
x
[
0
]
!=
y
[
0
]:
raise
Exception
(
"Output mismatch."
,
raise
Exception
(
"Output mismatch."
,
...
@@ -1859,8 +1857,6 @@ def _default_checker(x, y):
...
@@ -1859,8 +1857,6 @@ def _default_checker(x, y):
class
DualLinker
(
link
.
Linker
):
class
DualLinker
(
link
.
Linker
):
"""
"""
WRITEME
Runs the fgraph in parallel using PerformLinker and CLinker.
Runs the fgraph in parallel using PerformLinker and CLinker.
The thunk/function produced by DualLinker uses PerformLinker as the
The thunk/function produced by DualLinker uses PerformLinker as the
...
@@ -1902,6 +1898,9 @@ class DualLinker(link.Linker):
...
@@ -1902,6 +1898,9 @@ class DualLinker(link.Linker):
self
.
schedule
=
schedule
self
.
schedule
=
schedule
def
accept
(
self
,
fgraph
,
no_recycling
=
None
):
def
accept
(
self
,
fgraph
,
no_recycling
=
None
):
"""
Update/tie self with fgraph
"""
if
no_recycling
is
None
:
if
no_recycling
is
None
:
no_recycling
=
[]
no_recycling
=
[]
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
...
@@ -1912,7 +1911,10 @@ class DualLinker(link.Linker):
...
@@ -1912,7 +1911,10 @@ class DualLinker(link.Linker):
return
self
return
self
def
make_thunk
(
self
,
**
kwargs
):
def
make_thunk
(
self
,
**
kwargs
):
"""
Compiles this linker's fgraph and returns a function to perform the
computations
"""
fgraph
=
self
.
fgraph
fgraph
=
self
.
fgraph
no_recycling
=
self
.
no_recycling
no_recycling
=
self
.
no_recycling
...
...
theano/gof/cmodule.py
浏览文件 @
c0b294ec
...
@@ -1474,10 +1474,25 @@ class ModuleCache(object):
...
@@ -1474,10 +1474,25 @@ class ModuleCache(object):
def
_rmtree
(
parent
,
ignore_nocleanup
=
False
,
msg
=
''
,
level
=
logging
.
DEBUG
,
def
_rmtree
(
parent
,
ignore_nocleanup
=
False
,
msg
=
''
,
level
=
logging
.
DEBUG
,
ignore_if_missing
=
False
):
ignore_if_missing
=
False
):
# On NFS filesystems, it is impossible to delete a directory with open
"""
# files in it. So instead, some commands in this file will respond to a
On NFS filesystems, it is impossible to delete a directory with open
# failed rmtree() by touching a 'delete.me' file. This file is a message
files in it.
# for a future process to try deleting the directory.
So instead, some commands in this file will respond to a
failed rmtree() by touching a 'delete.me' file. This file is a message
for a future process to try deleting the directory.
Parameters:
----------
parent
Root node to start deleting from
ignore_nocleanup
Delete the tree if flag is TRUE
level
Python Logging level. Set to "DEBUG" by default
ignore_if_missing
If set to True, just return without any issue if parent is NULL
"""
if
ignore_if_missing
and
not
os
.
path
.
exists
(
parent
):
if
ignore_if_missing
and
not
os
.
path
.
exists
(
parent
):
return
return
try
:
try
:
...
@@ -1504,6 +1519,7 @@ _module_cache = None
...
@@ -1504,6 +1519,7 @@ _module_cache = None
def
get_module_cache
(
dirname
,
init_args
=
None
):
def
get_module_cache
(
dirname
,
init_args
=
None
):
"""
"""
Create a new module_cache with the (k, v) pairs in this dictionary
Parameters
Parameters
----------
----------
...
...
theano/gof/compiledir.py
浏览文件 @
c0b294ec
...
@@ -94,6 +94,9 @@ def cleanup():
...
@@ -94,6 +94,9 @@ def cleanup():
def
print_compiledir_content
():
def
print_compiledir_content
():
"""
print list of
%
d compiled individual ops in the "theano.config.compiledir"
"""
max_key_file_size
=
1
*
1024
*
1024
# 1M
max_key_file_size
=
1
*
1024
*
1024
# 1M
compiledir
=
theano
.
config
.
compiledir
compiledir
=
theano
.
config
.
compiledir
...
@@ -178,6 +181,9 @@ def compiledir_purge():
...
@@ -178,6 +181,9 @@ def compiledir_purge():
def
basecompiledir_ls
():
def
basecompiledir_ls
():
"""
Print list of files in the "theano.config.base_compiledir"
"""
subdirs
=
[]
subdirs
=
[]
others
=
[]
others
=
[]
for
f
in
os
.
listdir
(
config
.
base_compiledir
):
for
f
in
os
.
listdir
(
config
.
base_compiledir
):
...
...
theano/gof/destroyhandler.py
浏览文件 @
c0b294ec
...
@@ -32,6 +32,7 @@ class ProtocolError(Exception):
...
@@ -32,6 +32,7 @@ class ProtocolError(Exception):
def
_contains_cycle
(
fgraph
,
orderings
):
def
_contains_cycle
(
fgraph
,
orderings
):
"""
"""
Function to check if the given graph contains a cycle
Parameters
Parameters
----------
----------
...
...
theano/gof/fg.py
浏览文件 @
c0b294ec
...
@@ -66,7 +66,6 @@ class MissingInputError(Exception):
...
@@ -66,7 +66,6 @@ class MissingInputError(Exception):
class
FunctionGraph
(
utils
.
object2
):
class
FunctionGraph
(
utils
.
object2
):
"""
"""
WRITEME
A FunctionGraph represents a subgraph bound by a set of input variables and
A FunctionGraph represents a subgraph bound by a set of input variables and
a set of output variables, ie a subgraph that specifies a theano function.
a set of output variables, ie a subgraph that specifies a theano function.
The inputs list should contain all the inputs on which the outputs depend.
The inputs list should contain all the inputs on which the outputs depend.
...
@@ -265,8 +264,6 @@ class FunctionGraph(utils.object2):
...
@@ -265,8 +264,6 @@ class FunctionGraph(utils.object2):
"""
"""
Updates the list of clients of r with new_clients.
Updates the list of clients of r with new_clients.
WRITEME
Parameters
Parameters
----------
----------
r
r
...
@@ -365,6 +362,11 @@ class FunctionGraph(utils.object2):
...
@@ -365,6 +362,11 @@ class FunctionGraph(utils.object2):
"""
"""
Import variables to this FunctionGraph and also their apply_node,
Import variables to this FunctionGraph and also their apply_node,
if those nodes are not in this graph.
if those nodes are not in this graph.
Parameters:
----------
reason
reason is the name of the optimization or operation in progress.
"""
"""
global
NullType
global
NullType
if
NullType
is
None
:
if
NullType
is
None
:
...
@@ -438,8 +440,6 @@ class FunctionGraph(utils.object2):
...
@@ -438,8 +440,6 @@ class FunctionGraph(utils.object2):
"""
"""
Changes node.inputs[i] to new_r.
Changes node.inputs[i] to new_r.
WRITEME
new_r.type == old_r.type must be True, where old_r is the
new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace.
current value of node.inputs[i] which we want to replace.
...
@@ -483,8 +483,6 @@ class FunctionGraph(utils.object2):
...
@@ -483,8 +483,6 @@ class FunctionGraph(utils.object2):
# replace #
# replace #
def
replace
(
self
,
r
,
new_r
,
reason
=
None
,
verbose
=
None
):
def
replace
(
self
,
r
,
new_r
,
reason
=
None
,
verbose
=
None
):
"""
"""
WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead.
For every node that uses r as input, makes it use new_r instead.
...
@@ -540,7 +538,7 @@ class FunctionGraph(utils.object2):
...
@@ -540,7 +538,7 @@ class FunctionGraph(utils.object2):
def
replace_all
(
self
,
pairs
,
reason
=
None
):
def
replace_all
(
self
,
pairs
,
reason
=
None
):
"""
"""
WRITEME
For every node that uses r as input, makes it use new_r instead
"""
"""
for
r
,
new_r
in
pairs
:
for
r
,
new_r
in
pairs
:
...
@@ -578,8 +576,6 @@ class FunctionGraph(utils.object2):
...
@@ -578,8 +576,6 @@ class FunctionGraph(utils.object2):
def
remove_feature
(
self
,
feature
):
def
remove_feature
(
self
,
feature
):
"""
"""
WRITEME
Removes the feature from the graph.
Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method
Calls feature.on_detach(function_graph) if an on_detach method
...
@@ -598,8 +594,6 @@ class FunctionGraph(utils.object2):
...
@@ -598,8 +594,6 @@ class FunctionGraph(utils.object2):
# callback utils #
# callback utils #
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
"""
"""
WRITEME
Calls
Calls
getattr(feature, name)(*args)
getattr(feature, name)(*args)
for each feature which has a method called after name.
for each feature which has a method called after name.
...
@@ -621,8 +615,6 @@ class FunctionGraph(utils.object2):
...
@@ -621,8 +615,6 @@ class FunctionGraph(utils.object2):
def
collect_callbacks
(
self
,
name
,
*
args
):
def
collect_callbacks
(
self
,
name
,
*
args
):
"""
"""
WRITEME
Returns a dictionary d such that:
Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args)
d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name.
For each feature which has a method called after name.
...
@@ -640,8 +632,6 @@ class FunctionGraph(utils.object2):
...
@@ -640,8 +632,6 @@ class FunctionGraph(utils.object2):
# misc #
# misc #
def
toposort
(
self
):
def
toposort
(
self
):
"""
"""
WRITEME
Return an ordering of the graph's Apply nodes such that:
Return an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node.
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
- Satisfies the orderings provided by each feature that has
...
@@ -705,8 +695,6 @@ class FunctionGraph(utils.object2):
...
@@ -705,8 +695,6 @@ class FunctionGraph(utils.object2):
def
check_integrity
(
self
):
def
check_integrity
(
self
):
"""
"""
WRITEME
Call this for a diagnosis if things go awry.
Call this for a diagnosis if things go awry.
"""
"""
...
@@ -766,7 +754,7 @@ class FunctionGraph(utils.object2):
...
@@ -766,7 +754,7 @@ class FunctionGraph(utils.object2):
# clone #
# clone #
def
clone
(
self
,
check_integrity
=
True
):
def
clone
(
self
,
check_integrity
=
True
):
"""
"""
WRITEME
Clone the graph and get a memo( a dict )that map old node to new node
"""
"""
return
self
.
clone_get_equiv
(
check_integrity
)[
0
]
return
self
.
clone_get_equiv
(
check_integrity
)[
0
]
...
...
theano/gof/graph.py
浏览文件 @
c0b294ec
...
@@ -701,7 +701,15 @@ def inputs(variable_list, blockers=None):
...
@@ -701,7 +701,15 @@ def inputs(variable_list, blockers=None):
def
variables_and_orphans
(
i
,
o
):
def
variables_and_orphans
(
i
,
o
):
"""
"""
WRITEME
Extract list of variables between i and o nodes via
dfs traversal and chooses the orphans among them
Parameters
----------
i : list
Input variables.
o : list
Output variables.
"""
"""
def
expand
(
r
):
def
expand
(
r
):
...
@@ -716,21 +724,21 @@ def variables_and_orphans(i, o):
...
@@ -716,21 +724,21 @@ def variables_and_orphans(i, o):
def
ops
(
i
,
o
):
def
ops
(
i
,
o
):
"""
"""
WRITEME
Set of Ops contained within the subgraph between i and o
Parameters
Parameters
----------
----------
i : list
i : list
Input
L{Variable}
s.
Input
variable
s.
o : list
o : list
Output
L{Variable}
s.
Output
variable
s.
Returns
Returns
-------
-------
object
object
The set of ops that are contained within the subgraph that lies
The set of ops that are contained within the subgraph that lies
between i and o, including the owners of the
L{Variable}
s in o and
between i and o, including the owners of the
variable
s in o and
intermediary ops between i and o, but not the owners of the
L{Variable}
s
intermediary ops between i and o, but not the owners of the
variable
s
in i.
in i.
"""
"""
...
@@ -745,14 +753,14 @@ def ops(i, o):
...
@@ -745,14 +753,14 @@ def ops(i, o):
def
variables
(
i
,
o
):
def
variables
(
i
,
o
):
"""
"""
WRITEME
Extracts list of variables within input and output nodes via dfs travesal
Parameters
Parameters
----------
----------
i : list
i : list
Input
L{Variable}
s.
Input
variable
s.
o : list
o : list
Output
L{Variable}
s.
Output
variable
s.
Returns
Returns
-------
-------
...
@@ -767,14 +775,15 @@ def variables(i, o):
...
@@ -767,14 +775,15 @@ def variables(i, o):
def
orphans
(
i
,
o
):
def
orphans
(
i
,
o
):
"""
"""
WRITEME
Extracts list of variables within input and output nodes
via dfs travesal and returns the orphans among them
Parameters
Parameters
----------
----------
i : list
i : list
Input
L{Variable}
s.
Input
Variable
s.
o : list
o : list
Output
L{Variable}
s.
Output
Variable
s.
Returns
Returns
-------
-------
...
@@ -797,9 +806,9 @@ def clone(i, o, copy_inputs=True):
...
@@ -797,9 +806,9 @@ def clone(i, o, copy_inputs=True):
Parameters
Parameters
----------
----------
i : list
i : list
Input
L{Variable}
s.
Input
Variable
s.
o : list
o : list
Output
L{Variable}
s.
Output
Variable
s.
copy_inputs : bool
copy_inputs : bool
If True, the inputs will be copied (defaults to True).
If True, the inputs will be copied (defaults to True).
...
@@ -959,7 +968,7 @@ def general_toposort(r_out, deps, debug_print=False,
...
@@ -959,7 +968,7 @@ def general_toposort(r_out, deps, debug_print=False,
def
io_toposort
(
inputs
,
outputs
,
orderings
=
None
,
clients
=
None
):
def
io_toposort
(
inputs
,
outputs
,
orderings
=
None
,
clients
=
None
):
"""
"""
WRITEME
Perform topological sort from input and output nodes
Parameters
Parameters
----------
----------
...
@@ -1218,8 +1227,8 @@ def op_as_string(i, op,
...
@@ -1218,8 +1227,8 @@ def op_as_string(i, op,
leaf_formatter
=
default_leaf_formatter
,
leaf_formatter
=
default_leaf_formatter
,
node_formatter
=
default_node_formatter
):
node_formatter
=
default_node_formatter
):
"""
"""
WRITEME
Op to return a string representation of the subgraph
between i and o
"""
"""
strs
=
as_string
(
i
,
op
.
inputs
,
leaf_formatter
,
node_formatter
)
strs
=
as_string
(
i
,
op
.
inputs
,
leaf_formatter
,
node_formatter
)
return
node_formatter
(
op
,
strs
)
return
node_formatter
(
op
,
strs
)
...
@@ -1229,7 +1238,7 @@ def as_string(i, o,
...
@@ -1229,7 +1238,7 @@ def as_string(i, o,
leaf_formatter
=
default_leaf_formatter
,
leaf_formatter
=
default_leaf_formatter
,
node_formatter
=
default_node_formatter
):
node_formatter
=
default_node_formatter
):
"""
"""
WRITEME
Returns a string representation of the subgraph between i and o
Parameters
Parameters
----------
----------
...
...
theano/gof/link.py
浏览文件 @
c0b294ec
...
@@ -52,16 +52,24 @@ def log_thunk_trace(value, f=sys.stderr):
...
@@ -52,16 +52,24 @@ def log_thunk_trace(value, f=sys.stderr):
def
thunk_hook
(
type
,
value
,
trace
):
def
thunk_hook
(
type
,
value
,
trace
):
"""
"""
WRITEME
This function is meant to replace excepthook and do some
This function is meant to replace excepthook and do some
special work if the exception value has a __thunk_trace__
special work if the exception value has a __thunk_trace__
field. In that case, it retrieves the field, which should
field.
In that case, it retrieves the field, which should
contain a trace as returned by L{traceback.extract_stack},
contain a trace as returned by L{traceback.extract_stack},
and prints it out on L{stderr}.
and prints it out on L{stderr}.
The normal excepthook is then called.
The normal excepthook is then called.
Parameters:
----------
type
Exception class
value
Exception instance
trace
Traceback object
Notes
Notes
-----
-----
This hook replaced by nosetests, so it does not run in nose tests.
This hook replaced by nosetests, so it does not run in nose tests.
...
@@ -680,8 +688,6 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
...
@@ -680,8 +688,6 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
class
LocalLinker
(
Linker
):
class
LocalLinker
(
Linker
):
"""
"""
WRITEME
Useful base class for L{Linker}s which keep all nodes in the graph, and run
Useful base class for L{Linker}s which keep all nodes in the graph, and run
a thunk associated with each node.
a thunk associated with each node.
...
@@ -707,7 +713,7 @@ class LocalLinker(Linker):
...
@@ -707,7 +713,7 @@ class LocalLinker(Linker):
def
gc_helper
(
node_list
):
def
gc_helper
(
node_list
):
"""
"""
Return the set of Variable instances which are computed by node_list.
Parameters
Parameters
----------
----------
node_list
node_list
...
@@ -743,8 +749,6 @@ def gc_helper(node_list):
...
@@ -743,8 +749,6 @@ def gc_helper(node_list):
class
PerformLinker
(
LocalLinker
):
class
PerformLinker
(
LocalLinker
):
"""
"""
WRITEME
Basic L{Linker} subclass that calls the perform method on each L{Op} in
Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{FunctionGraph} in the order given by L{Linker.schedule}.
the L{FunctionGraph} in the order given by L{Linker.schedule}.
...
@@ -764,8 +768,7 @@ class PerformLinker(LocalLinker):
...
@@ -764,8 +768,7 @@ class PerformLinker(LocalLinker):
Parameters
Parameters
----------
----------
fgraph
fgraph
A PerformLinker can have accepted one FunctionGraph instance at a
A PerformLinker can have accepted one FunctionGraph instance at a time.
time.
no_recycling
no_recycling
WRITEME
WRITEME
...
@@ -786,13 +789,14 @@ class PerformLinker(LocalLinker):
...
@@ -786,13 +789,14 @@ class PerformLinker(LocalLinker):
def
make_all
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
):
def
make_all
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
):
"""
"""
Returns Function to run all nodes, list of input containers, list of outputs
Parameters
Parameters
----------
----------
input_storage
input_storage
WRITEME
list of storages corresponding to fgraph.inputs
output_storage
output_storage
WRITEME
list of storages corresponding to fgraph.outputs
Returns
Returns
-------
-------
...
@@ -879,8 +883,6 @@ def add_clear_storage(f, computed, storage_map):
...
@@ -879,8 +883,6 @@ def add_clear_storage(f, computed, storage_map):
class
WrapLinker
(
Linker
):
class
WrapLinker
(
Linker
):
"""
"""
WRITEME
This class makes it easier to run several L{LocalLinker}s in parallel, and
This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run.
offers some control over how each thunk is run.
...
...
theano/gof/op.py
浏览文件 @
c0b294ec
...
@@ -791,6 +791,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
...
@@ -791,6 +791,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
self
.
_op_use_c_code
=
use_c_code
self
.
_op_use_c_code
=
use_c_code
def
_props
(
self
):
def
_props
(
self
):
"""
Tuple of properties of all attributes
"""
return
tuple
(
getattr
(
self
,
a
)
for
a
in
self
.
__props__
)
return
tuple
(
getattr
(
self
,
a
)
for
a
in
self
.
__props__
)
def
_props_dict
(
self
):
def
_props_dict
(
self
):
...
@@ -924,6 +927,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
...
@@ -924,6 +927,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
"""
"""
This function must return a thunk, that is a zero-arguments
function that encapsulates the computation to be performed
by this op on the arguments of the node.
Parameters
Parameters
----------
----------
...
@@ -974,7 +980,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
...
@@ -974,7 +980,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
return
self
.
make_py_thunk
(
node
,
storage_map
,
compute_map
,
no_recycling
)
return
self
.
make_py_thunk
(
node
,
storage_map
,
compute_map
,
no_recycling
)
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
"""
Create a "apply" nodes for the inputs in that order.
"""
if
not
hasattr
(
self
,
'itypes'
):
if
not
hasattr
(
self
,
'itypes'
):
raise
NotImplementedError
(
"You can either define itypes and otypes,
\
raise
NotImplementedError
(
"You can either define itypes and otypes,
\
or implement make_node"
)
or implement make_node"
)
...
@@ -1058,6 +1066,10 @@ def debug_error_message(msg):
...
@@ -1058,6 +1066,10 @@ def debug_error_message(msg):
def
debug_assert
(
condition
,
msg
=
None
):
def
debug_assert
(
condition
,
msg
=
None
):
"""
Customized assert with options to ignore the assert
with just a warning
"""
if
msg
is
None
:
if
msg
is
None
:
msg
=
'debug_assert failed'
msg
=
'debug_assert failed'
if
not
condition
:
if
not
condition
:
...
@@ -1165,12 +1177,18 @@ class OpenMPOp(Op):
...
@@ -1165,12 +1177,18 @@ class OpenMPOp(Op):
self
.
openmp
=
False
self
.
openmp
=
False
def
c_compile_args
(
self
):
def
c_compile_args
(
self
):
"""
Return the compilation arg "fopenmp" if openMP is supported
"""
self
.
update_self_openmp
()
self
.
update_self_openmp
()
if
self
.
openmp
:
if
self
.
openmp
:
return
[
'-fopenmp'
]
return
[
'-fopenmp'
]
return
[]
return
[]
def
c_headers
(
self
):
def
c_headers
(
self
):
"""
Return the header file name "omp.h" if openMP is supported
"""
self
.
update_self_openmp
()
self
.
update_self_openmp
()
if
self
.
openmp
:
if
self
.
openmp
:
return
[
"omp.h"
]
return
[
"omp.h"
]
...
@@ -1178,6 +1196,9 @@ class OpenMPOp(Op):
...
@@ -1178,6 +1196,9 @@ class OpenMPOp(Op):
@staticmethod
@staticmethod
def
test_gxx_support
():
def
test_gxx_support
():
"""
Check if openMP is supported
"""
code
=
"""
code
=
"""
#include <omp.h>
#include <omp.h>
int main( int argc, const char* argv[] )
int main( int argc, const char* argv[] )
...
@@ -1313,6 +1334,9 @@ class COp(Op):
...
@@ -1313,6 +1334,9 @@ class COp(Op):
'and specify the func_name'
)
'and specify the func_name'
)
def
load_c_code
(
self
):
def
load_c_code
(
self
):
"""
Loads the c code to perform the Op
"""
self
.
func_codes
=
[]
self
.
func_codes
=
[]
for
func_file
in
self
.
func_files
:
for
func_file
in
self
.
func_files
:
with
open
(
func_file
,
'r'
)
as
f
:
with
open
(
func_file
,
'r'
)
as
f
:
...
@@ -1391,6 +1415,9 @@ class COp(Op):
...
@@ -1391,6 +1415,9 @@ class COp(Op):
return
hash
(
tuple
(
self
.
func_codes
))
return
hash
(
tuple
(
self
.
func_codes
))
def
c_init_code
(
self
):
def
c_init_code
(
self
):
"""
Get the code section for init_code
"""
if
'init_code'
in
self
.
code_sections
:
if
'init_code'
in
self
.
code_sections
:
return
[
self
.
code_sections
[
'init_code'
]]
return
[
self
.
code_sections
[
'init_code'
]]
else
:
else
:
...
@@ -1500,6 +1527,10 @@ class COp(Op):
...
@@ -1500,6 +1527,10 @@ class COp(Op):
undef_macros
.
append
(
"#undef OUTPUT_
%
d"
,
(
i
,))
undef_macros
.
append
(
"#undef OUTPUT_
%
d"
,
(
i
,))
def
c_init_code_struct
(
self
,
node
,
name
,
sub
):
def
c_init_code_struct
(
self
,
node
,
name
,
sub
):
"""
Stitches all the macros and "init_code" together
"""
if
'init_code_struct'
in
self
.
code_sections
:
if
'init_code_struct'
in
self
.
code_sections
:
op_code
=
self
.
code_sections
[
'init_code_struct'
]
op_code
=
self
.
code_sections
[
'init_code_struct'
]
...
@@ -1554,6 +1585,9 @@ class COp(Op):
...
@@ -1554,6 +1585,9 @@ class COp(Op):
'c_code'
,
type
(
self
),
type
(
self
)
.
__name__
)
'c_code'
,
type
(
self
),
type
(
self
)
.
__name__
)
def
c_code_cleanup
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code_cleanup
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
"""
Stitches all the macros and "code_cleanup" together
"""
if
'code_cleanup'
in
self
.
code_sections
:
if
'code_cleanup'
in
self
.
code_sections
:
op_code
=
self
.
code_sections
[
'code_cleanup'
]
op_code
=
self
.
code_sections
[
'code_cleanup'
]
...
...
theano/gof/sandbox/equilibrium.py
deleted
100644 → 0
浏览文件 @
d9028c7b
from
__future__
import
absolute_import
,
print_function
,
division
from
six.moves
import
reduce
from
six
import
string_types
if
0
:
class
_EquilibriumOptimizer
(
NavigatorOptimizer
):
def
__init__
(
self
,
local_optimizers
,
failure_callback
=
None
,
max_depth
=
None
,
max_use_ratio
=
None
):
super
(
EquilibriumOptimizer
,
self
)
.
__init__
(
None
,
ignore_newtrees
=
False
,
failure_callback
=
failure_callback
)
self
.
local_optimizers
=
local_optimizers
self
.
max_depth
=
max_depth
self
.
max_use_ratio
=
max_use_ratio
self
.
tracks
=
defaultdict
(
list
)
self
.
tracks0
=
defaultdict
(
list
)
max_depth
=
0
for
lopt
in
local_optimizers
:
tracks
=
lopt
.
tracks
()
for
track
in
tracks
:
max_depth
=
max
(
max_depth
,
len
(
track
))
if
self
.
max_depth
is
not
None
and
max_depth
>
self
.
max_depth
:
raise
ValueError
(
'One of the local optimizers exceeds the maximal depth.'
)
for
i
,
op
in
enumerate
(
track
):
if
i
==
0
:
self
.
tracks0
[
op
]
.
append
((
track
,
i
,
lopt
))
self
.
tracks
[
op
]
.
append
((
track
,
i
,
lopt
))
def
fetch_tracks
(
self
,
op
):
return
self
.
tracks
[
op
]
+
self
.
tracks
[
None
]
def
fetch_tracks0
(
self
,
op
):
return
self
.
tracks0
[
op
]
+
self
.
tracks0
[
None
]
def
backtrack
(
self
,
node
,
tasks
):
candidates
=
self
.
fetch_tracks
(
node
.
op
)
tracks
=
[]
def
filter
(
node
,
depth
):
new_candidates
=
[]
for
candidate
in
candidates
:
track
,
i
,
lopt
=
candidate
if
i
<
depth
:
pass
elif
track
[
i
-
depth
]
in
(
None
,
node
.
op
):
if
i
==
depth
:
tasks
[
node
]
.
append
(
lopt
)
else
:
tracks
.
append
(
candidate
)
else
:
new_candidates
.
append
(
candidate
)
return
new_candidates
depth
=
0
nodes
=
[
node
]
while
candidates
:
for
node
in
nodes
:
candidates
=
list
(
filter
(
node
,
depth
))
depth
+=
1
_nodes
=
nodes
nodes
=
reduce
(
list
.
__iadd__
,
[
reduce
(
list
.
__iadd__
,
[[
n
for
n
,
i
in
out
.
clients
if
not
isinstance
(
n
,
string_types
)]
for
out
in
node
.
outputs
],
[])
for
node
in
nodes
],
[])
candidates
=
tracks
tracks
=
[]
def
apply
(
self
,
fgraph
):
tasks
=
defaultdict
(
list
)
if
self
.
max_use_ratio
is
not
None
:
max_uses
=
self
.
max_use_ratio
*
len
(
fgraph
.
apply_nodes
)
runs
=
defaultdict
(
int
)
else
:
runs
=
None
def
importer
(
node
):
# print 'IMPORTING', node
self
.
backtrack
(
node
,
tasks
)
def
pruner
(
node
):
try
:
del
tasks
[
node
]
except
KeyError
:
pass
def
chin
(
node
,
i
,
r
,
new_r
):
if
new_r
.
owner
and
not
r
.
clients
:
self
.
backtrack
(
new_r
.
owner
,
tasks
)
# # == NOT IDEAL == #
# for node in fgraph.apply_nodes:
# importer(node)
for
node
in
fgraph
.
toposort
():
tasks
[
node
]
.
extend
(
lopt
for
track
,
i
,
lopt
in
self
.
fetch_tracks0
(
node
.
op
))
u
=
self
.
attach_updater
(
fgraph
,
importer
,
pruner
,
chin
)
print
(
'KEYS'
,
[
hash
(
t
)
for
t
in
tasks
.
keys
()])
while
tasks
:
for
node
in
tasks
:
todo
=
tasks
.
pop
(
node
)
break
for
lopt
in
todo
:
if
runs
is
not
None
and
runs
[
lopt
]
>=
max_uses
:
print
(
'Warning: optimization exceeded its maximal use ratio:
%
s,
%
s'
%
(
lopt
,
max_uses
),
file
=
sys
.
stderr
)
continue
success
=
self
.
process_node
(
fgraph
,
node
,
lopt
)
if
success
:
if
runs
is
not
None
:
runs
[
lopt
]
+=
1
break
self
.
detach_updater
(
fgraph
,
u
)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, string_types):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
theano/gof/sched.py
浏览文件 @
c0b294ec
...
@@ -268,6 +268,9 @@ def sort_schedule_fn(*cmps):
...
@@ -268,6 +268,9 @@ def sort_schedule_fn(*cmps):
def
key_to_cmp
(
key
):
def
key_to_cmp
(
key
):
"""
comparator function based on "key" function
"""
def
key_cmp
(
a
,
b
):
def
key_cmp
(
a
,
b
):
return
cmp
(
key
(
a
),
key
(
b
))
return
cmp
(
key
(
a
),
key
(
b
))
return
key_cmp
return
key_cmp
theano/gof/toolbox.py
浏览文件 @
c0b294ec
...
@@ -114,10 +114,20 @@ class Feature(object):
...
@@ -114,10 +114,20 @@ class Feature(object):
class
Bookkeeper
(
Feature
):
class
Bookkeeper
(
Feature
):
def
on_attach
(
self
,
fgraph
):
def
on_attach
(
self
,
fgraph
):
"""
Called by FunctionGraph.attach_feature, the method that attaches
the feature to the FunctionGraph. Since this is called after the
FunctionGraph is initially populated, this is where you should
run checks on the initial contents of the FunctionGraph.
"""
for
node
in
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
):
for
node
in
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
):
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_detach
(
self
,
fgraph
):
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
for
node
in
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
):
for
node
in
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
):
self
.
on_prune
(
fgraph
,
node
,
'Bookkeeper.detach'
)
self
.
on_prune
(
fgraph
,
node
,
'Bookkeeper.detach'
)
...
@@ -178,6 +188,10 @@ class History(Feature):
...
@@ -178,6 +188,10 @@ class History(Feature):
fgraph
.
revert
=
partial
(
self
.
revert
,
fgraph
)
fgraph
.
revert
=
partial
(
self
.
revert
,
fgraph
)
def
on_detach
(
self
,
fgraph
):
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
del
fgraph
.
checkpoint
del
fgraph
.
checkpoint
del
fgraph
.
revert
del
fgraph
.
revert
del
self
.
history
[
fgraph
]
del
self
.
history
[
fgraph
]
...
@@ -223,10 +237,19 @@ class Validator(Feature):
...
@@ -223,10 +237,19 @@ class Validator(Feature):
fgraph
.
consistent
=
partial
(
self
.
consistent_
,
fgraph
)
fgraph
.
consistent
=
partial
(
self
.
consistent_
,
fgraph
)
def
on_detach
(
self
,
fgraph
):
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
del
fgraph
.
validate
del
fgraph
.
validate
del
fgraph
.
consistent
del
fgraph
.
consistent
def
validate_
(
self
,
fgraph
):
def
validate_
(
self
,
fgraph
):
"""
If the caller is replace_all_validate, just raise the
exception. replace_all_validate will print out the
verbose output. Or it has to be done here before raise.
"""
t0
=
time
.
time
()
t0
=
time
.
time
()
try
:
try
:
ret
=
fgraph
.
execute_callbacks
(
'validate'
)
ret
=
fgraph
.
execute_callbacks
(
'validate'
)
...
@@ -289,6 +312,10 @@ class ReplaceValidate(History, Validator):
...
@@ -289,6 +312,10 @@ class ReplaceValidate(History, Validator):
self
.
replace_all_validate_remove
,
fgraph
)
self
.
replace_all_validate_remove
,
fgraph
)
def
on_detach
(
self
,
fgraph
):
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
History
.
on_detach
(
self
,
fgraph
)
History
.
on_detach
(
self
,
fgraph
)
Validator
.
on_detach
(
self
,
fgraph
)
Validator
.
on_detach
(
self
,
fgraph
)
del
self
.
_nodes_removed
del
self
.
_nodes_removed
...
@@ -412,6 +439,10 @@ class NodeFinder(Bookkeeper):
...
@@ -412,6 +439,10 @@ class NodeFinder(Bookkeeper):
Bookkeeper
.
on_attach
(
self
,
fgraph
)
Bookkeeper
.
on_attach
(
self
,
fgraph
)
def
on_detach
(
self
,
fgraph
):
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if
self
.
fgraph
is
not
fgraph
:
if
self
.
fgraph
is
not
fgraph
:
raise
Exception
(
"This NodeFinder instance was not attached to the"
raise
Exception
(
"This NodeFinder instance was not attached to the"
" provided fgraph."
)
" provided fgraph."
)
...
@@ -461,6 +492,10 @@ class PrintListener(Feature):
...
@@ -461,6 +492,10 @@ class PrintListener(Feature):
print
(
"-- attaching to: "
,
fgraph
)
print
(
"-- attaching to: "
,
fgraph
)
def
on_detach
(
self
,
fgraph
):
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if
self
.
active
:
if
self
.
active
:
print
(
"-- detaching from: "
,
fgraph
)
print
(
"-- detaching from: "
,
fgraph
)
...
...
theano/gof/unify.py
浏览文件 @
c0b294ec
...
@@ -4,7 +4,7 @@ can be "unified" if there exists an assignment to all unification variables
...
@@ -4,7 +4,7 @@ can be "unified" if there exists an assignment to all unification variables
such that the two expressions are equal.
such that the two expressions are equal.
For instance, [5, A, B] and [A, C, 9] can be unified if A=C=5 and B=9,
For instance, [5, A, B] and [A, C, 9] can be unified if A=C=5 and B=9,
yielding [5, 5, 9].
yielding [5, 5, 9].
[5, [A, B]] and [A, [1, 2]] cannot be unified because there is no value for A
[5, [A, B]] and [A, [1, 2]] cannot be unified because there is no value for A
that satisfies the constraints. That's useful for pattern matching.
that satisfies the constraints. That's useful for pattern matching.
...
@@ -15,7 +15,6 @@ from copy import copy
...
@@ -15,7 +15,6 @@ from copy import copy
from
functools
import
partial
from
functools
import
partial
from
theano.gof.utils
import
ANY_TYPE
,
comm_guard
,
FALL_THROUGH
,
iteritems
from
theano.gof.utils
import
ANY_TYPE
,
comm_guard
,
FALL_THROUGH
,
iteritems
################################
################################
...
@@ -135,8 +134,6 @@ class Unification:
...
@@ -135,8 +134,6 @@ class Unification:
"""
"""
This class represents a possible unification of a group of variables
This class represents a possible unification of a group of variables
with each other or with tangible values.
with each other or with tangible values.
Parameters
Parameters
----------
----------
inplace : bool
inplace : bool
...
@@ -229,7 +226,7 @@ def unify_walk(a, b, U):
...
@@ -229,7 +226,7 @@ def unify_walk(a, b, U):
return
False
return
False
@comm_guard
(
FreeVariable
,
ANY_TYPE
)
@comm_guard
(
FreeVariable
,
ANY_TYPE
)
# noqa
def
unify_walk
(
fv
,
o
,
U
):
def
unify_walk
(
fv
,
o
,
U
):
"""
"""
FreeV is unified to BoundVariable(other_object).
FreeV is unified to BoundVariable(other_object).
...
@@ -239,7 +236,7 @@ def unify_walk(fv, o, U):
...
@@ -239,7 +236,7 @@ def unify_walk(fv, o, U):
return
U
.
merge
(
v
,
fv
)
return
U
.
merge
(
v
,
fv
)
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
# noqa
def
unify_walk
(
bv
,
o
,
U
):
def
unify_walk
(
bv
,
o
,
U
):
"""
"""
The unification succeed iff BV.value == other_object.
The unification succeed iff BV.value == other_object.
...
@@ -251,7 +248,7 @@ def unify_walk(bv, o, U):
...
@@ -251,7 +248,7 @@ def unify_walk(bv, o, U):
return
False
return
False
@comm_guard
(
OrVariable
,
ANY_TYPE
)
@comm_guard
(
OrVariable
,
ANY_TYPE
)
# noqa
def
unify_walk
(
ov
,
o
,
U
):
def
unify_walk
(
ov
,
o
,
U
):
"""
"""
The unification succeeds iff other_object in OrV.options.
The unification succeeds iff other_object in OrV.options.
...
@@ -264,7 +261,7 @@ def unify_walk(ov, o, U):
...
@@ -264,7 +261,7 @@ def unify_walk(ov, o, U):
return
False
return
False
@comm_guard
(
NotVariable
,
ANY_TYPE
)
@comm_guard
(
NotVariable
,
ANY_TYPE
)
# noqa
def
unify_walk
(
nv
,
o
,
U
):
def
unify_walk
(
nv
,
o
,
U
):
"""
"""
The unification succeeds iff other_object not in NV.not_options.
The unification succeeds iff other_object not in NV.not_options.
...
@@ -277,7 +274,7 @@ def unify_walk(nv, o, U):
...
@@ -277,7 +274,7 @@ def unify_walk(nv, o, U):
return
U
.
merge
(
v
,
nv
)
return
U
.
merge
(
v
,
nv
)
@comm_guard
(
FreeVariable
,
Variable
)
@comm_guard
(
FreeVariable
,
Variable
)
# noqa
def
unify_walk
(
fv
,
v
,
U
):
def
unify_walk
(
fv
,
v
,
U
):
"""
"""
Both variables are unified.
Both variables are unified.
...
@@ -287,7 +284,7 @@ def unify_walk(fv, v, U):
...
@@ -287,7 +284,7 @@ def unify_walk(fv, v, U):
return
U
.
merge
(
v
,
fv
)
return
U
.
merge
(
v
,
fv
)
@comm_guard
(
BoundVariable
,
Variable
)
@comm_guard
(
BoundVariable
,
Variable
)
# noqa
def
unify_walk
(
bv
,
v
,
U
):
def
unify_walk
(
bv
,
v
,
U
):
"""
"""
V is unified to BV.value.
V is unified to BV.value.
...
@@ -296,13 +293,13 @@ def unify_walk(bv, v, U):
...
@@ -296,13 +293,13 @@ def unify_walk(bv, v, U):
return
unify_walk
(
v
,
bv
.
value
,
U
)
return
unify_walk
(
v
,
bv
.
value
,
U
)
@comm_guard
(
OrVariable
,
OrVariable
)
@comm_guard
(
OrVariable
,
OrVariable
)
# noqa
def
unify_walk
(
a
,
b
,
U
):
def
unify_walk
(
a
,
b
,
U
):
"""
"""
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
"""
"""
opt
=
intersection
(
a
.
options
,
b
.
options
)
opt
=
a
.
options
.
intersection
(
b
.
options
)
if
not
opt
:
if
not
opt
:
return
False
return
False
elif
len
(
opt
)
==
1
:
elif
len
(
opt
)
==
1
:
...
@@ -312,18 +309,18 @@ def unify_walk(a, b, U):
...
@@ -312,18 +309,18 @@ def unify_walk(a, b, U):
return
U
.
merge
(
v
,
a
,
b
)
return
U
.
merge
(
v
,
a
,
b
)
@comm_guard
(
NotVariable
,
NotVariable
)
@comm_guard
(
NotVariable
,
NotVariable
)
# noqa
def
unify_walk
(
a
,
b
,
U
):
def
unify_walk
(
a
,
b
,
U
):
"""
"""
NV(list1) == NV(list2) == NV(union(list1, list2))
NV(list1) == NV(list2) == NV(union(list1, list2))
"""
"""
opt
=
union
(
a
.
not_options
,
b
.
not_options
)
opt
=
a
.
not_options
.
union
(
b
.
not_options
)
v
=
NotVariable
(
"?"
,
opt
)
v
=
NotVariable
(
"?"
,
opt
)
return
U
.
merge
(
v
,
a
,
b
)
return
U
.
merge
(
v
,
a
,
b
)
@comm_guard
(
OrVariable
,
NotVariable
)
@comm_guard
(
OrVariable
,
NotVariable
)
# noqa
def
unify_walk
(
o
,
n
,
U
):
def
unify_walk
(
o
,
n
,
U
):
"""
"""
OrV(list1) == NV(list2) == OrV(list1
\
list2)
OrV(list1) == NV(list2) == OrV(list1
\
list2)
...
@@ -339,7 +336,7 @@ def unify_walk(o, n, U):
...
@@ -339,7 +336,7 @@ def unify_walk(o, n, U):
return
U
.
merge
(
v
,
o
,
n
)
return
U
.
merge
(
v
,
o
,
n
)
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
# noqa
def
unify_walk
(
vil
,
l
,
U
):
def
unify_walk
(
vil
,
l
,
U
):
"""
"""
Unifies VIL's inner Variable to OrV(list).
Unifies VIL's inner Variable to OrV(list).
...
@@ -350,7 +347,7 @@ def unify_walk(vil, l, U):
...
@@ -350,7 +347,7 @@ def unify_walk(vil, l, U):
return
unify_walk
(
v
,
ov
,
U
)
return
unify_walk
(
v
,
ov
,
U
)
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
# noqa
def
unify_walk
(
l1
,
l2
,
U
):
def
unify_walk
(
l1
,
l2
,
U
):
"""
"""
Tries to unify each corresponding pair of elements from l1 and l2.
Tries to unify each corresponding pair of elements from l1 and l2.
...
@@ -365,7 +362,7 @@ def unify_walk(l1, l2, U):
...
@@ -365,7 +362,7 @@ def unify_walk(l1, l2, U):
return
U
return
U
@comm_guard
(
dict
,
dict
)
@comm_guard
(
dict
,
dict
)
# noqa
def
unify_walk
(
d1
,
d2
,
U
):
def
unify_walk
(
d1
,
d2
,
U
):
"""
"""
Tries to unify values of corresponding keys.
Tries to unify values of corresponding keys.
...
@@ -379,7 +376,7 @@ def unify_walk(d1, d2, U):
...
@@ -379,7 +376,7 @@ def unify_walk(d1, d2, U):
return
U
return
U
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
# noqa
def
unify_walk
(
a
,
b
,
U
):
def
unify_walk
(
a
,
b
,
U
):
"""
"""
Checks for the existence of the __unify_walk__ method for one of
Checks for the existence of the __unify_walk__ method for one of
...
@@ -394,7 +391,7 @@ def unify_walk(a, b, U):
...
@@ -394,7 +391,7 @@ def unify_walk(a, b, U):
return
FALL_THROUGH
return
FALL_THROUGH
@comm_guard
(
Variable
,
ANY_TYPE
)
@comm_guard
(
Variable
,
ANY_TYPE
)
# noqa
def
unify_walk
(
v
,
o
,
U
):
def
unify_walk
(
v
,
o
,
U
):
"""
"""
This simply checks if the Var has an unification in U and uses it
This simply checks if the Var has an unification in U and uses it
...
@@ -429,27 +426,27 @@ def unify_merge(a, b, U):
...
@@ -429,27 +426,27 @@ def unify_merge(a, b, U):
return
a
return
a
@comm_guard
(
Variable
,
ANY_TYPE
)
@comm_guard
(
Variable
,
ANY_TYPE
)
# noqa
def
unify_merge
(
v
,
o
,
U
):
def
unify_merge
(
v
,
o
,
U
):
return
v
return
v
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
# noqa
def
unify_merge
(
bv
,
o
,
U
):
def
unify_merge
(
bv
,
o
,
U
):
return
bv
.
value
return
bv
.
value
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
# noqa
def
unify_merge
(
vil
,
l
,
U
):
def
unify_merge
(
vil
,
l
,
U
):
return
[
unify_merge
(
x
,
x
,
U
)
for
x
in
l
]
return
[
unify_merge
(
x
,
x
,
U
)
for
x
in
l
]
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
# noqa
def
unify_merge
(
l1
,
l2
,
U
):
def
unify_merge
(
l1
,
l2
,
U
):
return
[
unify_merge
(
x1
,
x2
,
U
)
for
x1
,
x2
in
zip
(
l1
,
l2
)]
return
[
unify_merge
(
x1
,
x2
,
U
)
for
x1
,
x2
in
zip
(
l1
,
l2
)]
@comm_guard
(
dict
,
dict
)
@comm_guard
(
dict
,
dict
)
# noqa
def
unify_merge
(
d1
,
d2
,
U
):
def
unify_merge
(
d1
,
d2
,
U
):
d
=
d1
.
__class__
()
d
=
d1
.
__class__
()
for
k1
,
v1
in
iteritems
(
d1
):
for
k1
,
v1
in
iteritems
(
d1
):
...
@@ -463,12 +460,12 @@ def unify_merge(d1, d2, U):
...
@@ -463,12 +460,12 @@ def unify_merge(d1, d2, U):
return
d
return
d
@comm_guard
(
FVar
,
ANY_TYPE
)
@comm_guard
(
FVar
,
ANY_TYPE
)
# noqa
def
unify_merge
(
vs
,
o
,
U
):
def
unify_merge
(
vs
,
o
,
U
):
return
vs
(
U
)
return
vs
(
U
)
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
# noqa
def
unify_merge
(
a
,
b
,
U
):
def
unify_merge
(
a
,
b
,
U
):
if
(
not
isinstance
(
a
,
Variable
)
and
if
(
not
isinstance
(
a
,
Variable
)
and
not
isinstance
(
b
,
Variable
)
and
not
isinstance
(
b
,
Variable
)
and
...
@@ -478,7 +475,7 @@ def unify_merge(a, b, U):
...
@@ -478,7 +475,7 @@ def unify_merge(a, b, U):
return
FALL_THROUGH
return
FALL_THROUGH
@comm_guard
(
Variable
,
ANY_TYPE
)
@comm_guard
(
Variable
,
ANY_TYPE
)
# noqa
def
unify_merge
(
v
,
o
,
U
):
def
unify_merge
(
v
,
o
,
U
):
"""
"""
This simply checks if the Var has an unification in U and uses it
This simply checks if the Var has an unification in U and uses it
...
...
theano/gof/vm.py
浏览文件 @
c0b294ec
...
@@ -27,6 +27,9 @@ logger = logging.getLogger(__name__)
...
@@ -27,6 +27,9 @@ logger = logging.getLogger(__name__)
def
calculate_reallocate_info
(
order
,
fgraph
,
storage_map
,
compute_map_re
,
def
calculate_reallocate_info
(
order
,
fgraph
,
storage_map
,
compute_map_re
,
dependencies
):
dependencies
):
"""
WRITEME : explain the parameters
"""
reallocated_info
=
{}
reallocated_info
=
{}
viewed_by
=
{}
viewed_by
=
{}
for
var
in
fgraph
.
variables
:
for
var
in
fgraph
.
variables
:
...
@@ -189,7 +192,9 @@ class VM(object):
...
@@ -189,7 +192,9 @@ class VM(object):
raise
NotImplementedError
(
'override me'
)
raise
NotImplementedError
(
'override me'
)
def
update_profile
(
self
,
profile
):
def
update_profile
(
self
,
profile
):
# accumulate into the profile object
"""
Accumulate into the profile object
"""
for
node
,
thunk
,
t
,
c
in
zip
(
self
.
nodes
,
self
.
thunks
,
for
node
,
thunk
,
t
,
c
in
zip
(
self
.
nodes
,
self
.
thunks
,
self
.
call_times
,
self
.
call_counts
):
self
.
call_times
,
self
.
call_counts
):
profile
.
apply_time
.
setdefault
(
node
,
0.0
)
profile
.
apply_time
.
setdefault
(
node
,
0.0
)
...
@@ -723,6 +728,9 @@ class VM_Linker(link.LocalLinker):
...
@@ -723,6 +728,9 @@ class VM_Linker(link.LocalLinker):
def
accept
(
self
,
fgraph
,
no_recycling
=
None
):
def
accept
(
self
,
fgraph
,
no_recycling
=
None
):
"""
"""
Check if fgraph is the first FunctionGraph that has ever been
associated to self, else, create a new VM_Linker
associated to fgraph
Parameters
Parameters
----------
----------
...
...
theano/tests/test_flake8.py
浏览文件 @
c0b294ec
...
@@ -126,13 +126,10 @@ whitelist_flake8 = [
...
@@ -126,13 +126,10 @@ whitelist_flake8 = [
"sparse/sandbox/sp2.py"
,
"sparse/sandbox/sp2.py"
,
"sparse/sandbox/truedot.py"
,
"sparse/sandbox/truedot.py"
,
"sparse/sandbox/sp.py"
,
"sparse/sandbox/sp.py"
,
"gof/unify.py"
,
"gof/__init__.py"
,
"gof/__init__.py"
,
"gof/sandbox/equilibrium.py"
,
"d3viz/__init__.py"
,
"d3viz/__init__.py"
,
"d3viz/tests/__init__.py"
,
"d3viz/tests/__init__.py"
,
"gof/tests/__init__.py"
,
"gof/tests/__init__.py"
,
]
]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论