Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7f3bfb23
提交
7f3bfb23
authored
9月 12, 2008
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added WRITEME many places
上级
77c55988
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
162 行增加
和
67 行删除
+162
-67
_test_destroyhandler.py
gof/_test_destroyhandler.py
+0
-1
cc.py
gof/cc.py
+27
-19
destroyhandler.py
gof/destroyhandler.py
+7
-1
env.py
gof/env.py
+22
-14
graph.py
gof/graph.py
+62
-14
link.py
gof/link.py
+15
-6
op.py
gof/op.py
+1
-0
opt.py
gof/opt.py
+28
-12
没有找到文件。
gof/_test_destroyhandler.py
浏览文件 @
7f3bfb23
...
@@ -7,7 +7,6 @@ from graph import Result, Apply
...
@@ -7,7 +7,6 @@ from graph import Result, Apply
from
op
import
Op
from
op
import
Op
from
opt
import
*
from
opt
import
*
from
ext
import
*
import
destroyhandler
import
destroyhandler
from
env
import
Env
,
InconsistencyError
from
env
import
Env
,
InconsistencyError
from
toolbox
import
ReplaceValidate
from
toolbox
import
ReplaceValidate
...
...
gof/cc.py
浏览文件 @
7f3bfb23
...
@@ -68,7 +68,7 @@ def get_compiledir():
...
@@ -68,7 +68,7 @@ def get_compiledir():
class
CodeBlock
:
class
CodeBlock
:
"""
"""
WRITEME
Represents a computation unit composed of declare, behavior, and cleanup.
Represents a computation unit composed of declare, behavior, and cleanup.
@ivar declare: C code that declares variables for use by the computation
@ivar declare: C code that declares variables for use by the computation
@ivar behavior: C code that performs the computation
@ivar behavior: C code that performs the computation
...
@@ -94,11 +94,12 @@ class CodeBlock:
...
@@ -94,11 +94,12 @@ class CodeBlock:
def
failure_code
(
sub
):
def
failure_code
(
sub
):
"""WRITEME"""
return
"{
%(failure_var)
s =
%(id)
s; goto __label_
%(id)
i;}"
%
sub
return
"{
%(failure_var)
s =
%(id)
s; goto __label_
%(id)
i;}"
%
sub
def
code_gen
(
blocks
):
def
code_gen
(
blocks
):
"""
"""
WRITEME
From a list of L{CodeBlock} instances, returns a string that executes them
From a list of L{CodeBlock} instances, returns a string that executes them
all in sequence. eg for C{(decl1, task1, cleanup1)} and C{(decl2, task2, cleanup2)}
all in sequence. eg for C{(decl1, task1, cleanup1)} and C{(decl2, task2, cleanup2)}
the returned string will be of the form::
the returned string will be of the form::
...
@@ -126,7 +127,7 @@ def code_gen(blocks):
...
@@ -126,7 +127,7 @@ def code_gen(blocks):
def
struct_gen
(
args
,
struct_builders
,
blocks
,
sub
):
def
struct_gen
(
args
,
struct_builders
,
blocks
,
sub
):
"""
"""
WRITEME
Generates a struct conforming to the following specifications:
Generates a struct conforming to the following specifications:
* args -> all of the PyObject* type, stored in the struct
* args -> all of the PyObject* type, stored in the struct
they represent the storage and must be length 1 python lists.
they represent the storage and must be length 1 python lists.
...
@@ -253,16 +254,18 @@ def struct_gen(args, struct_builders, blocks, sub):
...
@@ -253,16 +254,18 @@ def struct_gen(args, struct_builders, blocks, sub):
# with handling of the py_<name> variable.
# with handling of the py_<name> variable.
def
get_nothing
(
r
,
name
,
sub
):
def
get_nothing
(
r
,
name
,
sub
):
""
""
"WRITEME"""
return
""
return
""
def
get_c_declare
(
r
,
name
,
sub
):
def
get_c_declare
(
r
,
name
,
sub
):
"""WRITEME"""
pre
=
"""
pre
=
"""
PyObject* py_
%(name)
s;
PyObject* py_
%(name)
s;
"""
%
locals
()
"""
%
locals
()
return
pre
+
r
.
type
.
c_declare
(
name
,
sub
)
return
pre
+
r
.
type
.
c_declare
(
name
,
sub
)
def
get_c_init
(
r
,
name
,
sub
):
def
get_c_init
(
r
,
name
,
sub
):
"""WRITEME"""
pre
=
""
"""
pre
=
""
"""
py_
%(name)
s = Py_None;
py_
%(name)
s = Py_None;
Py_XINCREF(py_
%(name)
s);
Py_XINCREF(py_
%(name)
s);
...
@@ -270,6 +273,7 @@ def get_c_init(r, name, sub):
...
@@ -270,6 +273,7 @@ def get_c_init(r, name, sub):
return
pre
+
r
.
type
.
c_init
(
name
,
sub
)
return
pre
+
r
.
type
.
c_init
(
name
,
sub
)
def
get_c_extract
(
r
,
name
,
sub
):
def
get_c_extract
(
r
,
name
,
sub
):
"""WRITEME"""
pre
=
"""
pre
=
"""
py_
%(name)
s = PyList_GET_ITEM(storage_
%(name)
s, 0);
py_
%(name)
s = PyList_GET_ITEM(storage_
%(name)
s, 0);
Py_XINCREF(py_
%(name)
s);
Py_XINCREF(py_
%(name)
s);
...
@@ -277,12 +281,14 @@ def get_c_extract(r, name, sub):
...
@@ -277,12 +281,14 @@ def get_c_extract(r, name, sub):
return
pre
+
r
.
type
.
c_extract
(
name
,
sub
)
return
pre
+
r
.
type
.
c_extract
(
name
,
sub
)
def
get_c_cleanup
(
r
,
name
,
sub
):
def
get_c_cleanup
(
r
,
name
,
sub
):
"""WRITEME"""
post
=
"""
post
=
"""
Py_XDECREF(py_
%(name)
s);
Py_XDECREF(py_
%(name)
s);
"""
%
locals
()
"""
%
locals
()
return
r
.
type
.
c_cleanup
(
name
,
sub
)
+
post
return
r
.
type
.
c_cleanup
(
name
,
sub
)
+
post
def
get_c_sync
(
r
,
name
,
sub
):
def
get_c_sync
(
r
,
name
,
sub
):
"""WRITEME"""
return
"""
return
"""
if (!
%(failure_var)
s) {
if (!
%(failure_var)
s) {
%(sync)
s
%(sync)
s
...
@@ -294,7 +300,7 @@ def get_c_sync(r, name, sub):
...
@@ -294,7 +300,7 @@ def get_c_sync(r, name, sub):
"""
%
dict
(
sync
=
r
.
type
.
c_sync
(
name
,
sub
),
name
=
name
,
**
sub
)
"""
%
dict
(
sync
=
r
.
type
.
c_sync
(
name
,
sub
),
name
=
name
,
**
sub
)
def
apply_policy
(
policy
,
r
,
name
,
sub
):
def
apply_policy
(
policy
,
r
,
name
,
sub
):
"""
"""
WRITEME
@param policy: list of functions that map a L{Result} to a string, or a single such function
@param policy: list of functions that map a L{Result} to a string, or a single such function
@type r: L{Result}
@type r: L{Result}
@return: C{policy[0](r) + policy[1](r) + ...}
@return: C{policy[0](r) + policy[1](r) + ...}
...
@@ -309,7 +315,7 @@ def apply_policy(policy, r, name, sub):
...
@@ -309,7 +315,7 @@ def apply_policy(policy, r, name, sub):
def
struct_result_codeblocks
(
result
,
policies
,
id
,
symbol_table
,
sub
):
def
struct_result_codeblocks
(
result
,
policies
,
id
,
symbol_table
,
sub
):
"""
"""
WRITEME
result -> a Result
result -> a Result
policies -> a pair of tuples ((declare_policy, behavior_policy, cleanup_policy), -- at construction
policies -> a pair of tuples ((declare_policy, behavior_policy, cleanup_policy), -- at construction
(declare_policy, behavior_policy, cleanup_policy)) -- at execution
(declare_policy, behavior_policy, cleanup_policy)) -- at execution
...
@@ -339,7 +345,7 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
...
@@ -339,7 +345,7 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
class
CLinker
(
link
.
Linker
):
class
CLinker
(
link
.
Linker
):
"""
"""
WRITEME
Creates C code for an env, compiles it and returns callables
Creates C code for an env, compiles it and returns callables
through make_thunk and make_function that make use of the compiled
through make_thunk and make_function that make use of the compiled
...
@@ -354,6 +360,7 @@ class CLinker(link.Linker):
...
@@ -354,6 +360,7 @@ class CLinker(link.Linker):
self
.
env
=
None
self
.
env
=
None
def
accept
(
self
,
env
,
no_recycling
=
[]):
def
accept
(
self
,
env
,
no_recycling
=
[]):
"""WRITEME"""
if
self
.
env
is
not
None
and
self
.
env
is
not
env
:
if
self
.
env
is
not
None
and
self
.
env
is
not
env
:
return
type
(
self
)()
.
accept
(
env
,
no_recycling
)
return
type
(
self
)()
.
accept
(
env
,
no_recycling
)
#raise Exception("Cannot accept from a Linker that is already tied to another Env.")
#raise Exception("Cannot accept from a Linker that is already tied to another Env.")
...
@@ -363,7 +370,7 @@ class CLinker(link.Linker):
...
@@ -363,7 +370,7 @@ class CLinker(link.Linker):
return
self
return
self
def
fetch_results
(
self
):
def
fetch_results
(
self
):
"""
"""
WRITEME
Fills the inputs, outputs, results, orphans, temps and node_order fields.
Fills the inputs, outputs, results, orphans, temps and node_order fields.
"""
"""
env
=
self
.
env
env
=
self
.
env
...
@@ -376,7 +383,7 @@ class CLinker(link.Linker):
...
@@ -376,7 +383,7 @@ class CLinker(link.Linker):
self
.
node_order
=
env
.
toposort
()
self
.
node_order
=
env
.
toposort
()
def
code_gen
(
self
):
def
code_gen
(
self
):
"""
"""
WRITEME
Generates code for a struct that does the computation of the env and
Generates code for a struct that does the computation of the env and
stores it in the struct_code field of the instance.
stores it in the struct_code field of the instance.
...
@@ -542,7 +549,7 @@ class CLinker(link.Linker):
...
@@ -542,7 +549,7 @@ class CLinker(link.Linker):
return
self
.
struct_code
return
self
.
struct_code
def
support_code
(
self
):
def
support_code
(
self
):
"""
"""
WRITEME
Returns a list of support code strings that are needed by
Returns a list of support code strings that are needed by
one or more Results or Ops. The support code from Results is
one or more Results or Ops. The support code from Results is
added before the support code from Ops.
added before the support code from Ops.
...
@@ -556,7 +563,7 @@ class CLinker(link.Linker):
...
@@ -556,7 +563,7 @@ class CLinker(link.Linker):
return
ret
return
ret
def
compile_args
(
self
):
def
compile_args
(
self
):
"""
"""
WRITEME
Returns a list of compile args that are needed by one
Returns a list of compile args that are needed by one
or more Results or Ops.
or more Results or Ops.
...
@@ -569,7 +576,7 @@ class CLinker(link.Linker):
...
@@ -569,7 +576,7 @@ class CLinker(link.Linker):
return
ret
return
ret
def
headers
(
self
):
def
headers
(
self
):
"""
"""
WRITEME
Returns a list of headers that are needed by one
Returns a list of headers that are needed by one
or more Results or Ops.
or more Results or Ops.
...
@@ -582,7 +589,7 @@ class CLinker(link.Linker):
...
@@ -582,7 +589,7 @@ class CLinker(link.Linker):
return
ret
return
ret
def
libraries
(
self
):
def
libraries
(
self
):
"""
"""
WRITEME
Returns a list of libraries that are needed by one
Returns a list of libraries that are needed by one
or more Results or Ops.
or more Results or Ops.
...
@@ -595,7 +602,7 @@ class CLinker(link.Linker):
...
@@ -595,7 +602,7 @@ class CLinker(link.Linker):
return
ret
return
ret
def
__compile__
(
self
,
input_storage
=
None
,
output_storage
=
None
):
def
__compile__
(
self
,
input_storage
=
None
,
output_storage
=
None
):
"""
"""
WRITEME
Compiles this linker's env.
Compiles this linker's env.
@type input_storage: list or None
@type input_storage: list or None
...
@@ -629,7 +636,7 @@ class CLinker(link.Linker):
...
@@ -629,7 +636,7 @@ class CLinker(link.Linker):
error_storage
error_storage
def
make_thunk
(
self
,
input_storage
=
None
,
output_storage
=
None
):
def
make_thunk
(
self
,
input_storage
=
None
,
output_storage
=
None
):
"""
"""
WRITEME
Compiles this linker's env and returns a function to perform the
Compiles this linker's env and returns a function to perform the
computations, as well as lists of storage cells for both the
computations, as well as lists of storage cells for both the
inputs and outputs.
inputs and outputs.
...
@@ -655,7 +662,7 @@ class CLinker(link.Linker):
...
@@ -655,7 +662,7 @@ class CLinker(link.Linker):
return
_execute
(
cthunk
,
self
.
init_tasks
,
self
.
tasks
,
error_storage
),
in_storage
,
out_storage
return
_execute
(
cthunk
,
self
.
init_tasks
,
self
.
tasks
,
error_storage
),
in_storage
,
out_storage
def
cthunk_factory
(
self
,
error_storage
,
in_storage
,
out_storage
):
def
cthunk_factory
(
self
,
error_storage
,
in_storage
,
out_storage
):
"""
"""
WRITEME
error_storage -> list of length 3
error_storage -> list of length 3
in_storage -> list of lists of length 1, one per input
in_storage -> list of lists of length 1, one per input
out_storage -> list of lists of length 1, one per output
out_storage -> list of lists of length 1, one per output
...
@@ -754,6 +761,7 @@ class CLinker(link.Linker):
...
@@ -754,6 +761,7 @@ class CLinker(link.Linker):
def
_execute
(
cthunk
,
init_tasks
,
tasks
,
error_storage
):
def
_execute
(
cthunk
,
init_tasks
,
tasks
,
error_storage
):
"""WRITEME"""
def
find_task
(
failure_code
):
def
find_task
(
failure_code
):
"""
"""
Maps a failure code to the task that is associated to it.
Maps a failure code to the task that is associated to it.
...
@@ -782,7 +790,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage):
...
@@ -782,7 +790,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage):
class
OpWiseCLinker
(
link
.
LocalLinker
):
class
OpWiseCLinker
(
link
.
LocalLinker
):
"""
"""
WRITEME
Uses CLinker on the individual Ops that comprise an env and loops
Uses CLinker on the individual Ops that comprise an env and loops
over them in Python. The result is slower than a compiled version of
over them in Python. The result is slower than a compiled version of
the whole env, but saves on compilation time because small changes
the whole env, but saves on compilation time because small changes
...
@@ -881,7 +889,7 @@ class OpWiseCLinker(link.LocalLinker):
...
@@ -881,7 +889,7 @@ class OpWiseCLinker(link.LocalLinker):
def
_default_checker
(
x
,
y
):
def
_default_checker
(
x
,
y
):
"""
"""
WRITEME
Default checker for DualLinker. This checks that the
Default checker for DualLinker. This checks that the
results contain the same data using ==.
results contain the same data using ==.
"""
"""
...
@@ -889,7 +897,7 @@ def _default_checker(x, y):
...
@@ -889,7 +897,7 @@ def _default_checker(x, y):
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
[
0
],
'clinker'
:
y
[
0
]})
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
[
0
],
'clinker'
:
y
[
0
]})
class
DualLinker
(
link
.
Linker
):
class
DualLinker
(
link
.
Linker
):
"""
"""
WRITEME
Runs the env in parallel using PerformLinker and CLinker.
Runs the env in parallel using PerformLinker and CLinker.
The thunk/function produced by DualLinker uses PerformLinker as the
The thunk/function produced by DualLinker uses PerformLinker as the
...
...
gof/destroyhandler.py
浏览文件 @
7f3bfb23
"""WRITEME"""
from
collections
import
defaultdict
from
collections
import
defaultdict
import
toolbox
import
toolbox
...
@@ -5,9 +6,12 @@ import graph
...
@@ -5,9 +6,12 @@ import graph
from
env
import
InconsistencyError
from
env
import
InconsistencyError
class
ProtocolError
(
Exception
):
pass
class
ProtocolError
(
Exception
):
"""WRITEME"""
pass
class
DestroyHandler
(
toolbox
.
Bookkeeper
):
class
DestroyHandler
(
toolbox
.
Bookkeeper
):
"""WRITEME"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
map
=
{}
self
.
map
=
{}
...
@@ -36,6 +40,8 @@ class DestroyHandler(toolbox.Bookkeeper):
...
@@ -36,6 +40,8 @@ class DestroyHandler(toolbox.Bookkeeper):
class
DestroyHandlerHelper2
(
toolbox
.
Bookkeeper
):
class
DestroyHandlerHelper2
(
toolbox
.
Bookkeeper
):
"""WRITEME"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
env
=
None
self
.
env
=
None
...
...
gof/env.py
浏览文件 @
7f3bfb23
"""WRITEME"""
from
copy
import
copy
from
copy
import
copy
import
graph
import
graph
import
utils
import
utils
...
@@ -15,7 +17,7 @@ class InconsistencyError(Exception):
...
@@ -15,7 +17,7 @@ class InconsistencyError(Exception):
class
Env
(
utils
.
object2
):
class
Env
(
utils
.
object2
):
"""
"""
WRITEME
An Env represents a subgraph bound by a set of input results and a
An Env represents a subgraph bound by a set of input results and a
set of output results. The inputs list should contain all the inputs
set of output results. The inputs list should contain all the inputs
on which the outputs depend. Results of type Value or Constant are
on which the outputs depend. Results of type Value or Constant are
...
@@ -35,6 +37,8 @@ class Env(utils.object2):
...
@@ -35,6 +37,8 @@ class Env(utils.object2):
"""
"""
Create an Env which operates on the subgraph bound by the inputs and outputs
Create an Env which operates on the subgraph bound by the inputs and outputs
sets.
sets.
WRITEME
"""
"""
self
.
_features
=
[]
self
.
_features
=
[]
...
@@ -79,7 +83,7 @@ class Env(utils.object2):
...
@@ -79,7 +83,7 @@ class Env(utils.object2):
node
.
deps
=
{}
node
.
deps
=
{}
def
disown
(
self
):
def
disown
(
self
):
"""
"""
WRITEME
Cleans up all of this Env's nodes and results so they are not
Cleans up all of this Env's nodes and results so they are not
associated with this Env anymore.
associated with this Env anymore.
...
@@ -104,11 +108,12 @@ class Env(utils.object2):
...
@@ -104,11 +108,12 @@ class Env(utils.object2):
### clients ###
### clients ###
def
clients
(
self
,
r
):
def
clients
(
self
,
r
):
"Set of all the (node, i) pairs such that node.inputs[i] is r."
"""WRITEME
Set of all the (node, i) pairs such that node.inputs[i] is r."""
return
r
.
clients
return
r
.
clients
def
__add_clients__
(
self
,
r
,
new_clients
):
def
__add_clients__
(
self
,
r
,
new_clients
):
"""
"""
WRITEME
r -> result
r -> result
new_clients -> list of (node, i) pairs such that node.inputs[i] is r.
new_clients -> list of (node, i) pairs such that node.inputs[i] is r.
...
@@ -117,7 +122,7 @@ class Env(utils.object2):
...
@@ -117,7 +122,7 @@ class Env(utils.object2):
r
.
clients
+=
new_clients
r
.
clients
+=
new_clients
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
):
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
):
"""
"""
WRITEME
r -> result
r -> result
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
...
@@ -213,7 +218,7 @@ class Env(utils.object2):
...
@@ -213,7 +218,7 @@ class Env(utils.object2):
### change input ###
### change input ###
def
change_input
(
self
,
node
,
i
,
new_r
):
def
change_input
(
self
,
node
,
i
,
new_r
):
"""
"""
WRITEME
Changes node.inputs[i] to new_r.
Changes node.inputs[i] to new_r.
new_r.type == old_r.type must be True, where old_r is the
new_r.type == old_r.type must be True, where old_r is the
...
@@ -246,7 +251,7 @@ class Env(utils.object2):
...
@@ -246,7 +251,7 @@ class Env(utils.object2):
### replace ###
### replace ###
def
replace
(
self
,
r
,
new_r
):
def
replace
(
self
,
r
,
new_r
):
"""
"""
WRITEME
This is the main interface to manipulate the subgraph in Env.
This is the main interface to manipulate the subgraph in Env.
For every node that uses r as input, makes it use new_r instead.
For every node that uses r as input, makes it use new_r instead.
"""
"""
...
@@ -264,6 +269,7 @@ class Env(utils.object2):
...
@@ -264,6 +269,7 @@ class Env(utils.object2):
self
.
change_input
(
node
,
i
,
new_r
)
self
.
change_input
(
node
,
i
,
new_r
)
def
replace_all
(
self
,
pairs
):
def
replace_all
(
self
,
pairs
):
"""WRITEME"""
for
r
,
new_r
in
pairs
:
for
r
,
new_r
in
pairs
:
self
.
replace
(
r
,
new_r
)
self
.
replace
(
r
,
new_r
)
...
@@ -271,7 +277,7 @@ class Env(utils.object2):
...
@@ -271,7 +277,7 @@ class Env(utils.object2):
### features ###
### features ###
def
extend
(
self
,
feature
):
def
extend
(
self
,
feature
):
"""
"""
WRITEME
Adds a feature to this env. The feature may define one
Adds a feature to this env. The feature may define one
or more of the following methods:
or more of the following methods:
...
@@ -310,7 +316,7 @@ class Env(utils.object2):
...
@@ -310,7 +316,7 @@ class Env(utils.object2):
self
.
_features
.
append
(
feature
)
self
.
_features
.
append
(
feature
)
def
remove_feature
(
self
,
feature
):
def
remove_feature
(
self
,
feature
):
"""
"""
WRITEME
Removes the feature from the graph.
Removes the feature from the graph.
Calls feature.on_detach(env) if an on_detach method is defined.
Calls feature.on_detach(env) if an on_detach method is defined.
...
@@ -327,7 +333,7 @@ class Env(utils.object2):
...
@@ -327,7 +333,7 @@ class Env(utils.object2):
### callback utils ###
### callback utils ###
def
execute_callbacks
(
self
,
name
,
*
args
):
def
execute_callbacks
(
self
,
name
,
*
args
):
"""
"""
WRITEME
Calls
Calls
getattr(feature, name)(*args)
getattr(feature, name)(*args)
for each feature which has a method called after name.
for each feature which has a method called after name.
...
@@ -340,7 +346,7 @@ class Env(utils.object2):
...
@@ -340,7 +346,7 @@ class Env(utils.object2):
fn
(
self
,
*
args
)
fn
(
self
,
*
args
)
def
collect_callbacks
(
self
,
name
,
*
args
):
def
collect_callbacks
(
self
,
name
,
*
args
):
"""
"""
WRITEME
Returns a dictionary d such that:
Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args)
d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name.
For each feature which has a method called after name.
...
@@ -358,7 +364,7 @@ class Env(utils.object2):
...
@@ -358,7 +364,7 @@ class Env(utils.object2):
### misc ###
### misc ###
def
toposort
(
self
):
def
toposort
(
self
):
"""
"""
WRITEME
Returns an ordering of the graph's Apply nodes such that:
Returns an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node.
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
- Satisfies the orderings provided by each feature that has
...
@@ -379,7 +385,7 @@ class Env(utils.object2):
...
@@ -379,7 +385,7 @@ class Env(utils.object2):
return
order
return
order
def
nclients
(
self
,
r
):
def
nclients
(
self
,
r
):
"
Same as len(self.clients(r)).
"
"
""WRITEME Same as len(self.clients(r)).""
"
return
len
(
self
.
clients
(
r
))
return
len
(
self
.
clients
(
r
))
# def edge(self, r):
# def edge(self, r):
...
@@ -395,7 +401,7 @@ class Env(utils.object2):
...
@@ -395,7 +401,7 @@ class Env(utils.object2):
# return node.inputs
# return node.inputs
def
check_integrity
(
self
):
def
check_integrity
(
self
):
"""
"""
WRITEME
Call this for a diagnosis if things go awry.
Call this for a diagnosis if things go awry.
"""
"""
nodes
=
graph
.
ops
(
self
.
inputs
,
self
.
outputs
)
nodes
=
graph
.
ops
(
self
.
inputs
,
self
.
outputs
)
...
@@ -438,9 +444,11 @@ class Env(utils.object2):
...
@@ -438,9 +444,11 @@ class Env(utils.object2):
### clone ###
### clone ###
def
clone
(
self
):
def
clone
(
self
):
"""WRITEME"""
return
self
.
clone_get_equiv
()[
0
]
return
self
.
clone_get_equiv
()[
0
]
def
clone_get_equiv
(
self
):
def
clone_get_equiv
(
self
):
"""WRITEME"""
equiv
=
graph
.
clone_get_equiv
(
self
.
inputs
,
self
.
outputs
)
equiv
=
graph
.
clone_get_equiv
(
self
.
inputs
,
self
.
outputs
)
self
.
check_integrity
()
self
.
check_integrity
()
e
=
Env
([
equiv
[
i
]
for
i
in
self
.
inputs
],
e
=
Env
([
equiv
[
i
]
for
i
in
self
.
inputs
],
...
...
gof/graph.py
浏览文件 @
7f3bfb23
"""Node classes (Apply, Result) and expression graph algorithms."""
from
copy
import
copy
from
copy
import
copy
from
collections
import
deque
from
collections
import
deque
...
@@ -137,11 +137,32 @@ class Apply(utils.object2):
...
@@ -137,11 +137,32 @@ class Apply(utils.object2):
class
Result
(
utils
.
object2
):
class
Result
(
utils
.
object2
):
"""
"""
Represents the result of some computation (pointed to by its owner field),
A variable in a theano expression graph.
or an input to the graph (if owner is None)
A Result which is the output of a symbolic computation has a reference to the Apply
instance to which it belongs (property: owner) and the position of itself in the owner's
output list (property: index).
A Result which is not the output of a symbolic computation will have an owner == None.
"""
"""
#__slots__ = ['type', 'owner', 'index', 'name']
#__slots__ = ['type', 'owner', 'index', 'name']
def
__init__
(
self
,
type
,
owner
=
None
,
index
=
None
,
name
=
None
):
def
__init__
(
self
,
type
,
owner
=
None
,
index
=
None
,
name
=
None
):
"""Initialize type, owner, index, name.
@type type: a Type instance
@param type: the type governs the kind of data that can be associated with this
variable
@type owner: None or Apply instance
@param owner: the Apply instance which computes the value for this variable
@type index: None or int
@param index: the position of this Result in owner.outputs
@type name: None or str
@param name: a string for pretty-printing and debugging
"""
self
.
tag
=
utils
.
scratchpad
()
self
.
tag
=
utils
.
scratchpad
()
self
.
type
=
type
self
.
type
=
type
if
owner
is
not
None
and
not
isinstance
(
owner
,
Apply
):
if
owner
is
not
None
and
not
isinstance
(
owner
,
Apply
):
...
@@ -167,6 +188,14 @@ class Result(utils.object2):
...
@@ -167,6 +188,14 @@ class Result(utils.object2):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
str
(
self
)
return
str
(
self
)
def
clone
(
self
):
def
clone
(
self
):
"""Return a new Result like self.
@rtype: Result instance
@return: a new Result instance (or subclass instance) with no owner or index.
@note: tags are copied to the returned instance.
@note: name is copied to the returned instance.
"""
#return copy(self)
#return copy(self)
cp
=
self
.
__class__
(
self
.
type
,
None
,
None
,
self
.
name
)
cp
=
self
.
__class__
(
self
.
type
,
None
,
None
,
self
.
name
)
cp
.
tag
=
copy
(
self
.
tag
)
cp
.
tag
=
copy
(
self
.
tag
)
...
@@ -174,13 +203,18 @@ class Result(utils.object2):
...
@@ -174,13 +203,18 @@ class Result(utils.object2):
class
Value
(
Result
):
class
Value
(
Result
):
"""
"""
Result with a data field. The data field is filtered by what is
Result with a default 'data' field.
The data field is filtered by what is
provided in the constructor for the Value's type field.
provided in the constructor for the Value's type field.
Its owner field is always None.
Its owner field is always None.
"""
"""
#__slots__ = ['data']
#__slots__ = ['data']
def
__init__
(
self
,
type
,
data
,
name
=
None
):
def
__init__
(
self
,
type
,
data
,
name
=
None
):
"""Initialize self.
WRITEME
"""
Result
.
__init__
(
self
,
type
,
None
,
None
,
name
)
Result
.
__init__
(
self
,
type
,
None
,
None
,
name
)
self
.
data
=
type
.
filter
(
data
)
self
.
data
=
type
.
filter
(
data
)
def
__str__
(
self
):
def
__str__
(
self
):
...
@@ -188,6 +222,7 @@ class Value(Result):
...
@@ -188,6 +222,7 @@ class Value(Result):
return
self
.
name
return
self
.
name
return
"<"
+
str
(
self
.
data
)
+
">"
#+ "::" + str(self.type)
return
"<"
+
str
(
self
.
data
)
+
">"
#+ "::" + str(self.type)
def
clone
(
self
):
def
clone
(
self
):
"""WRITEME"""
return
self
.
__class__
(
self
.
type
,
copy
(
self
.
data
),
self
.
name
)
return
self
.
__class__
(
self
.
type
,
copy
(
self
.
data
),
self
.
name
)
def
__set_owner
(
self
,
value
):
def
__set_owner
(
self
,
value
):
if
value
is
not
None
:
if
value
is
not
None
:
...
@@ -218,7 +253,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
...
@@ -218,7 +253,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
"""Search through L{Result}s, either breadth- or depth-first
"""Search through L{Result}s, either breadth- or depth-first
@type start: deque
@type start: deque
@param start: search from these nodes
@param start: search from these nodes
@type explore:
function
@type explore:
callable
@param explore: when we get to a node, add explore(node) to the list of
@param explore: when we get to a node, add explore(node) to the list of
nodes to visit. This function should return a list, or None
nodes to visit. This function should return a list, or None
@rtype: list of L{Result}
@rtype: list of L{Result}
...
@@ -256,7 +291,8 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
...
@@ -256,7 +291,8 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
def
inputs
(
result_list
):
def
inputs
(
result_list
):
"""
"""Return the inputs required to compute the given Results.
@type result_list: list of L{Result}
@type result_list: list of L{Result}
@param result_list: output L{Result}s (from which to search backward through owners)
@param result_list: output L{Result}s (from which to search backward through owners)
@returns: the list of L{Result}s with no owner, in the order found by a
@returns: the list of L{Result}s with no owner, in the order found by a
...
@@ -275,7 +311,7 @@ def inputs(result_list):
...
@@ -275,7 +311,7 @@ def inputs(result_list):
def
results_and_orphans
(
i
,
o
):
def
results_and_orphans
(
i
,
o
):
"""
"""
WRITEME
"""
"""
def
expand
(
r
):
def
expand
(
r
):
if
r
.
owner
and
r
not
in
i
:
if
r
.
owner
and
r
not
in
i
:
...
@@ -288,7 +324,8 @@ def results_and_orphans(i, o):
...
@@ -288,7 +324,8 @@ def results_and_orphans(i, o):
def
ops
(
i
,
o
):
def
ops
(
i
,
o
):
"""
""" WRITEME
@type i: list
@type i: list
@param i: input L{Result}s
@param i: input L{Result}s
@type o: list
@type o: list
...
@@ -309,7 +346,8 @@ def ops(i, o):
...
@@ -309,7 +346,8 @@ def ops(i, o):
def
results
(
i
,
o
):
def
results
(
i
,
o
):
"""
""" WRITEME
@type i: list
@type i: list
@param i: input L{Result}s
@param i: input L{Result}s
@type o: list
@type o: list
...
@@ -323,7 +361,8 @@ def results(i, o):
...
@@ -323,7 +361,8 @@ def results(i, o):
def
orphans
(
i
,
o
):
def
orphans
(
i
,
o
):
"""
""" WRITEME
@type i: list
@type i: list
@param i: input L{Result}s
@param i: input L{Result}s
@type o: list
@type o: list
...
@@ -339,7 +378,8 @@ def orphans(i, o):
...
@@ -339,7 +378,8 @@ def orphans(i, o):
def
clone
(
i
,
o
,
copy_inputs
=
True
):
def
clone
(
i
,
o
,
copy_inputs
=
True
):
"""
""" WRITEME
@type i: list
@type i: list
@param i: input L{Result}s
@param i: input L{Result}s
@type o: list
@type o: list
...
@@ -355,7 +395,8 @@ def clone(i, o, copy_inputs = True):
...
@@ -355,7 +395,8 @@ def clone(i, o, copy_inputs = True):
def
clone_get_equiv
(
i
,
o
,
copy_inputs_and_orphans
=
True
):
def
clone_get_equiv
(
i
,
o
,
copy_inputs_and_orphans
=
True
):
"""
""" WRITEME
@type i: list
@type i: list
@param i: input L{Result}s
@param i: input L{Result}s
@type o: list
@type o: list
...
@@ -400,7 +441,8 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
...
@@ -400,7 +441,8 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
return
d
return
d
def
general_toposort
(
r_out
,
deps
,
debug_print
=
False
):
def
general_toposort
(
r_out
,
deps
,
debug_print
=
False
):
"""
""" WRITEME
@note: deps(i) should behave like a pure function (no funny business with
@note: deps(i) should behave like a pure function (no funny business with
internal state)
internal state)
...
@@ -446,6 +488,8 @@ def general_toposort(r_out, deps, debug_print = False):
...
@@ -446,6 +488,8 @@ def general_toposort(r_out, deps, debug_print = False):
def
io_toposort
(
i
,
o
,
orderings
=
{}):
def
io_toposort
(
i
,
o
,
orderings
=
{}):
"""WRITEME
"""
iset
=
set
(
i
)
iset
=
set
(
i
)
def
deps
(
obj
):
def
deps
(
obj
):
rval
=
[]
rval
=
[]
...
@@ -470,6 +514,7 @@ default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.op,
...
@@ -470,6 +514,7 @@ default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.op,
def
op_as_string
(
i
,
op
,
def
op_as_string
(
i
,
op
,
leaf_formatter
=
default_leaf_formatter
,
leaf_formatter
=
default_leaf_formatter
,
node_formatter
=
default_node_formatter
):
node_formatter
=
default_node_formatter
):
"""WRITEME"""
strs
=
as_string
(
i
,
op
.
inputs
,
leaf_formatter
,
node_formatter
)
strs
=
as_string
(
i
,
op
.
inputs
,
leaf_formatter
,
node_formatter
)
return
node_formatter
(
op
,
strs
)
return
node_formatter
(
op
,
strs
)
...
@@ -477,7 +522,8 @@ def op_as_string(i, op,
...
@@ -477,7 +522,8 @@ def op_as_string(i, op,
def
as_string
(
i
,
o
,
def
as_string
(
i
,
o
,
leaf_formatter
=
default_leaf_formatter
,
leaf_formatter
=
default_leaf_formatter
,
node_formatter
=
default_node_formatter
):
node_formatter
=
default_node_formatter
):
"""
"""WRITEME
@type i: list
@type i: list
@param i: input L{Result}s
@param i: input L{Result}s
@type o: list
@type o: list
...
@@ -549,6 +595,8 @@ def view_roots(r):
...
@@ -549,6 +595,8 @@ def view_roots(r):
"""
"""
Utility function that returns the leaves of a search through
Utility function that returns the leaves of a search through
consecutive view_map()s.
consecutive view_map()s.
WRITEME
"""
"""
owner
=
r
.
owner
owner
=
r
.
owner
if
owner
is
not
None
:
if
owner
is
not
None
:
...
...
gof/link.py
浏览文件 @
7f3bfb23
"""WRITEME"""
import
utils
import
utils
import
graph
import
graph
...
@@ -8,7 +8,7 @@ from copy import copy
...
@@ -8,7 +8,7 @@ from copy import copy
__excepthook
=
sys
.
excepthook
__excepthook
=
sys
.
excepthook
def
thunk_hook
(
type
,
value
,
trace
):
def
thunk_hook
(
type
,
value
,
trace
):
"""
"""
WRITEME
This function is meant to replace excepthook and do some
This function is meant to replace excepthook and do some
special work if the exception value has a __thunk_trace__
special work if the exception value has a __thunk_trace__
field. In that case, it retrieves the field, which should
field. In that case, it retrieves the field, which should
...
@@ -32,6 +32,7 @@ sys.excepthook = thunk_hook
...
@@ -32,6 +32,7 @@ sys.excepthook = thunk_hook
def
raise_with_op
(
op
,
exc_info
=
None
):
def
raise_with_op
(
op
,
exc_info
=
None
):
"""WRITEME"""
if
exc_info
is
None
:
if
exc_info
is
None
:
exc_info
=
sys
.
exc_info
()
exc_info
=
sys
.
exc_info
()
exc_type
,
exc_value
,
exc_trace
=
exc_info
exc_type
,
exc_value
,
exc_trace
=
exc_info
...
@@ -45,6 +46,7 @@ def raise_with_op(op, exc_info = None):
...
@@ -45,6 +46,7 @@ def raise_with_op(op, exc_info = None):
class
Linker
(
object
):
class
Linker
(
object
):
"""WRITEME"""
def
make_thunk
(
self
):
def
make_thunk
(
self
):
"""
"""
...
@@ -108,6 +110,7 @@ class Linker(object):
...
@@ -108,6 +110,7 @@ class Linker(object):
class
Filter
(
object
):
class
Filter
(
object
):
"""WRITEME"""
def
__init__
(
self
,
r
,
storage
,
readonly
=
False
,
strict
=
False
,
trace
=
()):
def
__init__
(
self
,
r
,
storage
,
readonly
=
False
,
strict
=
False
,
trace
=
()):
self
.
r
=
r
self
.
r
=
r
self
.
type
=
r
.
type
self
.
type
=
r
.
type
...
@@ -134,6 +137,7 @@ class Filter(object):
...
@@ -134,6 +137,7 @@ class Filter(object):
def
map_storage
(
env
,
order
,
input_storage
,
output_storage
):
def
map_storage
(
env
,
order
,
input_storage
,
output_storage
):
"""WRITEME"""
if
input_storage
is
None
:
if
input_storage
is
None
:
input_storage
=
[[
None
]
for
input
in
env
.
inputs
]
input_storage
=
[[
None
]
for
input
in
env
.
inputs
]
else
:
else
:
...
@@ -165,6 +169,7 @@ def map_storage(env, order, input_storage, output_storage):
...
@@ -165,6 +169,7 @@ def map_storage(env, order, input_storage, output_storage):
def
streamline
(
env
,
thunks
,
order
,
no_recycling
=
[],
profiler
=
None
):
def
streamline
(
env
,
thunks
,
order
,
no_recycling
=
[],
profiler
=
None
):
"""WRITEME"""
def
clear
():
def
clear
():
for
thunk
in
thunks
:
for
thunk
in
thunks
:
for
output
in
thunk
.
outputs
:
for
output
in
thunk
.
outputs
:
...
@@ -191,7 +196,7 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None):
...
@@ -191,7 +196,7 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None):
return
f
return
f
class
LocalLinker
(
Linker
):
class
LocalLinker
(
Linker
):
"""
"""
WRITEME
Useful base class for L{Linker}s which keep all nodes in the graph, and run a
Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node.
thunk associated with each node.
"""
"""
...
@@ -214,7 +219,7 @@ class LocalLinker(Linker):
...
@@ -214,7 +219,7 @@ class LocalLinker(Linker):
class
PerformLinker
(
LocalLinker
):
class
PerformLinker
(
LocalLinker
):
"""
"""
WRITEME
Basic L{Linker} subclass that calls the perform method on each L{Op} in
Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{Env} in the order given by L{Env.toposort}.
the L{Env} in the order given by L{Env.toposort}.
"""
"""
...
@@ -262,7 +267,7 @@ class PerformLinker(LocalLinker):
...
@@ -262,7 +267,7 @@ class PerformLinker(LocalLinker):
class
WrapLinker
(
Linker
):
class
WrapLinker
(
Linker
):
"""
"""
WRITEME
This class makes it easier to run several L{LocalLinker}s in parallel, and
This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run.
offers some control over how each thunk is run.
...
@@ -373,6 +378,7 @@ class WrapLinker(Linker):
...
@@ -373,6 +378,7 @@ class WrapLinker(Linker):
import
time
import
time
class
Stats
:
class
Stats
:
"""WRITEME"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
ncalls
=
0
self
.
ncalls
=
0
self
.
time
=
0
self
.
time
=
0
...
@@ -384,7 +390,7 @@ class Stats:
...
@@ -384,7 +390,7 @@ class Stats:
def
inc_time_failures
(
self
,
v
):
self
.
time_failures
+=
v
def
inc_time_failures
(
self
,
v
):
self
.
time_failures
+=
v
class
Profiler
:
class
Profiler
:
"""
"""
WRITEME
Collects performance statistics on a function on a per-L{Op}
Collects performance statistics on a function on a per-L{Op}
or per-L{Op}-class basis.
or per-L{Op}-class basis.
"""
"""
...
@@ -404,6 +410,7 @@ class Profiler:
...
@@ -404,6 +410,7 @@ class Profiler:
self
.
by_class
=
by_class
self
.
by_class
=
by_class
def
profile_env
(
self
,
f
,
env
):
def
profile_env
(
self
,
f
,
env
):
"""WRITEME"""
stats
=
self
.
stats
.
setdefault
(
'TOTAL'
,
Stats
())
stats
=
self
.
stats
.
setdefault
(
'TOTAL'
,
Stats
())
n
,
t
=
stats
.
inc_ncalls
,
stats
.
inc_time
n
,
t
=
stats
.
inc_ncalls
,
stats
.
inc_time
failed
=
False
failed
=
False
...
@@ -423,6 +430,7 @@ class Profiler:
...
@@ -423,6 +430,7 @@ class Profiler:
raise
ety
,
eva
,
etr
raise
ety
,
eva
,
etr
def
profile_op
(
self
,
f
,
op
):
def
profile_op
(
self
,
f
,
op
):
"""WRITEME"""
if
self
.
by_class
:
if
self
.
by_class
:
entry
=
op
.
__class__
entry
=
op
.
__class__
else
:
else
:
...
@@ -449,6 +457,7 @@ class Profiler:
...
@@ -449,6 +457,7 @@ class Profiler:
def
print_stats
(
self
,
sort_by
=
'time'
):
def
print_stats
(
self
,
sort_by
=
'time'
):
"""WRITEME"""
def
compare_fn
((
op1
,
stat1
),
(
op2
,
stat2
)):
def
compare_fn
((
op1
,
stat1
),
(
op2
,
stat2
)):
x1
=
getattr
(
stat2
,
sort_by
)
x1
=
getattr
(
stat2
,
sort_by
)
...
...
gof/op.py
浏览文件 @
7f3bfb23
...
@@ -11,6 +11,7 @@ class Op(utils.object2):
...
@@ -11,6 +11,7 @@ class Op(utils.object2):
default_output
=
None
default_output
=
None
"""@todo
"""@todo
WRITEME
"""
"""
#############
#############
...
...
gof/opt.py
浏览文件 @
7f3bfb23
...
@@ -15,14 +15,14 @@ from collections import deque
...
@@ -15,14 +15,14 @@ from collections import deque
class
Optimizer
:
class
Optimizer
:
"""
"""
WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it.
An L{Optimizer} can be applied to an L{Env} to transform it.
It can represent an optimization or in general any kind
It can represent an optimization or in general any kind
of transformation you could apply to an L{Env}.
of transformation you could apply to an L{Env}.
"""
"""
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
"""
"""
WRITEME
Applies the optimization to the provided L{Env}. It may use all
Applies the optimization to the provided L{Env}. It may use all
the methods defined by the L{Env}. If the L{Optimizer} needs
the methods defined by the L{Env}. If the L{Optimizer} needs
to use a certain tool, such as an L{InstanceFinder}, it can do
to use a certain tool, such as an L{InstanceFinder}, it can do
...
@@ -31,7 +31,7 @@ class Optimizer:
...
@@ -31,7 +31,7 @@ class Optimizer:
pass
pass
def
optimize
(
self
,
env
,
*
args
,
**
kwargs
):
def
optimize
(
self
,
env
,
*
args
,
**
kwargs
):
"""
"""
WRITEME
This is meant as a shortcut to::
This is meant as a shortcut to::
opt.add_requirements(env)
opt.add_requirements(env)
opt.apply(env)
opt.apply(env)
...
@@ -40,13 +40,13 @@ class Optimizer:
...
@@ -40,13 +40,13 @@ class Optimizer:
self
.
apply
(
env
,
*
args
,
**
kwargs
)
self
.
apply
(
env
,
*
args
,
**
kwargs
)
def
__call__
(
self
,
env
):
def
__call__
(
self
,
env
):
"""
"""
WRITEME
Same as self.optimize(env)
Same as self.optimize(env)
"""
"""
return
self
.
optimize
(
env
)
return
self
.
optimize
(
env
)
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
"""
"""
WRITEME
Add features to the env that are required to apply the optimization.
Add features to the env that are required to apply the optimization.
For example:
For example:
env.extend(History())
env.extend(History())
...
@@ -57,29 +57,33 @@ class Optimizer:
...
@@ -57,29 +57,33 @@ class Optimizer:
class
FromFunctionOptimizer
(
Optimizer
):
class
FromFunctionOptimizer
(
Optimizer
):
"""WRITEME"""
def
__init__
(
self
,
fn
):
def
__init__
(
self
,
fn
):
self
.
apply
=
fn
self
.
apply
=
fn
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
"""WRITEME"""
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
)
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
)
def
optimizer
(
f
):
def
optimizer
(
f
):
"""WRITEME"""
return
FromFunctionOptimizer
(
f
)
return
FromFunctionOptimizer
(
f
)
class
SeqOptimizer
(
Optimizer
,
list
):
class
SeqOptimizer
(
Optimizer
,
list
):
"""
"""
WRITEME
Takes a list of L{Optimizer} instances and applies them
Takes a list of L{Optimizer} instances and applies them
sequentially.
sequentially.
"""
"""
def
__init__
(
self
,
*
opts
):
def
__init__
(
self
,
*
opts
):
"""WRITEME"""
if
len
(
opts
)
==
1
and
isinstance
(
opts
[
0
],
(
list
,
tuple
)):
if
len
(
opts
)
==
1
and
isinstance
(
opts
[
0
],
(
list
,
tuple
)):
opts
=
opts
[
0
]
opts
=
opts
[
0
]
self
[:]
=
opts
self
[:]
=
opts
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
"""
"""
WRITEME
Applies each L{Optimizer} in self in turn.
Applies each L{Optimizer} in self in turn.
"""
"""
for
optimizer
in
self
:
for
optimizer
in
self
:
...
@@ -94,6 +98,7 @@ class SeqOptimizer(Optimizer, list):
...
@@ -94,6 +98,7 @@ class SeqOptimizer(Optimizer, list):
class
_metadict
:
class
_metadict
:
"""WRITEME"""
# dict that accepts unhashable keys
# dict that accepts unhashable keys
# uses an associative list
# uses an associative list
# for internal use only
# for internal use only
...
@@ -130,7 +135,7 @@ class _metadict:
...
@@ -130,7 +135,7 @@ class _metadict:
class
MergeOptimizer
(
Optimizer
):
class
MergeOptimizer
(
Optimizer
):
"""
"""
WRITEME
Merges parts of the graph that are identical, i.e. parts that
Merges parts of the graph that are identical, i.e. parts that
take the same inputs and carry out the asme computations so we
take the same inputs and carry out the asme computations so we
can avoid doing them more than once. Also merges results that
can avoid doing them more than once. Also merges results that
...
@@ -184,7 +189,7 @@ class MergeOptimizer(Optimizer):
...
@@ -184,7 +189,7 @@ class MergeOptimizer(Optimizer):
def
MergeOptMerge
(
opt
):
def
MergeOptMerge
(
opt
):
"""
"""
WRITEME
Returns an Optimizer that merges the graph then applies the
Returns an Optimizer that merges the graph then applies the
optimizer in opt and then merges the graph again in case the
optimizer in opt and then merges the graph again in case the
opt introduced additional similarities.
opt introduced additional similarities.
...
@@ -199,22 +204,26 @@ def MergeOptMerge(opt):
...
@@ -199,22 +204,26 @@ def MergeOptMerge(opt):
########################
########################
class
LocalOptimizer
(
utils
.
object2
):
class
LocalOptimizer
(
utils
.
object2
):
"""WRITEME"""
def
transform
(
self
,
node
):
def
transform
(
self
,
node
):
raise
utils
.
AbstractFunctionError
()
raise
utils
.
AbstractFunctionError
()
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
"""WRITEME"""
def
__init__
(
self
,
fn
):
def
__init__
(
self
,
fn
):
self
.
transform
=
fn
self
.
transform
=
fn
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
)
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
)
def
local_optimizer
(
f
):
def
local_optimizer
(
f
):
"""WRITEME"""
return
FromFunctionLocalOptimizer
(
f
)
return
FromFunctionLocalOptimizer
(
f
)
class
LocalOptGroup
(
LocalOptimizer
):
class
LocalOptGroup
(
LocalOptimizer
):
"""WRITEME"""
def
__init__
(
self
,
*
optimizers
):
def
__init__
(
self
,
*
optimizers
):
self
.
opts
=
optimizers
self
.
opts
=
optimizers
...
@@ -229,6 +238,7 @@ class LocalOptGroup(LocalOptimizer):
...
@@ -229,6 +238,7 @@ class LocalOptGroup(LocalOptimizer):
class
LocalOpKeyOptGroup
(
LocalOptGroup
):
class
LocalOpKeyOptGroup
(
LocalOptGroup
):
"""WRITEME"""
def
__init__
(
self
,
optimizers
):
def
__init__
(
self
,
optimizers
):
if
any
(
not
hasattr
(
opt
,
'op_key'
),
optimizers
):
if
any
(
not
hasattr
(
opt
,
'op_key'
),
optimizers
):
...
@@ -240,7 +250,7 @@ class LocalOpKeyOptGroup(LocalOptGroup):
...
@@ -240,7 +250,7 @@ class LocalOpKeyOptGroup(LocalOptGroup):
class
OpSub
(
LocalOptimizer
):
class
OpSub
(
LocalOptimizer
):
"""
"""
WRITEME
Replaces the application of a certain op by the application of
Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing.
another op that take the same inputs as what they are replacing.
...
@@ -277,7 +287,7 @@ class OpSub(LocalOptimizer):
...
@@ -277,7 +287,7 @@ class OpSub(LocalOptimizer):
class
OpRemove
(
LocalOptimizer
):
class
OpRemove
(
LocalOptimizer
):
"""
"""
WRITEME
Removes all applications of an op by transferring each of its
Removes all applications of an op by transferring each of its
outputs to the corresponding input.
outputs to the corresponding input.
"""
"""
...
@@ -304,7 +314,7 @@ class OpRemove(LocalOptimizer):
...
@@ -304,7 +314,7 @@ class OpRemove(LocalOptimizer):
class
PatternSub
(
LocalOptimizer
):
class
PatternSub
(
LocalOptimizer
):
"""
"""
WRITEME
@todo update
@todo update
Replaces all occurrences of the input pattern by the output pattern:
Replaces all occurrences of the input pattern by the output pattern:
...
@@ -448,6 +458,7 @@ class PatternSub(LocalOptimizer):
...
@@ -448,6 +458,7 @@ class PatternSub(LocalOptimizer):
class
NavigatorOptimizer
(
Optimizer
):
class
NavigatorOptimizer
(
Optimizer
):
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
self
.
local_opt
=
local_opt
self
.
local_opt
=
local_opt
...
@@ -498,6 +509,7 @@ class NavigatorOptimizer(Optimizer):
...
@@ -498,6 +509,7 @@ class NavigatorOptimizer(Optimizer):
class
TopoOptimizer
(
NavigatorOptimizer
):
class
TopoOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
if
order
not
in
[
'out_to_in'
,
'in_to_out'
]:
if
order
not
in
[
'out_to_in'
,
'in_to_out'
]:
...
@@ -531,6 +543,7 @@ class TopoOptimizer(NavigatorOptimizer):
...
@@ -531,6 +543,7 @@ class TopoOptimizer(NavigatorOptimizer):
class
OpKeyOptimizer
(
NavigatorOptimizer
):
class
OpKeyOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
if
not
hasattr
(
local_opt
,
'op_key'
):
if
not
hasattr
(
local_opt
,
'op_key'
):
...
@@ -570,6 +583,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
...
@@ -570,6 +583,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
def
keep_going
(
exc
,
nav
,
repl_pairs
):
def
keep_going
(
exc
,
nav
,
repl_pairs
):
"""WRITEME"""
pass
pass
...
@@ -578,6 +592,7 @@ def keep_going(exc, nav, repl_pairs):
...
@@ -578,6 +592,7 @@ def keep_going(exc, nav, repl_pairs):
#################
#################
def
_check_chain
(
r
,
chain
):
def
_check_chain
(
r
,
chain
):
"""WRITEME"""
chain
=
list
(
reversed
(
chain
))
chain
=
list
(
reversed
(
chain
))
while
chain
:
while
chain
:
elem
=
chain
.
pop
()
elem
=
chain
.
pop
()
...
@@ -600,6 +615,7 @@ def _check_chain(r, chain):
...
@@ -600,6 +615,7 @@ def _check_chain(r, chain):
return
r
return
r
def
check_chain
(
r
,
*
chain
):
def
check_chain
(
r
,
*
chain
):
"""WRITEME"""
if
isinstance
(
r
,
graph
.
Apply
):
if
isinstance
(
r
,
graph
.
Apply
):
r
=
r
.
outputs
[
0
]
r
=
r
.
outputs
[
0
]
return
_check_chain
(
r
,
reduce
(
list
.
__iadd__
,
([
x
,
0
]
for
x
in
chain
)))
return
_check_chain
(
r
,
reduce
(
list
.
__iadd__
,
([
x
,
0
]
for
x
in
chain
)))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论