Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c0b294ec
提交
c0b294ec
authored
6月 20, 2016
作者:
Frédéric Bastien
提交者:
GitHub
6月 20, 2016
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4629 from chinnadhurai/ccw_4483_indent_fix
Ccw 4483 indent fix
上级
d9028c7b
99fcfdcb
显示空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
244 行增加
和
278 行删除
+244
-278
callcache.py
theano/gof/callcache.py
+14
-0
cc.py
theano/gof/cc.py
+47
-45
cmodule.py
theano/gof/cmodule.py
+20
-4
compiledir.py
theano/gof/compiledir.py
+6
-0
destroyhandler.py
theano/gof/destroyhandler.py
+1
-0
fg.py
theano/gof/fg.py
+7
-19
graph.py
theano/gof/graph.py
+27
-18
link.py
theano/gof/link.py
+16
-14
op.py
theano/gof/op.py
+35
-1
equilibrium.py
theano/gof/sandbox/equilibrium.py
+0
-146
sched.py
theano/gof/sched.py
+3
-0
toolbox.py
theano/gof/toolbox.py
+35
-0
unify.py
theano/gof/unify.py
+24
-27
vm.py
theano/gof/vm.py
+9
-1
test_flake8.py
theano/tests/test_flake8.py
+0
-3
没有找到文件。
theano/gof/callcache.py
浏览文件 @
c0b294ec
...
...
@@ -17,12 +17,26 @@ class CallCache(object):
self
.
cache
=
{}
def
persist
(
self
,
filename
=
None
):
"""
Cache "filename" as a pickle file
"""
if
filename
is
None
:
filename
=
self
.
filename
with
open
(
filename
,
'w'
)
as
f
:
pickle
.
dump
(
self
.
cache
,
f
)
def
call
(
self
,
fn
,
args
=
(),
key
=
None
):
"""
Retrieve item from the cache(if available)
based on a key
Parameters:
----------
key
parameter to retrieve cache item
fn,args
key to retrieve if "key" is None
"""
if
key
is
None
:
key
=
(
fn
,
tuple
(
args
))
if
key
not
in
self
.
cache
:
...
...
theano/gof/cc.py
浏览文件 @
c0b294ec
...
...
@@ -61,8 +61,6 @@ def get_persistent_module_cache():
class
CodeBlock
:
"""
WRITEME
Represents a computation unit composed of declare, behavior, and cleanup.
The constructor initializes a L{CodeBlock} with templatized declare,
...
...
@@ -118,6 +116,12 @@ def failure_code_init(sub):
"""
Code for failure in the struct init.
Parameters:
----------
sub
Dictionary used to template the struct.
* failure_var -> must contain a variable name to use for
the failure code.
"""
return
'''{
if (!PyErr_Occurred()) {
...
...
@@ -131,10 +135,10 @@ def failure_code_init(sub):
def
code_gen
(
blocks
):
"""
WRITEME
From a list of L{CodeBlock} instances, returns a string
that executes them all in sequence. eg for C{(decl1, task1,
that executes them all in sequence.
Eg for C{(decl1, task1,
cleanup1)} and C{(decl2, task2, cleanup2)} the returned string
will be of the form:
...
...
@@ -149,6 +153,12 @@ def code_gen(blocks):
cleanup1
}
Parameters:
----------
blocks
List of CodeBlock instances such that
* declarations, behavior and cleanup are in the run()
method of the struct
"""
decl
=
""
head
=
""
...
...
@@ -162,8 +172,6 @@ def code_gen(blocks):
def
struct_gen
(
args
,
struct_builders
,
blocks
,
sub
):
"""
WRITEME
Generates a struct conforming to the following specifications:
Parameters
...
...
@@ -453,7 +461,7 @@ def get_c_sync(r, name, sub):
def
apply_policy
(
policy
,
r
,
name
,
sub
):
"""
WRITEME
Apply the list of policies to name.r,sub
Parameters
----------
...
...
@@ -478,7 +486,7 @@ def apply_policy(policy, r, name, sub):
def
struct_variable_codeblocks
(
variable
,
policies
,
id
,
symbol_table
,
sub
):
"""
WRITEME
Update "sub" dict and create two codeblocks with different failure modes
Parameters
----------
...
...
@@ -525,8 +533,6 @@ def struct_variable_codeblocks(variable, policies, id, symbol_table, sub):
class
CLinker
(
link
.
Linker
):
"""
WRITEME
Creates C code for an fgraph, compiles it and returns callables
through make_thunk and make_function that make use of the compiled
code.
...
...
@@ -544,7 +550,7 @@ class CLinker(link.Linker):
def
accept
(
self
,
fgraph
,
no_recycling
=
None
):
"""
WRITEME
Associate linker with fgraph
"""
if
no_recycling
is
None
:
...
...
@@ -559,8 +565,6 @@ class CLinker(link.Linker):
def
fetch_variables
(
self
):
"""
WRITEME
Fills the inputs, outputs, variables, orphans, temps and node_order
fields.
...
...
@@ -617,8 +621,6 @@ class CLinker(link.Linker):
def
code_gen
(
self
):
"""
WRITEME
Generates code for a struct that does the computation of the fgraph and
stores it in the struct_code field of the instance.
...
...
@@ -890,14 +892,9 @@ class CLinker(link.Linker):
def
support_code
(
self
):
"""
WRITEME
Returns a list of support code strings that are needed by
one or more Variables or Ops. The support code from Variables is
added before the support code from Ops.
This might contain duplicates.
one or more Variables or Ops.
The support code from Variables is added before the support code from Ops.This might contain duplicates.
"""
ret
=
[]
# generic support code
...
...
@@ -911,8 +908,6 @@ class CLinker(link.Linker):
def
compile_args
(
self
):
"""
WRITEME
Returns a list of compile args that are needed by one
or more Variables or Ops.
...
...
@@ -971,8 +966,6 @@ class CLinker(link.Linker):
def
headers
(
self
):
"""
WRITEME
Returns a list of headers that are needed by one
or more Types or Ops.
...
...
@@ -1032,8 +1025,6 @@ class CLinker(link.Linker):
def
header_dirs
(
self
):
"""
WRITEME
Returns a list of lib directories that are needed by one
or more Types or Ops.
...
...
@@ -1055,8 +1046,6 @@ class CLinker(link.Linker):
def
libraries
(
self
):
"""
WRITEME
Returns a list of libraries that are needed by one
or more Types or Ops.
...
...
@@ -1078,8 +1067,6 @@ class CLinker(link.Linker):
def
lib_dirs
(
self
):
"""
WRITEME
Returns a list of lib directories that are needed by one
or more Types or Ops.
...
...
@@ -1101,7 +1088,7 @@ class CLinker(link.Linker):
def
__compile__
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
,
keep_lock
=
False
):
"""
WRITEME
"""
Compiles this linker's fgraph.
Parameters
...
...
@@ -1166,7 +1153,7 @@ class CLinker(link.Linker):
def
make_thunk
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
,
keep_lock
=
False
):
"""
WRITEME
"""
Compiles this linker's fgraph and returns a function to perform the
computations, as well as lists of storage cells for both the inputs
and outputs.
...
...
@@ -1183,8 +1170,10 @@ class CLinker(link.Linker):
be allocated.
storage_map: dict that map variables to storages.
This is used when you need to customize the storage of
this thunk.
this thunk
keep_lock:
If True, we won't release the lock on the compiledir
at the end of this function call.
Returns: thunk, input_storage, output_storage
The return values can be used as follows:
...
...
@@ -1568,7 +1557,12 @@ class CLinker(link.Linker):
def
cthunk_factory
(
self
,
error_storage
,
in_storage
,
out_storage
,
storage_map
=
None
,
keep_lock
=
False
):
"""WRITEME
"""
Returns a thunk that points to an instance of a C struct that
can carry on the computation of this linker's fgraph
Parameters:
----------
error_storage -> list of length 3
in_storage -> list of lists of length 1, one per input
out_storage -> list of lists of length 1, one per output
...
...
@@ -1705,8 +1699,6 @@ class _CThunk(object):
class
OpWiseCLinker
(
link
.
LocalLinker
):
"""
WRITEME
Uses CLinker on the individual Ops that comprise an fgraph and loops
over them in Python. The variable is slower than a compiled version of
the whole fgraph, but saves on compilation time because small changes
...
...
@@ -1746,6 +1738,9 @@ class OpWiseCLinker(link.LocalLinker):
self
.
schedule
=
schedule
def
accept
(
self
,
fgraph
,
no_recycling
=
None
):
"""
Associate linker with fgraph
"""
if
no_recycling
is
None
:
no_recycling
=
[]
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
...
...
@@ -1846,11 +1841,14 @@ class OpWiseCLinker(link.LocalLinker):
def
_default_checker
(
x
,
y
):
"""
WRITEME
Default checker for DualLinker. This checks that the
variables contain the same data using ==.
Parameters:
----------
x,y
the variables to compare data
"""
if
x
[
0
]
!=
y
[
0
]:
raise
Exception
(
"Output mismatch."
,
...
...
@@ -1859,8 +1857,6 @@ def _default_checker(x, y):
class
DualLinker
(
link
.
Linker
):
"""
WRITEME
Runs the fgraph in parallel using PerformLinker and CLinker.
The thunk/function produced by DualLinker uses PerformLinker as the
...
...
@@ -1902,6 +1898,9 @@ class DualLinker(link.Linker):
self
.
schedule
=
schedule
def
accept
(
self
,
fgraph
,
no_recycling
=
None
):
"""
Update/tie self with fgraph
"""
if
no_recycling
is
None
:
no_recycling
=
[]
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
...
...
@@ -1912,7 +1911,10 @@ class DualLinker(link.Linker):
return
self
def
make_thunk
(
self
,
**
kwargs
):
"""
Compiles this linker's fgraph and returns a function to perform the
computations
"""
fgraph
=
self
.
fgraph
no_recycling
=
self
.
no_recycling
...
...
theano/gof/cmodule.py
浏览文件 @
c0b294ec
...
...
@@ -1474,10 +1474,25 @@ class ModuleCache(object):
def
_rmtree
(
parent
,
ignore_nocleanup
=
False
,
msg
=
''
,
level
=
logging
.
DEBUG
,
ignore_if_missing
=
False
):
# On NFS filesystems, it is impossible to delete a directory with open
# files in it. So instead, some commands in this file will respond to a
# failed rmtree() by touching a 'delete.me' file. This file is a message
# for a future process to try deleting the directory.
"""
On NFS filesystems, it is impossible to delete a directory with open
files in it.
So instead, some commands in this file will respond to a
failed rmtree() by touching a 'delete.me' file. This file is a message
for a future process to try deleting the directory.
Parameters:
----------
parent
Root node to start deleting from
ignore_nocleanup
Delete the tree if flag is TRUE
level
Python Logging level. Set to "DEBUG" by default
ignore_if_missing
If set to True, just return without any issue if parent is NULL
"""
if
ignore_if_missing
and
not
os
.
path
.
exists
(
parent
):
return
try
:
...
...
@@ -1504,6 +1519,7 @@ _module_cache = None
def
get_module_cache
(
dirname
,
init_args
=
None
):
"""
Create a new module_cache with the (k, v) pairs in this dictionary
Parameters
----------
...
...
theano/gof/compiledir.py
浏览文件 @
c0b294ec
...
...
@@ -94,6 +94,9 @@ def cleanup():
def
print_compiledir_content
():
"""
print list of
%
d compiled individual ops in the "theano.config.compiledir"
"""
max_key_file_size
=
1
*
1024
*
1024
# 1M
compiledir
=
theano
.
config
.
compiledir
...
...
@@ -178,6 +181,9 @@ def compiledir_purge():
def
basecompiledir_ls
():
"""
Print list of files in the "theano.config.base_compiledir"
"""
subdirs
=
[]
others
=
[]
for
f
in
os
.
listdir
(
config
.
base_compiledir
):
...
...
theano/gof/destroyhandler.py
浏览文件 @
c0b294ec
...
...
@@ -32,6 +32,7 @@ class ProtocolError(Exception):
def
_contains_cycle
(
fgraph
,
orderings
):
"""
Function to check if the given graph contains a cycle
Parameters
----------
...
...
theano/gof/fg.py
浏览文件 @
c0b294ec
...
...
@@ -66,7 +66,6 @@ class MissingInputError(Exception):
class
FunctionGraph
(
utils
.
object2
):
"""
WRITEME
A FunctionGraph represents a subgraph bound by a set of input variables and
a set of output variables, ie a subgraph that specifies a theano function.
The inputs list should contain all the inputs on which the outputs depend.
...
...
@@ -265,8 +264,6 @@ class FunctionGraph(utils.object2):
"""
Updates the list of clients of r with new_clients.
WRITEME
Parameters
----------
r
...
...
@@ -365,6 +362,11 @@ class FunctionGraph(utils.object2):
"""
Import variables to this FunctionGraph and also their apply_node,
if those nodes are not in this graph.
Parameters:
----------
reason
reason is the name of the optimization or operation in progress.
"""
global
NullType
if
NullType
is
None
:
...
...
@@ -438,8 +440,6 @@ class FunctionGraph(utils.object2):
"""
Changes node.inputs[i] to new_r.
WRITEME
new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace.
...
...
@@ -483,8 +483,6 @@ class FunctionGraph(utils.object2):
# replace #
def
replace
(
self
,
r
,
new_r
,
reason
=
None
,
verbose
=
None
):
"""
WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead.
...
...
@@ -540,7 +538,7 @@ class FunctionGraph(utils.object2):
def
replace_all
(
self
,
pairs
,
reason
=
None
):
"""
WRITEME
For every node that uses r as input, makes it use new_r instead
"""
for
r
,
new_r
in
pairs
:
...
...
@@ -578,8 +576,6 @@ class FunctionGraph(utils.object2):
def
remove_feature
(
self
,
feature
):
"""
WRITEME
Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method
...
...
@@ -598,8 +594,6 @@ class FunctionGraph(utils.object2):
# callback utils #
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
"""
WRITEME
Calls
getattr(feature, name)(*args)
for each feature which has a method called after name.
...
...
@@ -621,8 +615,6 @@ class FunctionGraph(utils.object2):
def
collect_callbacks
(
self
,
name
,
*
args
):
"""
WRITEME
Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name.
...
...
@@ -640,8 +632,6 @@ class FunctionGraph(utils.object2):
# misc #
def
toposort
(
self
):
"""
WRITEME
Return an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
...
...
@@ -705,8 +695,6 @@ class FunctionGraph(utils.object2):
def
check_integrity
(
self
):
"""
WRITEME
Call this for a diagnosis if things go awry.
"""
...
...
@@ -766,7 +754,7 @@ class FunctionGraph(utils.object2):
# clone #
def
clone
(
self
,
check_integrity
=
True
):
"""
WRITEME
Clone the graph and get a memo( a dict )that map old node to new node
"""
return
self
.
clone_get_equiv
(
check_integrity
)[
0
]
...
...
theano/gof/graph.py
浏览文件 @
c0b294ec
...
...
@@ -701,7 +701,15 @@ def inputs(variable_list, blockers=None):
def
variables_and_orphans
(
i
,
o
):
"""
WRITEME
Extract list of variables between i and o nodes via
dfs traversal and chooses the orphans among them
Parameters
----------
i : list
Input variables.
o : list
Output variables.
"""
def
expand
(
r
):
...
...
@@ -716,21 +724,21 @@ def variables_and_orphans(i, o):
def
ops
(
i
,
o
):
"""
WRITEME
Set of Ops contained within the subgraph between i and o
Parameters
----------
i : list
Input
L{Variable}
s.
Input
variable
s.
o : list
Output
L{Variable}
s.
Output
variable
s.
Returns
-------
object
The set of ops that are contained within the subgraph that lies
between i and o, including the owners of the
L{Variable}
s in o and
intermediary ops between i and o, but not the owners of the
L{Variable}
s
between i and o, including the owners of the
variable
s in o and
intermediary ops between i and o, but not the owners of the
variable
s
in i.
"""
...
...
@@ -745,14 +753,14 @@ def ops(i, o):
def
variables
(
i
,
o
):
"""
WRITEME
Extracts list of variables within input and output nodes via dfs travesal
Parameters
----------
i : list
Input
L{Variable}
s.
Input
variable
s.
o : list
Output
L{Variable}
s.
Output
variable
s.
Returns
-------
...
...
@@ -767,14 +775,15 @@ def variables(i, o):
def
orphans
(
i
,
o
):
"""
WRITEME
Extracts list of variables within input and output nodes
via dfs travesal and returns the orphans among them
Parameters
----------
i : list
Input
L{Variable}
s.
Input
Variable
s.
o : list
Output
L{Variable}
s.
Output
Variable
s.
Returns
-------
...
...
@@ -797,9 +806,9 @@ def clone(i, o, copy_inputs=True):
Parameters
----------
i : list
Input
L{Variable}
s.
Input
Variable
s.
o : list
Output
L{Variable}
s.
Output
Variable
s.
copy_inputs : bool
If True, the inputs will be copied (defaults to True).
...
...
@@ -959,7 +968,7 @@ def general_toposort(r_out, deps, debug_print=False,
def
io_toposort
(
inputs
,
outputs
,
orderings
=
None
,
clients
=
None
):
"""
WRITEME
Perform topological sort from input and output nodes
Parameters
----------
...
...
@@ -1218,8 +1227,8 @@ def op_as_string(i, op,
leaf_formatter
=
default_leaf_formatter
,
node_formatter
=
default_node_formatter
):
"""
WRITEME
Op to return a string representation of the subgraph
between i and o
"""
strs
=
as_string
(
i
,
op
.
inputs
,
leaf_formatter
,
node_formatter
)
return
node_formatter
(
op
,
strs
)
...
...
@@ -1229,7 +1238,7 @@ def as_string(i, o,
leaf_formatter
=
default_leaf_formatter
,
node_formatter
=
default_node_formatter
):
"""
WRITEME
Returns a string representation of the subgraph between i and o
Parameters
----------
...
...
theano/gof/link.py
浏览文件 @
c0b294ec
...
...
@@ -52,16 +52,24 @@ def log_thunk_trace(value, f=sys.stderr):
def
thunk_hook
(
type
,
value
,
trace
):
"""
WRITEME
This function is meant to replace excepthook and do some
special work if the exception value has a __thunk_trace__
field. In that case, it retrieves the field, which should
field.
In that case, it retrieves the field, which should
contain a trace as returned by L{traceback.extract_stack},
and prints it out on L{stderr}.
The normal excepthook is then called.
Parameters:
----------
type
Exception class
value
Exception instance
trace
Traceback object
Notes
-----
This hook replaced by nosetests, so it does not run in nose tests.
...
...
@@ -680,8 +688,6 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
class
LocalLinker
(
Linker
):
"""
WRITEME
Useful base class for L{Linker}s which keep all nodes in the graph, and run
a thunk associated with each node.
...
...
@@ -707,7 +713,7 @@ class LocalLinker(Linker):
def
gc_helper
(
node_list
):
"""
Return the set of Variable instances which are computed by node_list.
Parameters
----------
node_list
...
...
@@ -743,8 +749,6 @@ def gc_helper(node_list):
class
PerformLinker
(
LocalLinker
):
"""
WRITEME
Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{FunctionGraph} in the order given by L{Linker.schedule}.
...
...
@@ -764,8 +768,7 @@ class PerformLinker(LocalLinker):
Parameters
----------
fgraph
A PerformLinker can have accepted one FunctionGraph instance at a
time.
A PerformLinker can have accepted one FunctionGraph instance at a time.
no_recycling
WRITEME
...
...
@@ -786,13 +789,14 @@ class PerformLinker(LocalLinker):
def
make_all
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
):
"""
Returns Function to run all nodes, list of input containers, list of outputs
Parameters
----------
input_storage
WRITEME
list of storages corresponding to fgraph.inputs
output_storage
WRITEME
list of storages corresponding to fgraph.outputs
Returns
-------
...
...
@@ -879,8 +883,6 @@ def add_clear_storage(f, computed, storage_map):
class
WrapLinker
(
Linker
):
"""
WRITEME
This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run.
...
...
theano/gof/op.py
浏览文件 @
c0b294ec
...
...
@@ -791,6 +791,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
self
.
_op_use_c_code
=
use_c_code
def
_props
(
self
):
"""
Tuple of properties of all attributes
"""
return
tuple
(
getattr
(
self
,
a
)
for
a
in
self
.
__props__
)
def
_props_dict
(
self
):
...
...
@@ -924,6 +927,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
"""
This function must return a thunk, that is a zero-arguments
function that encapsulates the computation to be performed
by this op on the arguments of the node.
Parameters
----------
...
...
@@ -974,7 +980,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
return
self
.
make_py_thunk
(
node
,
storage_map
,
compute_map
,
no_recycling
)
def
make_node
(
self
,
*
inputs
):
"""
Create a "apply" nodes for the inputs in that order.
"""
if
not
hasattr
(
self
,
'itypes'
):
raise
NotImplementedError
(
"You can either define itypes and otypes,
\
or implement make_node"
)
...
...
@@ -1058,6 +1066,10 @@ def debug_error_message(msg):
def
debug_assert
(
condition
,
msg
=
None
):
"""
Customized assert with options to ignore the assert
with just a warning
"""
if
msg
is
None
:
msg
=
'debug_assert failed'
if
not
condition
:
...
...
@@ -1165,12 +1177,18 @@ class OpenMPOp(Op):
self
.
openmp
=
False
def
c_compile_args
(
self
):
"""
Return the compilation arg "fopenmp" if openMP is supported
"""
self
.
update_self_openmp
()
if
self
.
openmp
:
return
[
'-fopenmp'
]
return
[]
def
c_headers
(
self
):
"""
Return the header file name "omp.h" if openMP is supported
"""
self
.
update_self_openmp
()
if
self
.
openmp
:
return
[
"omp.h"
]
...
...
@@ -1178,6 +1196,9 @@ class OpenMPOp(Op):
@staticmethod
def
test_gxx_support
():
"""
Check if openMP is supported
"""
code
=
"""
#include <omp.h>
int main( int argc, const char* argv[] )
...
...
@@ -1313,6 +1334,9 @@ class COp(Op):
'and specify the func_name'
)
def
load_c_code
(
self
):
"""
Loads the c code to perform the Op
"""
self
.
func_codes
=
[]
for
func_file
in
self
.
func_files
:
with
open
(
func_file
,
'r'
)
as
f
:
...
...
@@ -1391,6 +1415,9 @@ class COp(Op):
return
hash
(
tuple
(
self
.
func_codes
))
def
c_init_code
(
self
):
"""
Get the code section for init_code
"""
if
'init_code'
in
self
.
code_sections
:
return
[
self
.
code_sections
[
'init_code'
]]
else
:
...
...
@@ -1500,6 +1527,10 @@ class COp(Op):
undef_macros
.
append
(
"#undef OUTPUT_
%
d"
,
(
i
,))
def
c_init_code_struct
(
self
,
node
,
name
,
sub
):
"""
Stitches all the macros and "init_code" together
"""
if
'init_code_struct'
in
self
.
code_sections
:
op_code
=
self
.
code_sections
[
'init_code_struct'
]
...
...
@@ -1554,6 +1585,9 @@ class COp(Op):
'c_code'
,
type
(
self
),
type
(
self
)
.
__name__
)
def
c_code_cleanup
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
"""
Stitches all the macros and "code_cleanup" together
"""
if
'code_cleanup'
in
self
.
code_sections
:
op_code
=
self
.
code_sections
[
'code_cleanup'
]
...
...
theano/gof/sandbox/equilibrium.py
deleted
100644 → 0
浏览文件 @
d9028c7b
from
__future__
import
absolute_import
,
print_function
,
division
from
six.moves
import
reduce
from
six
import
string_types
if
0
:
class
_EquilibriumOptimizer
(
NavigatorOptimizer
):
def
__init__
(
self
,
local_optimizers
,
failure_callback
=
None
,
max_depth
=
None
,
max_use_ratio
=
None
):
super
(
EquilibriumOptimizer
,
self
)
.
__init__
(
None
,
ignore_newtrees
=
False
,
failure_callback
=
failure_callback
)
self
.
local_optimizers
=
local_optimizers
self
.
max_depth
=
max_depth
self
.
max_use_ratio
=
max_use_ratio
self
.
tracks
=
defaultdict
(
list
)
self
.
tracks0
=
defaultdict
(
list
)
max_depth
=
0
for
lopt
in
local_optimizers
:
tracks
=
lopt
.
tracks
()
for
track
in
tracks
:
max_depth
=
max
(
max_depth
,
len
(
track
))
if
self
.
max_depth
is
not
None
and
max_depth
>
self
.
max_depth
:
raise
ValueError
(
'One of the local optimizers exceeds the maximal depth.'
)
for
i
,
op
in
enumerate
(
track
):
if
i
==
0
:
self
.
tracks0
[
op
]
.
append
((
track
,
i
,
lopt
))
self
.
tracks
[
op
]
.
append
((
track
,
i
,
lopt
))
def
fetch_tracks
(
self
,
op
):
return
self
.
tracks
[
op
]
+
self
.
tracks
[
None
]
def
fetch_tracks0
(
self
,
op
):
return
self
.
tracks0
[
op
]
+
self
.
tracks0
[
None
]
def
backtrack
(
self
,
node
,
tasks
):
candidates
=
self
.
fetch_tracks
(
node
.
op
)
tracks
=
[]
def
filter
(
node
,
depth
):
new_candidates
=
[]
for
candidate
in
candidates
:
track
,
i
,
lopt
=
candidate
if
i
<
depth
:
pass
elif
track
[
i
-
depth
]
in
(
None
,
node
.
op
):
if
i
==
depth
:
tasks
[
node
]
.
append
(
lopt
)
else
:
tracks
.
append
(
candidate
)
else
:
new_candidates
.
append
(
candidate
)
return
new_candidates
depth
=
0
nodes
=
[
node
]
while
candidates
:
for
node
in
nodes
:
candidates
=
list
(
filter
(
node
,
depth
))
depth
+=
1
_nodes
=
nodes
nodes
=
reduce
(
list
.
__iadd__
,
[
reduce
(
list
.
__iadd__
,
[[
n
for
n
,
i
in
out
.
clients
if
not
isinstance
(
n
,
string_types
)]
for
out
in
node
.
outputs
],
[])
for
node
in
nodes
],
[])
candidates
=
tracks
tracks
=
[]
def
apply
(
self
,
fgraph
):
tasks
=
defaultdict
(
list
)
if
self
.
max_use_ratio
is
not
None
:
max_uses
=
self
.
max_use_ratio
*
len
(
fgraph
.
apply_nodes
)
runs
=
defaultdict
(
int
)
else
:
runs
=
None
def
importer
(
node
):
# print 'IMPORTING', node
self
.
backtrack
(
node
,
tasks
)
def
pruner
(
node
):
try
:
del
tasks
[
node
]
except
KeyError
:
pass
def
chin
(
node
,
i
,
r
,
new_r
):
if
new_r
.
owner
and
not
r
.
clients
:
self
.
backtrack
(
new_r
.
owner
,
tasks
)
# # == NOT IDEAL == #
# for node in fgraph.apply_nodes:
# importer(node)
for
node
in
fgraph
.
toposort
():
tasks
[
node
]
.
extend
(
lopt
for
track
,
i
,
lopt
in
self
.
fetch_tracks0
(
node
.
op
))
u
=
self
.
attach_updater
(
fgraph
,
importer
,
pruner
,
chin
)
print
(
'KEYS'
,
[
hash
(
t
)
for
t
in
tasks
.
keys
()])
while
tasks
:
for
node
in
tasks
:
todo
=
tasks
.
pop
(
node
)
break
for
lopt
in
todo
:
if
runs
is
not
None
and
runs
[
lopt
]
>=
max_uses
:
print
(
'Warning: optimization exceeded its maximal use ratio:
%
s,
%
s'
%
(
lopt
,
max_uses
),
file
=
sys
.
stderr
)
continue
success
=
self
.
process_node
(
fgraph
,
node
,
lopt
)
if
success
:
if
runs
is
not
None
:
runs
[
lopt
]
+=
1
break
self
.
detach_updater
(
fgraph
,
u
)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, string_types):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
theano/gof/sched.py
浏览文件 @
c0b294ec
...
...
@@ -268,6 +268,9 @@ def sort_schedule_fn(*cmps):
def
key_to_cmp
(
key
):
"""
comparator function based on "key" function
"""
def
key_cmp
(
a
,
b
):
return
cmp
(
key
(
a
),
key
(
b
))
return
key_cmp
theano/gof/toolbox.py
浏览文件 @
c0b294ec
...
...
@@ -114,10 +114,20 @@ class Feature(object):
class
Bookkeeper
(
Feature
):
def
on_attach
(
self
,
fgraph
):
"""
Called by FunctionGraph.attach_feature, the method that attaches
the feature to the FunctionGraph. Since this is called after the
FunctionGraph is initially populated, this is where you should
run checks on the initial contents of the FunctionGraph.
"""
for
node
in
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
):
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
for
node
in
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
):
self
.
on_prune
(
fgraph
,
node
,
'Bookkeeper.detach'
)
...
...
@@ -178,6 +188,10 @@ class History(Feature):
fgraph
.
revert
=
partial
(
self
.
revert
,
fgraph
)
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
del
fgraph
.
checkpoint
del
fgraph
.
revert
del
self
.
history
[
fgraph
]
...
...
@@ -223,10 +237,19 @@ class Validator(Feature):
fgraph
.
consistent
=
partial
(
self
.
consistent_
,
fgraph
)
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
del
fgraph
.
validate
del
fgraph
.
consistent
def
validate_
(
self
,
fgraph
):
"""
If the caller is replace_all_validate, just raise the
exception. replace_all_validate will print out the
verbose output. Or it has to be done here before raise.
"""
t0
=
time
.
time
()
try
:
ret
=
fgraph
.
execute_callbacks
(
'validate'
)
...
...
@@ -289,6 +312,10 @@ class ReplaceValidate(History, Validator):
self
.
replace_all_validate_remove
,
fgraph
)
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
History
.
on_detach
(
self
,
fgraph
)
Validator
.
on_detach
(
self
,
fgraph
)
del
self
.
_nodes_removed
...
...
@@ -412,6 +439,10 @@ class NodeFinder(Bookkeeper):
Bookkeeper
.
on_attach
(
self
,
fgraph
)
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if
self
.
fgraph
is
not
fgraph
:
raise
Exception
(
"This NodeFinder instance was not attached to the"
" provided fgraph."
)
...
...
@@ -461,6 +492,10 @@ class PrintListener(Feature):
print
(
"-- attaching to: "
,
fgraph
)
def
on_detach
(
self
,
fgraph
):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if
self
.
active
:
print
(
"-- detaching from: "
,
fgraph
)
...
...
theano/gof/unify.py
浏览文件 @
c0b294ec
...
...
@@ -15,7 +15,6 @@ from copy import copy
from
functools
import
partial
from
theano.gof.utils
import
ANY_TYPE
,
comm_guard
,
FALL_THROUGH
,
iteritems
################################
...
...
@@ -135,8 +134,6 @@ class Unification:
"""
This class represents a possible unification of a group of variables
with each other or with tangible values.
Parameters
----------
inplace : bool
...
...
@@ -229,7 +226,7 @@ def unify_walk(a, b, U):
return
False
@comm_guard
(
FreeVariable
,
ANY_TYPE
)
@comm_guard
(
FreeVariable
,
ANY_TYPE
)
# noqa
def
unify_walk
(
fv
,
o
,
U
):
"""
FreeV is unified to BoundVariable(other_object).
...
...
@@ -239,7 +236,7 @@ def unify_walk(fv, o, U):
return
U
.
merge
(
v
,
fv
)
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
# noqa
def
unify_walk
(
bv
,
o
,
U
):
"""
The unification succeed iff BV.value == other_object.
...
...
@@ -251,7 +248,7 @@ def unify_walk(bv, o, U):
return
False
@comm_guard
(
OrVariable
,
ANY_TYPE
)
@comm_guard
(
OrVariable
,
ANY_TYPE
)
# noqa
def
unify_walk
(
ov
,
o
,
U
):
"""
The unification succeeds iff other_object in OrV.options.
...
...
@@ -264,7 +261,7 @@ def unify_walk(ov, o, U):
return
False
@comm_guard
(
NotVariable
,
ANY_TYPE
)
@comm_guard
(
NotVariable
,
ANY_TYPE
)
# noqa
def
unify_walk
(
nv
,
o
,
U
):
"""
The unification succeeds iff other_object not in NV.not_options.
...
...
@@ -277,7 +274,7 @@ def unify_walk(nv, o, U):
return
U
.
merge
(
v
,
nv
)
@comm_guard
(
FreeVariable
,
Variable
)
@comm_guard
(
FreeVariable
,
Variable
)
# noqa
def
unify_walk
(
fv
,
v
,
U
):
"""
Both variables are unified.
...
...
@@ -287,7 +284,7 @@ def unify_walk(fv, v, U):
return
U
.
merge
(
v
,
fv
)
@comm_guard
(
BoundVariable
,
Variable
)
@comm_guard
(
BoundVariable
,
Variable
)
# noqa
def
unify_walk
(
bv
,
v
,
U
):
"""
V is unified to BV.value.
...
...
@@ -296,13 +293,13 @@ def unify_walk(bv, v, U):
return
unify_walk
(
v
,
bv
.
value
,
U
)
@comm_guard
(
OrVariable
,
OrVariable
)
@comm_guard
(
OrVariable
,
OrVariable
)
# noqa
def
unify_walk
(
a
,
b
,
U
):
"""
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
"""
opt
=
intersection
(
a
.
options
,
b
.
options
)
opt
=
a
.
options
.
intersection
(
b
.
options
)
if
not
opt
:
return
False
elif
len
(
opt
)
==
1
:
...
...
@@ -312,18 +309,18 @@ def unify_walk(a, b, U):
return
U
.
merge
(
v
,
a
,
b
)
@comm_guard
(
NotVariable
,
NotVariable
)
@comm_guard
(
NotVariable
,
NotVariable
)
# noqa
def
unify_walk
(
a
,
b
,
U
):
"""
NV(list1) == NV(list2) == NV(union(list1, list2))
"""
opt
=
union
(
a
.
not_options
,
b
.
not_options
)
opt
=
a
.
not_options
.
union
(
b
.
not_options
)
v
=
NotVariable
(
"?"
,
opt
)
return
U
.
merge
(
v
,
a
,
b
)
@comm_guard
(
OrVariable
,
NotVariable
)
@comm_guard
(
OrVariable
,
NotVariable
)
# noqa
def
unify_walk
(
o
,
n
,
U
):
"""
OrV(list1) == NV(list2) == OrV(list1
\
list2)
...
...
@@ -339,7 +336,7 @@ def unify_walk(o, n, U):
return
U
.
merge
(
v
,
o
,
n
)
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
# noqa
def
unify_walk
(
vil
,
l
,
U
):
"""
Unifies VIL's inner Variable to OrV(list).
...
...
@@ -350,7 +347,7 @@ def unify_walk(vil, l, U):
return
unify_walk
(
v
,
ov
,
U
)
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
# noqa
def
unify_walk
(
l1
,
l2
,
U
):
"""
Tries to unify each corresponding pair of elements from l1 and l2.
...
...
@@ -365,7 +362,7 @@ def unify_walk(l1, l2, U):
return
U
@comm_guard
(
dict
,
dict
)
@comm_guard
(
dict
,
dict
)
# noqa
def
unify_walk
(
d1
,
d2
,
U
):
"""
Tries to unify values of corresponding keys.
...
...
@@ -379,7 +376,7 @@ def unify_walk(d1, d2, U):
return
U
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
# noqa
def
unify_walk
(
a
,
b
,
U
):
"""
Checks for the existence of the __unify_walk__ method for one of
...
...
@@ -394,7 +391,7 @@ def unify_walk(a, b, U):
return
FALL_THROUGH
@comm_guard
(
Variable
,
ANY_TYPE
)
@comm_guard
(
Variable
,
ANY_TYPE
)
# noqa
def
unify_walk
(
v
,
o
,
U
):
"""
This simply checks if the Var has an unification in U and uses it
...
...
@@ -429,27 +426,27 @@ def unify_merge(a, b, U):
return
a
@comm_guard
(
Variable
,
ANY_TYPE
)
@comm_guard
(
Variable
,
ANY_TYPE
)
# noqa
def
unify_merge
(
v
,
o
,
U
):
return
v
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
# noqa
def
unify_merge
(
bv
,
o
,
U
):
return
bv
.
value
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
# noqa
def
unify_merge
(
vil
,
l
,
U
):
return
[
unify_merge
(
x
,
x
,
U
)
for
x
in
l
]
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
# noqa
def
unify_merge
(
l1
,
l2
,
U
):
return
[
unify_merge
(
x1
,
x2
,
U
)
for
x1
,
x2
in
zip
(
l1
,
l2
)]
@comm_guard
(
dict
,
dict
)
@comm_guard
(
dict
,
dict
)
# noqa
def
unify_merge
(
d1
,
d2
,
U
):
d
=
d1
.
__class__
()
for
k1
,
v1
in
iteritems
(
d1
):
...
...
@@ -463,12 +460,12 @@ def unify_merge(d1, d2, U):
return
d
@comm_guard
(
FVar
,
ANY_TYPE
)
@comm_guard
(
FVar
,
ANY_TYPE
)
# noqa
def
unify_merge
(
vs
,
o
,
U
):
return
vs
(
U
)
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
# noqa
def
unify_merge
(
a
,
b
,
U
):
if
(
not
isinstance
(
a
,
Variable
)
and
not
isinstance
(
b
,
Variable
)
and
...
...
@@ -478,7 +475,7 @@ def unify_merge(a, b, U):
return
FALL_THROUGH
@comm_guard
(
Variable
,
ANY_TYPE
)
@comm_guard
(
Variable
,
ANY_TYPE
)
# noqa
def
unify_merge
(
v
,
o
,
U
):
"""
This simply checks if the Var has an unification in U and uses it
...
...
theano/gof/vm.py
浏览文件 @
c0b294ec
...
...
@@ -27,6 +27,9 @@ logger = logging.getLogger(__name__)
def
calculate_reallocate_info
(
order
,
fgraph
,
storage_map
,
compute_map_re
,
dependencies
):
"""
WRITEME : explain the parameters
"""
reallocated_info
=
{}
viewed_by
=
{}
for
var
in
fgraph
.
variables
:
...
...
@@ -189,7 +192,9 @@ class VM(object):
raise
NotImplementedError
(
'override me'
)
def
update_profile
(
self
,
profile
):
# accumulate into the profile object
"""
Accumulate into the profile object
"""
for
node
,
thunk
,
t
,
c
in
zip
(
self
.
nodes
,
self
.
thunks
,
self
.
call_times
,
self
.
call_counts
):
profile
.
apply_time
.
setdefault
(
node
,
0.0
)
...
...
@@ -723,6 +728,9 @@ class VM_Linker(link.LocalLinker):
def
accept
(
self
,
fgraph
,
no_recycling
=
None
):
"""
Check if fgraph is the first FunctionGraph that has ever been
associated to self, else, create a new VM_Linker
associated to fgraph
Parameters
----------
...
...
theano/tests/test_flake8.py
浏览文件 @
c0b294ec
...
...
@@ -126,13 +126,10 @@ whitelist_flake8 = [
"sparse/sandbox/sp2.py"
,
"sparse/sandbox/truedot.py"
,
"sparse/sandbox/sp.py"
,
"gof/unify.py"
,
"gof/__init__.py"
,
"gof/sandbox/equilibrium.py"
,
"d3viz/__init__.py"
,
"d3viz/tests/__init__.py"
,
"gof/tests/__init__.py"
,
]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论