Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
14090d83
提交
14090d83
authored
5月 08, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
documentation and cleanup
上级
abd860d7
隐藏空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
336 行增加
和
887 行删除
+336
-887
_test_cc.py
gof/_test_cc.py
+5
-1
_test_ext.py
gof/_test_ext.py
+5
-3
_test_graph.py
gof/_test_graph.py
+4
-0
_test_link.py
gof/_test_link.py
+5
-1
_test_op.py
gof/_test_op.py
+4
-2
_test_opt.py
gof/_test_opt.py
+6
-1
_test_toolbox.py
gof/_test_toolbox.py
+6
-2
cc.py
gof/cc.py
+94
-104
env.py
gof/env.py
+84
-31
ext.py
gof/ext.py
+15
-85
graph.py
gof/graph.py
+2
-224
link.py
gof/link.py
+4
-7
op.py
gof/op.py
+22
-28
opt.py
gof/opt.py
+80
-203
toolbox.py
gof/toolbox.py
+0
-195
没有找到文件。
gof/_test_cc.py
浏览文件 @
14090d83
...
@@ -4,11 +4,15 @@ import unittest
...
@@ -4,11 +4,15 @@ import unittest
from
link
import
PerformLinker
,
Profiler
from
link
import
PerformLinker
,
Profiler
from
cc
import
*
from
cc
import
*
from
type
import
Type
from
type
import
Type
from
graph
import
Result
,
as_result
,
Apply
,
Constant
from
graph
import
Result
,
Apply
,
Constant
from
op
import
Op
from
op
import
Op
import
env
import
env
import
toolbox
import
toolbox
def
as_result
(
x
):
assert
isinstance
(
x
,
Result
)
return
x
class
TDouble
(
Type
):
class
TDouble
(
Type
):
def
filter
(
self
,
data
):
def
filter
(
self
,
data
):
return
float
(
data
)
return
float
(
data
)
...
...
gof/_test_ext.py
浏览文件 @
14090d83
...
@@ -3,18 +3,20 @@ import unittest
...
@@ -3,18 +3,20 @@ import unittest
from
type
import
Type
from
type
import
Type
import
graph
import
graph
from
graph
import
Result
,
as_result
,
Apply
from
graph
import
Result
,
Apply
from
op
import
Op
from
op
import
Op
from
opt
import
PatternOptimizer
,
OpSubOptimizer
from
opt
import
PatternOptimizer
,
OpSubOptimizer
from
ext
import
*
from
ext
import
*
from
env
import
Env
,
InconsistencyError
from
env
import
Env
,
InconsistencyError
#from toolbox import EquivTool
from
toolbox
import
ReplaceValidate
from
toolbox
import
ReplaceValidate
from
copy
import
copy
from
copy
import
copy
#from _test_result import MyResult
def
as_result
(
x
):
assert
isinstance
(
x
,
Result
)
return
x
class
MyType
(
Type
):
class
MyType
(
Type
):
...
...
gof/_test_graph.py
浏览文件 @
14090d83
...
@@ -15,6 +15,10 @@ else:
...
@@ -15,6 +15,10 @@ else:
realtestcase
=
unittest
.
TestCase
realtestcase
=
unittest
.
TestCase
def
as_result
(
x
):
assert
isinstance
(
x
,
Result
)
return
x
class
MyType
(
Type
):
class
MyType
(
Type
):
...
...
gof/_test_link.py
浏览文件 @
14090d83
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
unittest
import
unittest
import
graph
import
graph
from
graph
import
Result
,
as_result
,
Apply
,
Constant
from
graph
import
Result
,
Apply
,
Constant
from
type
import
Type
from
type
import
Type
from
op
import
Op
from
op
import
Op
import
env
import
env
...
@@ -14,6 +14,10 @@ from link import *
...
@@ -14,6 +14,10 @@ from link import *
#from _test_result import Double
#from _test_result import Double
def
as_result
(
x
):
assert
isinstance
(
x
,
Result
)
return
x
class
TDouble
(
Type
):
class
TDouble
(
Type
):
def
filter
(
self
,
data
):
def
filter
(
self
,
data
):
return
float
(
data
)
return
float
(
data
)
...
...
gof/_test_op.py
浏览文件 @
14090d83
...
@@ -3,9 +3,11 @@ import unittest
...
@@ -3,9 +3,11 @@ import unittest
from
copy
import
copy
from
copy
import
copy
from
op
import
*
from
op
import
*
from
type
import
Type
,
Generic
from
type
import
Type
,
Generic
from
graph
import
Apply
,
as_r
esult
from
graph
import
Apply
,
R
esult
#from result import Result
def
as_result
(
x
):
assert
isinstance
(
x
,
Result
)
return
x
class
MyType
(
Type
):
class
MyType
(
Type
):
...
...
gof/_test_opt.py
浏览文件 @
14090d83
import
unittest
import
unittest
from
graph
import
Result
,
as_result
,
Apply
,
Constant
from
type
import
Type
from
graph
import
Result
,
Apply
,
Constant
from
op
import
Op
from
op
import
Op
from
opt
import
*
from
opt
import
*
from
env
import
Env
from
env
import
Env
from
toolbox
import
*
from
toolbox
import
*
def
as_result
(
x
):
assert
isinstance
(
x
,
Result
)
return
x
class
MyType
(
Type
):
class
MyType
(
Type
):
...
...
gof/_test_toolbox.py
浏览文件 @
14090d83
import
unittest
import
unittest
from
graph
import
Result
,
as_result
,
Apply
from
graph
import
Result
,
Apply
from
type
import
Type
from
type
import
Type
from
op
import
Op
from
op
import
Op
#from opt import PatternOptimizer, OpSubOptimizer
from
env
import
Env
,
InconsistencyError
from
env
import
Env
,
InconsistencyError
from
toolbox
import
*
from
toolbox
import
*
def
as_result
(
x
):
assert
isinstance
(
x
,
Result
)
return
x
class
MyType
(
Type
):
class
MyType
(
Type
):
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
...
...
gof/cc.py
浏览文件 @
14090d83
"""
Defines Linkers that deal with C implementations.
"""
import
graph
from
graph
import
Constant
,
Value
# Python imports
from
link
import
Linker
,
LocalLinker
,
raise_with_op
,
Filter
,
map_storage
,
PerformLinker
from
copy
import
copy
from
copy
import
copy
from
utils
import
AbstractFunctionError
from
env
import
Env
import
md5
import
md5
import
sys
import
re
import
os
import
os
,
sys
,
platform
import
platform
# weave import
from
scipy
import
weave
from
scipy
import
weave
# gof imports
import
cutils
import
cutils
from
env
import
Env
import
graph
import
link
import
utils
import
utils
import
traceback
import
re
def
compile_dir
():
def
compile_dir
():
...
@@ -192,7 +196,7 @@ def struct_gen(args, struct_builders, blocks, sub):
...
@@ -192,7 +196,7 @@ def struct_gen(args, struct_builders, blocks, sub):
return
%(failure_var)
s;
return
%(failure_var)
s;
"""
%
sub
"""
%
sub
sub
=
copy
(
sub
)
sub
=
dict
(
sub
)
sub
.
update
(
locals
())
sub
.
update
(
locals
())
# TODO: add some error checking to make sure storage_<x> are 1-element lists
# TODO: add some error checking to make sure storage_<x> are 1-element lists
...
@@ -309,7 +313,7 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
...
@@ -309,7 +313,7 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
name
=
"V
%
i"
%
id
name
=
"V
%
i"
%
id
symbol_table
[
result
]
=
name
symbol_table
[
result
]
=
name
sub
=
copy
(
sub
)
sub
=
dict
(
sub
)
# sub['name'] = name
# sub['name'] = name
sub
[
'id'
]
=
id
sub
[
'id'
]
=
id
sub
[
'fail'
]
=
failure_code
(
sub
)
sub
[
'fail'
]
=
failure_code
(
sub
)
...
@@ -323,13 +327,16 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
...
@@ -323,13 +327,16 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
return
struct_builder
,
block
return
struct_builder
,
block
class
CLinker
(
Linker
):
class
CLinker
(
link
.
Linker
):
"""
"""
Creates C code for an env or an Op instance, compiles it and returns
callables through make_thunk and make_function that make use of the
Creates C code for an env, compiles it and returns callables
compiled code.
through make_thunk and make_function that make use of the compiled
code.
It can take an env or an Op as input.
no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
"""
"""
def
__init__
(
self
,
env
,
no_recycling
=
[]):
def
__init__
(
self
,
env
,
no_recycling
=
[]):
...
@@ -346,7 +353,7 @@ class CLinker(Linker):
...
@@ -346,7 +353,7 @@ class CLinker(Linker):
self
.
outputs
=
env
.
outputs
self
.
outputs
=
env
.
outputs
self
.
results
=
graph
.
results
(
self
.
inputs
,
self
.
outputs
)
# list(env.results)
self
.
results
=
graph
.
results
(
self
.
inputs
,
self
.
outputs
)
# list(env.results)
# The orphans field is listified to ensure a consistent order.
# The orphans field is listified to ensure a consistent order.
self
.
orphans
=
list
(
r
for
r
in
self
.
results
if
isinstance
(
r
,
Value
)
and
r
not
in
self
.
inputs
)
#list(env.orphans.difference(self.outputs))
self
.
orphans
=
list
(
r
for
r
in
self
.
results
if
isinstance
(
r
,
graph
.
Value
)
and
r
not
in
self
.
inputs
)
#list(env.orphans.difference(self.outputs))
self
.
temps
=
list
(
set
(
self
.
results
)
.
difference
(
self
.
inputs
)
.
difference
(
self
.
outputs
)
.
difference
(
self
.
orphans
))
self
.
temps
=
list
(
set
(
self
.
results
)
.
difference
(
self
.
inputs
)
.
difference
(
self
.
outputs
)
.
difference
(
self
.
orphans
))
self
.
node_order
=
env
.
toposort
()
self
.
node_order
=
env
.
toposort
()
...
@@ -404,15 +411,15 @@ class CLinker(Linker):
...
@@ -404,15 +411,15 @@ class CLinker(Linker):
policy
=
[[
get_nothing
,
get_nothing
,
get_nothing
],
policy
=
[[
get_nothing
,
get_nothing
,
get_nothing
],
[
get_c_declare
,
get_c_extract
,
get_c_cleanup
]]
[
get_c_declare
,
get_c_extract
,
get_c_cleanup
]]
elif
result
in
self
.
orphans
:
elif
result
in
self
.
orphans
:
if
not
isinstance
(
result
,
Value
):
if
not
isinstance
(
result
,
graph
.
Value
):
raise
TypeError
(
"All orphans to CLinker must be Value instances."
,
result
)
raise
TypeError
(
"All orphans to CLinker must be Value instances."
,
result
)
if
isinstance
(
result
,
Constant
):
if
isinstance
(
result
,
graph
.
Constant
):
try
:
try
:
symbol
[
result
]
=
"("
+
result
.
type
.
c_literal
(
result
.
data
)
+
")"
symbol
[
result
]
=
"("
+
result
.
type
.
c_literal
(
result
.
data
)
+
")"
consts
.
append
(
result
)
consts
.
append
(
result
)
self
.
orphans
.
remove
(
result
)
self
.
orphans
.
remove
(
result
)
continue
continue
except
(
AbstractFunctionError
,
NotImplementedError
):
except
(
utils
.
AbstractFunctionError
,
NotImplementedError
):
pass
pass
# orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same
# orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same
policy
=
[[
get_c_declare
,
get_c_extract
,
get_c_cleanup
],
policy
=
[[
get_c_declare
,
get_c_extract
,
get_c_cleanup
],
...
@@ -475,11 +482,11 @@ class CLinker(Linker):
...
@@ -475,11 +482,11 @@ class CLinker(Linker):
op
=
node
.
op
op
=
node
.
op
try
:
behavior
=
op
.
c_code
(
node
,
name
,
isyms
,
osyms
,
sub
)
try
:
behavior
=
op
.
c_code
(
node
,
name
,
isyms
,
osyms
,
sub
)
except
AbstractFunctionError
:
except
utils
.
AbstractFunctionError
:
raise
NotImplementedError
(
"
%
s cannot produce C code"
%
op
)
raise
NotImplementedError
(
"
%
s cannot produce C code"
%
op
)
try
:
cleanup
=
op
.
c_code_cleanup
(
node
,
name
,
isyms
,
osyms
,
sub
)
try
:
cleanup
=
op
.
c_code_cleanup
(
node
,
name
,
isyms
,
osyms
,
sub
)
except
AbstractFunctionError
:
except
utils
.
AbstractFunctionError
:
cleanup
=
""
cleanup
=
""
blocks
.
append
(
CodeBlock
(
""
,
behavior
,
cleanup
,
sub
))
blocks
.
append
(
CodeBlock
(
""
,
behavior
,
cleanup
,
sub
))
...
@@ -539,7 +546,7 @@ class CLinker(Linker):
...
@@ -539,7 +546,7 @@ class CLinker(Linker):
ret
=
[]
ret
=
[]
for
x
in
[
y
.
type
for
y
in
self
.
results
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
results
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
try
:
ret
.
append
(
x
.
c_support_code
())
try
:
ret
.
append
(
x
.
c_support_code
())
except
AbstractFunctionError
:
pass
except
utils
.
AbstractFunctionError
:
pass
return
ret
return
ret
def
compile_args
(
self
):
def
compile_args
(
self
):
...
@@ -552,7 +559,7 @@ class CLinker(Linker):
...
@@ -552,7 +559,7 @@ class CLinker(Linker):
ret
=
[]
ret
=
[]
for
x
in
[
y
.
type
for
y
in
self
.
results
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
results
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
try
:
ret
+=
x
.
c_compile_args
()
try
:
ret
+=
x
.
c_compile_args
()
except
AbstractFunctionError
:
pass
except
utils
.
AbstractFunctionError
:
pass
return
ret
return
ret
def
headers
(
self
):
def
headers
(
self
):
...
@@ -565,7 +572,7 @@ class CLinker(Linker):
...
@@ -565,7 +572,7 @@ class CLinker(Linker):
ret
=
[]
ret
=
[]
for
x
in
[
y
.
type
for
y
in
self
.
results
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
results
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
try
:
ret
+=
x
.
c_headers
()
try
:
ret
+=
x
.
c_headers
()
except
AbstractFunctionError
:
pass
except
utils
.
AbstractFunctionError
:
pass
return
ret
return
ret
def
libraries
(
self
):
def
libraries
(
self
):
...
@@ -578,25 +585,23 @@ class CLinker(Linker):
...
@@ -578,25 +585,23 @@ class CLinker(Linker):
ret
=
[]
ret
=
[]
for
x
in
[
y
.
type
for
y
in
self
.
results
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
results
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
try
:
ret
+=
x
.
c_libraries
()
try
:
ret
+=
x
.
c_libraries
()
except
AbstractFunctionError
:
pass
except
utils
.
AbstractFunctionError
:
pass
return
ret
return
ret
def
__compile__
(
self
,
input_storage
=
None
,
output_storage
=
None
):
def
__compile__
(
self
,
input_storage
=
None
,
output_storage
=
None
):
"""
"""
@todo update
Compiles this linker's env.
Compiles this linker's env. If inplace is True, it will use the
@type input_storage: list or None
Results contained in the env, if it is False it will copy the
@param input_storage: list of lists of length 1. In order to use
input and output Results.
the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated.
Returns: thunk, in_results, out_results, error_storage
@param output_storage: list of lists of length 1. The thunk returned
by __compile__ will put the results of the computation in these
lists. If None, storage will be allocated.
Returns: thunk, input_storage, output_storage, error_storage
"""
"""
# if inplace:
# in_results = self.inputs
# out_results = self.outputs
# else:
# in_results = [copy(input) for input in self.inputs]
# out_results = [copy(output) for output in self.outputs]
error_storage
=
[
None
,
None
,
None
]
error_storage
=
[
None
,
None
,
None
]
if
input_storage
is
None
:
if
input_storage
is
None
:
input_storage
=
tuple
([
None
]
for
result
in
self
.
inputs
)
input_storage
=
tuple
([
None
]
for
result
in
self
.
inputs
)
...
@@ -612,13 +617,34 @@ class CLinker(Linker):
...
@@ -612,13 +617,34 @@ class CLinker(Linker):
thunk
=
self
.
cthunk_factory
(
error_storage
,
thunk
=
self
.
cthunk_factory
(
error_storage
,
input_storage
,
input_storage
,
output_storage
)
output_storage
)
return
thunk
,
[
Filter
(
input
.
type
,
storage
)
for
input
,
storage
in
zip
(
self
.
env
.
inputs
,
input_storage
)],
\
return
thunk
,
\
[
Filter
(
output
.
type
,
storage
,
True
)
for
output
,
storage
in
zip
(
self
.
env
.
outputs
,
output_storage
)],
\
[
link
.
Filter
(
input
.
type
,
storage
)
for
input
,
storage
in
zip
(
self
.
env
.
inputs
,
input_storage
)],
\
[
link
.
Filter
(
output
.
type
,
storage
,
True
)
for
output
,
storage
in
zip
(
self
.
env
.
outputs
,
output_storage
)],
\
error_storage
error_storage
# return thunk, [Filter(x) for x in input_storage], [Filter(x) for x in output_storage], error_storage
def
make_thunk
(
self
,
input_storage
=
None
,
output_storage
=
None
):
def
make_thunk
(
self
,
input_storage
=
None
,
output_storage
=
None
):
"""
Compiles this linker's env and returns a function to perform the
computations, as well as lists of storage cells for both the
inputs and outputs.
@type input_storage: list or None
@param input_storage: list of lists of length 1. In order to use
the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated.
@param output_storage: list of lists of length 1. The thunk returned
by __compile__ will put the results of the computation in these
lists. If None, storage will be allocated.
Returns: thunk, input_storage, output_storage
The return values can be used as follows:
f, istor, ostor = clinker.make_thunk()
istor[0].data = first_input
istor[1].data = second_input
f()
first_output = ostor[0].data
"""
cthunk
,
in_storage
,
out_storage
,
error_storage
=
self
.
__compile__
(
input_storage
,
output_storage
)
cthunk
,
in_storage
,
out_storage
,
error_storage
=
self
.
__compile__
(
input_storage
,
output_storage
)
def
execute
():
def
execute
():
failure
=
cutils
.
run_cthunk
(
cthunk
)
failure
=
cutils
.
run_cthunk
(
cthunk
)
...
@@ -729,7 +755,7 @@ class CLinker(Linker):
...
@@ -729,7 +755,7 @@ class CLinker(Linker):
class
OpWiseCLinker
(
LocalLinker
):
class
OpWiseCLinker
(
link
.
LocalLinker
):
"""
"""
Uses CLinker on the individual Ops that comprise an env and loops
Uses CLinker on the individual Ops that comprise an env and loops
over them in Python. The result is slower than a compiled version of
over them in Python. The result is slower than a compiled version of
...
@@ -739,6 +765,10 @@ class OpWiseCLinker(LocalLinker):
...
@@ -739,6 +765,10 @@ class OpWiseCLinker(LocalLinker):
If fallback_on_perform is True, OpWiseCLinker will use an op's
If fallback_on_perform is True, OpWiseCLinker will use an op's
perform method if no C version can be generated.
perform method if no C version can be generated.
no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
"""
"""
def
__init__
(
self
,
env
,
fallback_on_perform
=
True
,
no_recycling
=
[]):
def
__init__
(
self
,
env
,
fallback_on_perform
=
True
,
no_recycling
=
[]):
...
@@ -756,7 +786,7 @@ class OpWiseCLinker(LocalLinker):
...
@@ -756,7 +786,7 @@ class OpWiseCLinker(LocalLinker):
order
=
env
.
toposort
()
order
=
env
.
toposort
()
no_recycling
=
self
.
no_recycling
no_recycling
=
self
.
no_recycling
input_storage
,
output_storage
,
storage_map
=
map_storage
(
env
,
order
,
input_storage
,
output_storage
)
input_storage
,
output_storage
,
storage_map
=
link
.
map_storage
(
env
,
order
,
input_storage
,
output_storage
)
thunks
=
[]
thunks
=
[]
for
node
in
order
:
for
node
in
order
:
...
@@ -772,7 +802,7 @@ class OpWiseCLinker(LocalLinker):
...
@@ -772,7 +802,7 @@ class OpWiseCLinker(LocalLinker):
thunk
.
inputs
=
node_input_storage
thunk
.
inputs
=
node_input_storage
thunk
.
outputs
=
node_output_storage
thunk
.
outputs
=
node_output_storage
thunks
.
append
(
thunk
)
thunks
.
append
(
thunk
)
except
(
NotImplementedError
,
AbstractFunctionError
):
except
(
NotImplementedError
,
utils
.
AbstractFunctionError
):
if
self
.
fallback_on_perform
:
if
self
.
fallback_on_perform
:
p
=
node
.
op
.
perform
p
=
node
.
op
.
perform
thunk
=
lambda
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
:
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
)
thunk
=
lambda
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
:
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
)
...
@@ -791,27 +821,8 @@ class OpWiseCLinker(LocalLinker):
...
@@ -791,27 +821,8 @@ class OpWiseCLinker(LocalLinker):
f
=
self
.
streamline
(
env
,
thunks
,
order
,
no_recycling
=
no_recycling
,
profiler
=
profiler
)
f
=
self
.
streamline
(
env
,
thunks
,
order
,
no_recycling
=
no_recycling
,
profiler
=
profiler
)
# if profiler is None:
return
f
,
[
link
.
Filter
(
input
.
type
,
storage
)
for
input
,
storage
in
zip
(
env
.
inputs
,
input_storage
)],
\
# def f():
[
link
.
Filter
(
output
.
type
,
storage
,
True
)
for
output
,
storage
in
zip
(
env
.
outputs
,
output_storage
)],
\
# for x in no_recycling:
# x[0] = None
# try:
# for thunk, node in zip(thunks, order):
# thunk()
# except:
# raise_with_op(node)
# else:
# def f():
# for x in no_recycling:
# x[0] = None
# def g():
# for thunk, node in zip(thunks, order):
# profiler.profile_op(thunk, node)
# profiler.profile_env(g, env)
# f.profiler = profiler
return
f
,
[
Filter
(
input
.
type
,
storage
)
for
input
,
storage
in
zip
(
env
.
inputs
,
input_storage
)],
\
[
Filter
(
output
.
type
,
storage
,
True
)
for
output
,
storage
in
zip
(
env
.
outputs
,
output_storage
)],
\
thunks
,
order
thunks
,
order
...
@@ -825,7 +836,7 @@ def _default_checker(x, y):
...
@@ -825,7 +836,7 @@ def _default_checker(x, y):
if
x
[
0
]
!=
y
[
0
]:
if
x
[
0
]
!=
y
[
0
]:
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
[
0
],
'clinker'
:
y
[
0
]})
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
[
0
],
'clinker'
:
y
[
0
]})
class
DualLinker
(
Linker
):
class
DualLinker
(
link
.
Linker
):
"""
"""
Runs the env in parallel using PerformLinker and CLinker.
Runs the env in parallel using PerformLinker and CLinker.
...
@@ -841,13 +852,13 @@ class DualLinker(Linker):
...
@@ -841,13 +852,13 @@ class DualLinker(Linker):
"""
"""
Initialize a DualLinker.
Initialize a DualLinker.
The checker argument must be a function that takes two
Result
The checker argument must be a function that takes two
lists
instances. The first one passed will be the output computed by
of length 1. The first one passed will contain the output
PerformLinker and the second one the output computed by
computed by PerformLinker and the second one the output
OpWiseCLinker. The checker should compare the data fields of
computed by OpWiseCLinker. The checker should compare the data
the two results to see if they match. By default, DualLinker
fields of the two results to see if they match. By default,
uses ==. A custom checker can be provided to compare up to a
DualLinker uses ==. A custom checker can be provided to
certain error tolerance.
c
ompare up to a c
ertain error tolerance.
If a mismatch occurs, the checker should raise an exception to
If a mismatch occurs, the checker should raise an exception to
halt the computation. If it does not, the computation will
halt the computation. If it does not, the computation will
...
@@ -855,35 +866,22 @@ class DualLinker(Linker):
...
@@ -855,35 +866,22 @@ class DualLinker(Linker):
the problem by fiddling with the data, but it should be
the problem by fiddling with the data, but it should be
careful not to share data between the two outputs (or inplace
careful not to share data between the two outputs (or inplace
operations that use them will interfere).
operations that use them will interfere).
no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
"""
"""
self
.
env
=
env
self
.
env
=
env
self
.
checker
=
checker
self
.
checker
=
checker
self
.
no_recycling
=
no_recycling
self
.
no_recycling
=
no_recycling
def
make_thunk
(
self
,
**
kwargs
):
def
make_thunk
(
self
,
**
kwargs
):
# if inplace:
# env1 = self.env
# else:
# env1 = self.env.clone(True)
# env2, equiv = env1.clone_get_equiv(True)
# op_order_1 = env1.toposort()
# op_order_2 = [equiv[op.outputs[0]].owner for op in op_order_1] # we need to have the exact same order so we can compare each step
# def c_make_thunk(op):
# try:
# return CLinker(op).make_thunk(True)[0]
# except AbstractFunctionError:
# return op.perform
# thunks1 = [op.perform for op in op_order_1]
# thunks2 = [c_make_thunk(op) for op in op_order_2]
env
=
self
.
env
env
=
self
.
env
no_recycling
=
self
.
no_recycling
no_recycling
=
self
.
no_recycling
_f
,
i1
,
o1
,
thunks1
,
order1
=
PerformLinker
(
env
,
no_recycling
=
no_recycling
)
.
make_all
(
**
kwargs
)
_f
,
i1
,
o1
,
thunks1
,
order1
=
link
.
PerformLinker
(
env
,
no_recycling
=
no_recycling
)
.
make_all
(
**
kwargs
)
_f
,
i2
,
o2
,
thunks2
,
order2
=
OpWiseCLinker
(
env
,
no_recycling
=
no_recycling
)
.
make_all
(
**
kwargs
)
_f
,
i2
,
o2
,
thunks2
,
order2
=
OpWiseCLinker
(
env
,
no_recycling
=
no_recycling
)
.
make_all
(
**
kwargs
)
def
f
():
def
f
():
for
input1
,
input2
in
zip
(
i1
,
i2
):
for
input1
,
input2
in
zip
(
i1
,
i2
):
...
@@ -903,15 +901,7 @@ class DualLinker(Linker):
...
@@ -903,15 +901,7 @@ class DualLinker(Linker):
for
output1
,
output2
in
zip
(
thunk1
.
outputs
,
thunk2
.
outputs
):
for
output1
,
output2
in
zip
(
thunk1
.
outputs
,
thunk2
.
outputs
):
self
.
checker
(
output1
,
output2
)
self
.
checker
(
output1
,
output2
)
except
:
except
:
raise_with_op
(
node1
)
link
.
raise_with_op
(
node1
)
# exc_type, exc_value, exc_trace = sys.exc_info()
# try:
# trace = op1.trace
# except AttributeError:
# trace = ()
# exc_value.__thunk_trace__ = trace
# exc_value.args = exc_value.args + (op1, )
# raise exc_type, exc_value, exc_trace
return
f
,
i1
,
o1
return
f
,
i1
,
o1
...
...
gof/env.py
浏览文件 @
14090d83
from
copy
import
copy
from
copy
import
copy
import
graph
import
graph
##from features import Listener, Orderings, Constraint, Tool, uniq_features
import
utils
import
utils
from
utils
import
AbstractFunctionError
class
InconsistencyError
(
Exception
):
class
InconsistencyError
(
Exception
):
"""
"""
This exception
is raised by Env whenever one of the listeners marks
This exception
should be thrown by listeners to Env when the
the graph as inconsistent
.
graph's state is invalid
.
"""
"""
pass
pass
class
Env
(
object
):
#(graph.Graph
):
class
Env
(
utils
.
object2
):
"""
"""
An Env represents a subgraph bound by a set of input results and a
An Env represents a subgraph bound by a set of input results and a
set of output results. An op is in the subgraph iff it depends on
set of output results. The inputs list should contain all the inputs
the value of some of the Env's inputs _and_ some of the Env's
on which the outputs depend. Results of type Value or Constant are
outputs depend on it. A result is in the subgraph iff it is an
not counted as inputs.
input or an output of an op that is in the subgraph.
The Env supports the replace operation which allows to replace a
The Env supports the replace operation which allows to replace a
result in the subgraph by another, e.g. replace (x + x).out by (2
result in the subgraph by another, e.g. replace (x + x).out by (2
* x).out. This is the basis for optimization in theano.
* x).out. This is the basis for optimization in theano.
It can also be "extended" using env.extend(some_object). See the
toolbox and ext modules for common extensions.
"""
"""
### Special ###
### Special ###
...
@@ -65,12 +64,14 @@ class Env(object): #(graph.Graph):
...
@@ -65,12 +64,14 @@ class Env(object): #(graph.Graph):
### Setup a Result ###
### Setup a Result ###
def
__setup_r__
(
self
,
r
):
def
__setup_r__
(
self
,
r
):
# sets up r so it belongs to this env
if
hasattr
(
r
,
'env'
)
and
r
.
env
is
not
None
and
r
.
env
is
not
self
:
if
hasattr
(
r
,
'env'
)
and
r
.
env
is
not
None
and
r
.
env
is
not
self
:
raise
Exception
(
"
%
s is already owned by another env"
%
r
)
raise
Exception
(
"
%
s is already owned by another env"
%
r
)
r
.
env
=
self
r
.
env
=
self
r
.
clients
=
[]
r
.
clients
=
[]
def
__setup_node__
(
self
,
node
):
def
__setup_node__
(
self
,
node
):
# sets up node so it belongs to this env
if
hasattr
(
node
,
'env'
)
and
node
.
env
is
not
self
:
if
hasattr
(
node
,
'env'
)
and
node
.
env
is
not
self
:
raise
Exception
(
"
%
s is already owned by another env"
%
node
)
raise
Exception
(
"
%
s is already owned by another env"
%
node
)
node
.
env
=
self
node
.
env
=
self
...
@@ -80,28 +81,27 @@ class Env(object): #(graph.Graph):
...
@@ -80,28 +81,27 @@ class Env(object): #(graph.Graph):
### clients ###
### clients ###
def
clients
(
self
,
r
):
def
clients
(
self
,
r
):
"Set of all the (
op, i) pairs such that op
.inputs[i] is r."
"Set of all the (
node, i) pairs such that node
.inputs[i] is r."
return
r
.
clients
return
r
.
clients
def
__add_clients__
(
self
,
r
,
all
):
def
__add_clients__
(
self
,
r
,
new_clients
):
"""
"""
r -> result
r -> result
all -> list of (op, i) pairs representing who r is an input of
.
new_clients -> list of (node, i) pairs such that node.inputs[i] is r
.
Updates the list of clients of r with
all
.
Updates the list of clients of r with
new_clients
.
"""
"""
r
.
clients
+=
all
r
.
clients
+=
new_clients
def
__remove_clients__
(
self
,
r
,
all
,
prune
=
True
):
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
):
"""
"""
r -> result
r -> result
all -> list of (op, i) pairs representing who r is an input of
.
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore
.
Removes all from the clients list of r.
Removes all from the clients list of r.
"""
"""
for
entry
in
all
:
for
entry
in
clients_to_remove
:
r
.
clients
.
remove
(
entry
)
r
.
clients
.
remove
(
entry
)
# remove from orphans?
if
not
r
.
clients
:
if
not
r
.
clients
:
if
prune
:
if
prune
:
self
.
__prune_r__
([
r
])
self
.
__prune_r__
([
r
])
...
@@ -188,6 +188,15 @@ class Env(object): #(graph.Graph):
...
@@ -188,6 +188,15 @@ class Env(object): #(graph.Graph):
### change input ###
### change input ###
def
change_input
(
self
,
node
,
i
,
new_r
):
def
change_input
(
self
,
node
,
i
,
new_r
):
"""
Changes node.inputs[i] to new_r.
new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace.
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(env, node, i, old_r, new_r)
"""
if
node
==
'output'
:
if
node
==
'output'
:
r
=
self
.
outputs
[
i
]
r
=
self
.
outputs
[
i
]
if
not
r
.
type
==
new_r
.
type
:
if
not
r
.
type
==
new_r
.
type
:
...
@@ -214,10 +223,7 @@ class Env(object): #(graph.Graph):
...
@@ -214,10 +223,7 @@ class Env(object): #(graph.Graph):
def
replace
(
self
,
r
,
new_r
):
def
replace
(
self
,
r
,
new_r
):
"""
"""
This is the main interface to manipulate the subgraph in Env.
This is the main interface to manipulate the subgraph in Env.
For every op that uses r as input, makes it use new_r instead.
For every node that uses r as input, makes it use new_r instead.
This may raise an error if the new result violates type
constraints for one of the target nodes. In that case, no
changes are made.
"""
"""
if
r
.
env
is
not
self
:
if
r
.
env
is
not
self
:
raise
Exception
(
"Cannot replace
%
s because it does not belong to this Env"
%
r
)
raise
Exception
(
"Cannot replace
%
s because it does not belong to this Env"
%
r
)
...
@@ -238,11 +244,32 @@ class Env(object): #(graph.Graph):
...
@@ -238,11 +244,32 @@ class Env(object): #(graph.Graph):
def
extend
(
self
,
feature
):
def
extend
(
self
,
feature
):
"""
"""
@todo out of date
Adds a feature to this env. The feature may define one
Adds an instance of the feature_class to this env's supported
or more of the following methods:
features. If do_import is True and feature_class is a subclass
of Listener, its on_import method will be called on all the Nodes
- feature.on_attach(env)
already in the env.
Called by extend. The feature has great freedom in what
it can do with the env: it may, for example, add methods
to it dynicamically.
- feature.on_detach(env)
Called by remove_feature(feature).
- feature.on_import(env, node)*
Called whenever a node is imported into env, which is
just before the node is actually connected to the graph.
- feature.on_prune(env, node)*
Called whenever a node is pruned (removed) from the env,
after it is disconnected from the graph.
- feature.on_change_input(env, node, i, r, new_r)*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(env)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
"""
"""
if
feature
in
self
.
_features
:
if
feature
in
self
.
_features
:
return
# the feature is already present
return
# the feature is already present
...
@@ -256,6 +283,11 @@ class Env(object): #(graph.Graph):
...
@@ -256,6 +283,11 @@ class Env(object): #(graph.Graph):
raise
raise
def
remove_feature
(
self
,
feature
):
def
remove_feature
(
self
,
feature
):
"""
Removes the feature from the graph.
Calls feature.on_detach(env) if an on_detach method is defined.
"""
try
:
try
:
self
.
_features
.
remove
(
feature
)
self
.
_features
.
remove
(
feature
)
except
:
except
:
...
@@ -268,6 +300,11 @@ class Env(object): #(graph.Graph):
...
@@ -268,6 +300,11 @@ class Env(object): #(graph.Graph):
### callback utils ###
### callback utils ###
def
execute_callbacks
(
self
,
name
,
*
args
):
def
execute_callbacks
(
self
,
name
,
*
args
):
"""
Calls
getattr(feature, name)(*args)
for each feature which has a method called after name.
"""
for
feature
in
self
.
_features
:
for
feature
in
self
.
_features
:
try
:
try
:
fn
=
getattr
(
feature
,
name
)
fn
=
getattr
(
feature
,
name
)
...
@@ -276,6 +313,11 @@ class Env(object): #(graph.Graph):
...
@@ -276,6 +313,11 @@ class Env(object): #(graph.Graph):
fn
(
self
,
*
args
)
fn
(
self
,
*
args
)
def
collect_callbacks
(
self
,
name
,
*
args
):
def
collect_callbacks
(
self
,
name
,
*
args
):
"""
Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name.
"""
d
=
{}
d
=
{}
for
feature
in
self
.
_features
:
for
feature
in
self
.
_features
:
try
:
try
:
...
@@ -289,6 +331,17 @@ class Env(object): #(graph.Graph):
...
@@ -289,6 +331,17 @@ class Env(object): #(graph.Graph):
### misc ###
### misc ###
def
toposort
(
self
):
def
toposort
(
self
):
"""
Returns an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
an 'orderings' method.
If a feature has an 'orderings' method, it will be called with
this env as sole argument. It should return a dictionary of
{node: predecessors} where predecessors is a list of nodes
that should be computed before the key node.
"""
env
=
self
env
=
self
ords
=
{}
ords
=
{}
for
feature
in
env
.
_features
:
for
feature
in
env
.
_features
:
...
@@ -314,10 +367,10 @@ class Env(object): #(graph.Graph):
...
@@ -314,10 +367,10 @@ class Env(object): #(graph.Graph):
raise
Exception
(
"what the fuck"
)
raise
Exception
(
"what the fuck"
)
return
node
.
inputs
return
node
.
inputs
def
has_node
(
self
,
node
):
return
node
in
self
.
nodes
def
check_integrity
(
self
):
def
check_integrity
(
self
):
"""
Call this for a diagnosis if things go awry.
"""
nodes
=
graph
.
ops
(
self
.
inputs
,
self
.
outputs
)
nodes
=
graph
.
ops
(
self
.
inputs
,
self
.
outputs
)
if
self
.
nodes
!=
nodes
:
if
self
.
nodes
!=
nodes
:
missing
=
nodes
.
difference
(
self
.
nodes
)
missing
=
nodes
.
difference
(
self
.
nodes
)
...
...
gof/ext.py
浏览文件 @
14090d83
#from features import Listener, Constraint, Orderings, Tool
from
collections
import
defaultdict
import
graph
import
graph
import
utils
import
utils
from
utils
import
AbstractFunctionError
import
toolbox
from
copy
import
copy
from
utils
import
AbstractFunctionError
from
env
import
InconsistencyError
from
env
import
InconsistencyError
from
toolbox
import
Bookkeeper
from
collections
import
defaultdict
class
DestroyHandler
(
toolbox
.
Bookkeeper
):
class
DestroyHandler
(
Bookkeeper
):
#(Listener, Constraint, Orderings, Tool):
"""
"""
This feature ensures that an env represents a consistent data flow
This feature ensures that an env represents a consistent data flow
when some Ops overwrite their inputs and/or provide "views" over
when some Ops overwrite their inputs and/or provide "views" over
...
@@ -29,14 +21,16 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
...
@@ -29,14 +21,16 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
Examples:
Examples:
- (x += 1) + (x += 1) -> fails because the first += makes the second
- (x += 1) + (x += 1) -> fails because the first += makes the second
invalid
invalid
- x += transpose_view(x) -> fails because the input that is destroyed
depends on an input that shares the same data
- (a += b) + (c += a) -> succeeds but we have to do c += a first
- (a += b) + (c += a) -> succeeds but we have to do c += a first
- (a += b) + (b += c) + (c += a) -> fails because there's a cyclical
- (a += b) + (b += c) + (c += a) -> fails because there's a cyclical
dependency (no possible ordering)
dependency (no possible ordering)
This feature allows some optimizations (eg sub += for +) to be applied
This feature allows some optimizations (eg sub += for +) to be applied
safely.
safely.
@todo
- x += transpose_view(x) -> fails because the input that is destroyed
depends on an input that shares the same data
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -88,11 +82,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
...
@@ -88,11 +82,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
self
.
seen
=
set
()
self
.
seen
=
set
()
Bookkeeper
.
on_attach
(
self
,
env
)
toolbox
.
Bookkeeper
.
on_attach
(
self
,
env
)
# # Initialize the children if the inputs and orphans.
# for input in env.inputs: # env.orphans.union(env.inputs):
# self.children[input] = set()
def
on_detach
(
self
,
env
):
def
on_detach
(
self
,
env
):
del
self
.
parent
del
self
.
parent
...
@@ -105,19 +95,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
...
@@ -105,19 +95,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
del
self
.
seen
del
self
.
seen
self
.
env
=
None
self
.
env
=
None
# def publish(self):
# """
# Publishes the following on the env:
# - destroyers(r) -> returns all L{Op}s that destroy the result r
# - destroy_handler -> self
# """
# def __destroyers(r):
# ret = self.destroyers.get(r, {})
# ret = ret.keys()
# return ret
# self.env.destroyers = __destroyers
# self.env.destroy_handler = self
def
__path__
(
self
,
r
):
def
__path__
(
self
,
r
):
"""
"""
Returns a path from r to the result that it is ultimately
Returns a path from r to the result that it is ultimately
...
@@ -171,7 +148,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
...
@@ -171,7 +148,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
def
__pre__
(
self
,
op
):
def
__pre__
(
self
,
op
):
"""
"""
Returns all results that must be computed prior to computing
Returns all results that must be computed prior to computing
this
op
.
this
node
.
"""
"""
rval
=
set
()
rval
=
set
()
if
op
is
None
:
if
op
is
None
:
...
@@ -222,7 +199,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
...
@@ -222,7 +199,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
users
=
set
(
self
.
__users__
(
start
))
users
=
set
(
self
.
__users__
(
start
))
users
.
add
(
start
)
users
.
add
(
start
)
for
user
in
users
:
for
user
in
users
:
for
cycle
in
copy
(
self
.
cycles
):
for
cycle
in
set
(
self
.
cycles
):
if
user
in
cycle
:
if
user
in
cycle
:
self
.
cycles
.
remove
(
cycle
)
self
.
cycles
.
remove
(
cycle
)
if
just_remove
:
if
just_remove
:
...
@@ -234,7 +211,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
...
@@ -234,7 +211,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
"""
"""
@return: (vmap, dmap) where:
@return: (vmap, dmap) where:
- vmap -> {output : [inputs output is a view of]}
- vmap -> {output : [inputs output is a view of]}
- dmap -> {output : [inputs that are destroyed by the
Op
- dmap -> {output : [inputs that are destroyed by the
node
(and presumably returned as that output)]}
(and presumably returned as that output)]}
"""
"""
try
:
_vmap
=
node
.
op
.
view_map
try
:
_vmap
=
node
.
op
.
view_map
...
@@ -260,7 +237,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
...
@@ -260,7 +237,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
def
on_import
(
self
,
env
,
op
):
def
on_import
(
self
,
env
,
op
):
"""
"""
Recomputes the dependencies and search for inconsistencies given
Recomputes the dependencies and search for inconsistencies given
that we just added an
op
to the env.
that we just added an
node
to the env.
"""
"""
self
.
seen
.
add
(
op
)
self
.
seen
.
add
(
op
)
...
@@ -303,7 +280,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
...
@@ -303,7 +280,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
def
on_prune
(
self
,
env
,
op
):
def
on_prune
(
self
,
env
,
op
):
"""
"""
Recomputes the dependencies and searches for inconsistencies to remove
Recomputes the dependencies and searches for inconsistencies to remove
given that we just removed a
n op
to the env.
given that we just removed a
node
to the env.
"""
"""
view_map
,
destroy_map
=
self
.
get_maps
(
op
)
view_map
,
destroy_map
=
self
.
get_maps
(
op
)
...
@@ -400,7 +377,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
...
@@ -400,7 +377,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
"""
"""
Recomputes the dependencies and searches for inconsistencies to remove
Recomputes the dependencies and searches for inconsistencies to remove
given that all the clients are moved from r_1 to r_2, clients being
given that all the clients are moved from r_1 to r_2, clients being
a list of (
op, i) pairs such that op
.inputs[i] used to be r_1 and is
a list of (
node, i) pairs such that node
.inputs[i] used to be r_1 and is
now r_2.
now r_2.
"""
"""
path_1
=
self
.
__path__
(
r_1
)
path_1
=
self
.
__path__
(
r_1
)
...
@@ -485,53 +462,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
...
@@ -485,53 +462,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
# class Destroyer:
# """
# Base class for Ops that destroy one or more of their inputs in an
# inplace operation, use them as temporary storage, puts garbage in
# them or anything else that invalidates the contents for use by other
# Ops.
# Usage of this class in an env requires DestroyHandler.
# """
# def destroyed_inputs(self):
# raise AbstractFunctionError()
# def destroy_map(self):
# """
# Returns the map {output: [list of destroyed inputs]}
# While it typically means that the storage of the output is
# shared with each of the destroyed inputs, it does necessarily
# have to be the case.
# """
# # compatibility
# return {self.out: self.destroyed_inputs()}
# __env_require__ = DestroyHandler
# class Viewer:
# """
# Base class for Ops that return one or more views over one or more inputs,
# which means that the inputs and outputs share their storage. Unless it also
# extends Destroyer, this Op does not modify the storage in any way and thus
# the input is safe for use by other Ops even after executing this one.
# """
# def view_map(self):
# """
# Returns the map {output: [list of viewed inputs]}
# It means that the output shares storage with each of the inputs
# in the list.
# Note: support for more than one viewed input is minimal, but
# this might improve in the future.
# """
# raise AbstractFunctionError()
def
view_roots
(
r
):
def
view_roots
(
r
):
"""
"""
Utility function that returns the leaves of a search through
Utility function that returns the leaves of a search through
...
...
gof/graph.py
浏览文件 @
14090d83
...
@@ -3,20 +3,9 @@ from copy import copy
...
@@ -3,20 +3,9 @@ from copy import copy
from
collections
import
deque
from
collections
import
deque
import
utils
import
utils
from
utils
import
object2
def
deprecated
(
f
):
printme
=
[
True
]
def
g
(
*
args
,
**
kwargs
):
if
printme
[
0
]:
print
'gof.graph.
%
s deprecated: April 29'
%
f
.
__name__
printme
[
0
]
=
False
return
f
(
*
args
,
**
kwargs
)
return
g
class
Apply
(
utils
.
object2
):
class
Apply
(
object2
):
"""
"""
Note: it is illegal for an output element to have an owner != self
Note: it is illegal for an output element to have an owner != self
"""
"""
...
@@ -74,18 +63,13 @@ class Apply(object2):
...
@@ -74,18 +63,13 @@ class Apply(object2):
raise
TypeError
(
"Cannot change the type of this input."
,
curr
,
new
)
raise
TypeError
(
"Cannot change the type of this input."
,
curr
,
new
)
new_node
=
self
.
clone
()
new_node
=
self
.
clone
()
new_node
.
inputs
=
inputs
new_node
.
inputs
=
inputs
# new_node.outputs = []
# for output in self.outputs:
# new_output = copy(output)
# new_output.owner = new_node
# new_node.outputs.append(new_output)
return
new_node
return
new_node
nin
=
property
(
lambda
self
:
len
(
self
.
inputs
))
nin
=
property
(
lambda
self
:
len
(
self
.
inputs
))
nout
=
property
(
lambda
self
:
len
(
self
.
outputs
))
nout
=
property
(
lambda
self
:
len
(
self
.
outputs
))
class
Result
(
object2
):
class
Result
(
utils
.
object2
):
#__slots__ = ['type', 'owner', 'index', 'name']
#__slots__ = ['type', 'owner', 'index', 'name']
def
__init__
(
self
,
type
,
owner
=
None
,
index
=
None
,
name
=
None
):
def
__init__
(
self
,
type
,
owner
=
None
,
index
=
None
,
name
=
None
):
self
.
type
=
type
self
.
type
=
type
...
@@ -111,9 +95,6 @@ class Result(object2):
...
@@ -111,9 +95,6 @@ class Result(object2):
return
"<?>::"
+
str
(
self
.
type
)
return
"<?>::"
+
str
(
self
.
type
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
str
(
self
)
return
str
(
self
)
@deprecated
def
__asresult__
(
self
):
return
self
def
clone
(
self
):
def
clone
(
self
):
return
self
.
__class__
(
self
.
type
,
None
,
None
,
self
.
name
)
return
self
.
__class__
(
self
.
type
,
None
,
None
,
self
.
name
)
...
@@ -137,7 +118,6 @@ class Constant(Value):
...
@@ -137,7 +118,6 @@ class Constant(Value):
#__slots__ = ['data']
#__slots__ = ['data']
def
__init__
(
self
,
type
,
data
,
name
=
None
):
def
__init__
(
self
,
type
,
data
,
name
=
None
):
Value
.
__init__
(
self
,
type
,
data
,
name
)
Value
.
__init__
(
self
,
type
,
data
,
name
)
### self.indestructible = True
def
equals
(
self
,
other
):
def
equals
(
self
,
other
):
# this does what __eq__ should do, but Result and Apply should always be hashable by id
# this does what __eq__ should do, but Result and Apply should always be hashable by id
return
type
(
other
)
==
type
(
self
)
and
self
.
signature
()
==
other
.
signature
()
return
type
(
other
)
==
type
(
self
)
and
self
.
signature
()
==
other
.
signature
()
...
@@ -148,32 +128,6 @@ class Constant(Value):
...
@@ -148,32 +128,6 @@ class Constant(Value):
return
self
.
name
return
self
.
name
return
str
(
self
.
data
)
#+ "::" + str(self.type)
return
str
(
self
.
data
)
#+ "::" + str(self.type)
@deprecated
def
as_result
(
x
):
if
isinstance
(
x
,
Result
):
return
x
# elif isinstance(x, Type):
# return Result(x, None, None)
elif
hasattr
(
x
,
'__asresult__'
):
r
=
x
.
__asresult__
()
if
not
isinstance
(
r
,
Result
):
raise
TypeError
(
"
%
s.__asresult__ must return a Result instance"
%
x
,
(
x
,
r
))
return
r
else
:
raise
TypeError
(
"Cannot wrap
%
s in a Result"
%
x
)
@deprecated
def
as_apply
(
x
):
if
isinstance
(
x
,
Apply
):
return
x
elif
hasattr
(
x
,
'__asapply__'
):
node
=
x
.
__asapply__
()
if
not
isinstance
(
node
,
Apply
):
raise
TypeError
(
"
%
s.__asapply__ must return an Apply instance"
%
x
,
(
x
,
node
))
return
node
else
:
raise
TypeError
(
"Cannot map
%
s to Apply"
%
x
)
def
stack_search
(
start
,
expand
,
mode
=
'bfs'
,
build_inv
=
False
):
def
stack_search
(
start
,
expand
,
mode
=
'bfs'
,
build_inv
=
False
):
"""Search through L{Result}s, either breadth- or depth-first
"""Search through L{Result}s, either breadth- or depth-first
@type start: deque
@type start: deque
...
@@ -234,45 +188,6 @@ def inputs(result_list):
...
@@ -234,45 +188,6 @@ def inputs(result_list):
return
rval
return
rval
# def results_and_orphans(r_in, r_out, except_unreachable_input=False):
# r_in_set = set(r_in)
# class Dummy(object): pass
# dummy = Dummy()
# dummy.inputs = r_out
# def expand_inputs(io):
# if io in r_in_set:
# return None
# try:
# return [io.owner] if io.owner != None else None
# except AttributeError:
# return io.inputs
# ops_and_results, dfsinv = stack_search(
# deque([dummy]),
# expand_inputs, 'dfs', True)
# if except_unreachable_input:
# for r in r_in:
# if r not in dfsinv:
# raise Exception(results_and_orphans.E_unreached)
# clients = stack_search(
# deque(r_in),
# lambda io: dfsinv.get(io,None), 'dfs')
# ops_to_compute = [o for o in clients if is_op(o) and o is not dummy]
# results = []
# for o in ops_to_compute:
# results.extend(o.inputs)
# results.extend(r_out)
# op_set = set(ops_to_compute)
# assert len(ops_to_compute) == len(op_set)
# orphans = [r for r in results \
# if (r.owner not in op_set) and (r not in r_in_set)]
# return results, orphans
# results_and_orphans.E_unreached = 'there were unreachable inputs'
def
results_and_orphans
(
i
,
o
):
def
results_and_orphans
(
i
,
o
):
"""
"""
"""
"""
...
@@ -286,24 +201,6 @@ def results_and_orphans(i, o):
...
@@ -286,24 +201,6 @@ def results_and_orphans(i, o):
return
results
,
orphans
return
results
,
orphans
#def results_and_orphans(i, o):
# results = set()
# orphans = set()
# def helper(r):
# if r in results:
# return
# results.add(r)
# if r.owner is None:
# if r not in i:
# orphans.add(r)
# else:
# for r2 in r.owner.inputs:
# helper(r2)
# for output in o:
# helper(output)
# return results, orphans
def
ops
(
i
,
o
):
def
ops
(
i
,
o
):
"""
"""
@type i: list
@type i: list
...
@@ -569,122 +466,3 @@ def as_string(i, o,
...
@@ -569,122 +466,3 @@ def as_string(i, o,
return
[
describe
(
output
)
for
output
in
o
]
return
[
describe
(
output
)
for
output
in
o
]
# class Graph:
# """
# Object-oriented wrapper for all the functions in this module.
# """
# def __init__(self, inputs, outputs):
# self.inputs = inputs
# self.outputs = outputs
# def ops(self):
# return ops(self.inputs, self.outputs)
# def values(self):
# return values(self.inputs, self.outputs)
# def orphans(self):
# return orphans(self.inputs, self.outputs)
# def io_toposort(self):
# return io_toposort(self.inputs, self.outputs)
# def toposort(self):
# return self.io_toposort()
# def clone(self):
# o = clone(self.inputs, self.outputs)
# return Graph(self.inputs, o)
# def __str__(self):
# return as_string(self.inputs, self.outputs)
if
0
:
#these were the old implementations
# they were replaced out of a desire that graph search routines would not
# depend on the hash or id of any node, so that it would be deterministic
# and consistent between program executions.
@utils.deprecated
(
'gof.graph'
,
'preserving only for review'
)
def
_results_and_orphans
(
i
,
o
,
except_unreachable_input
=
False
):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
Returns the pair (results, orphans). The former is the set of
L{Result}s that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
"""
results
=
set
()
i
=
set
(
i
)
results
.
update
(
i
)
incomplete_paths
=
[]
reached
=
set
()
def
helper
(
r
,
path
):
if
r
in
i
:
reached
.
add
(
r
)
results
.
update
(
path
)
elif
r
.
owner
is
None
:
incomplete_paths
.
append
(
path
)
else
:
op
=
r
.
owner
for
r2
in
op
.
inputs
:
helper
(
r2
,
path
+
[
r2
])
for
output
in
o
:
helper
(
output
,
[
output
])
orphans
=
set
()
for
path
in
incomplete_paths
:
for
r
in
path
:
if
r
not
in
results
:
orphans
.
add
(
r
)
break
if
except_unreachable_input
and
len
(
i
)
!=
len
(
reached
):
raise
Exception
(
results_and_orphans
.
E_unreached
)
results
.
update
(
orphans
)
return
results
,
orphans
def
_io_toposort
(
i
,
o
,
orderings
=
{}):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
@param orderings: {op: [requirements for op]} (defaults to {})
@rtype: ordered list
@return: L{Op}s that belong in the subgraph between i and o which
respects the following constraints:
- all inputs in i are assumed to be already computed
- the L{Op}s that compute an L{Op}'s inputs must be computed before it
- the orderings specified in the optional orderings parameter must be satisfied
Note that this function does not take into account ordering information
related to destructive operations or other special behavior.
"""
prereqs_d
=
copy
(
orderings
)
all
=
ops
(
i
,
o
)
for
op
in
all
:
asdf
=
set
([
input
.
owner
for
input
in
op
.
inputs
if
input
.
owner
and
input
.
owner
in
all
])
prereqs_d
.
setdefault
(
op
,
set
())
.
update
(
asdf
)
return
utils
.
toposort
(
prereqs_d
)
gof/link.py
浏览文件 @
14090d83
from
utils
import
AbstractFunctionError
import
utils
import
utils
import
graph
from
graph
import
Value
import
sys
,
traceback
import
sys
import
traceback
__excepthook
=
sys
.
excepthook
__excepthook
=
sys
.
excepthook
...
@@ -67,7 +64,7 @@ class Linker:
...
@@ -67,7 +64,7 @@ class Linker:
print new_e.data # 3.0
print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
print e.data # 3.0 iff inplace == True (else unknown)
"""
"""
raise
AbstractFunctionError
()
raise
utils
.
AbstractFunctionError
()
def
make_function
(
self
,
unpack_single
=
True
,
**
kwargs
):
def
make_function
(
self
,
unpack_single
=
True
,
**
kwargs
):
"""
"""
...
@@ -151,7 +148,7 @@ def map_storage(env, order, input_storage, output_storage):
...
@@ -151,7 +148,7 @@ def map_storage(env, order, input_storage, output_storage):
for
node
in
order
:
for
node
in
order
:
for
r
in
node
.
inputs
:
for
r
in
node
.
inputs
:
if
r
not
in
storage_map
:
if
r
not
in
storage_map
:
assert
isinstance
(
r
,
Value
)
assert
isinstance
(
r
,
graph
.
Value
)
storage_map
[
r
]
=
[
r
.
data
]
storage_map
[
r
]
=
[
r
.
data
]
for
r
in
node
.
outputs
:
for
r
in
node
.
outputs
:
storage_map
.
setdefault
(
r
,
[
None
])
storage_map
.
setdefault
(
r
,
[
None
])
...
...
gof/op.py
浏览文件 @
14090d83
"""
"""
Contains the L{Op} class, which is the base interface for all operations
Contains the L{Op} class, which is the base interface for all operations
compatible with gof's graph manipulation routines.
compatible with gof's graph manipulation routines.
"""
"""
import
utils
import
utils
from
utils
import
AbstractFunctionError
,
object2
from
copy
import
copy
class
Op
(
object2
):
class
Op
(
utils
.
object2
):
default_output
=
None
default_output
=
None
"""@todo
"""@todo
...
@@ -22,9 +17,19 @@ class Op(object2):
...
@@ -22,9 +17,19 @@ class Op(object2):
#############
#############
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
raise
AbstractFunctionError
()
"""
This function should return an Apply instance representing the
application of this Op on the provided inputs.
"""
raise
utils
.
AbstractFunctionError
()
def
__call__
(
self
,
*
inputs
):
def
__call__
(
self
,
*
inputs
):
"""
Shortcut for:
self.make_node(*inputs).outputs[self.default_output] (if default_output is defined)
self.make_node(*inputs).outputs[0] (if only one output)
self.make_node(*inputs).outputs (if more than one output)
"""
node
=
self
.
make_node
(
*
inputs
)
node
=
self
.
make_node
(
*
inputs
)
if
self
.
default_output
is
not
None
:
if
self
.
default_output
is
not
None
:
return
node
.
outputs
[
self
.
default_output
]
return
node
.
outputs
[
self
.
default_output
]
...
@@ -44,6 +49,7 @@ class Op(object2):
...
@@ -44,6 +49,7 @@ class Op(object2):
Calculate the function on the inputs and put the results in the
Calculate the function on the inputs and put the results in the
output storage.
output storage.
- node: Apply instance that contains the symbolic inputs and outputs
- inputs: sequence of inputs (immutable)
- inputs: sequence of inputs (immutable)
- output_storage: list of mutable 1-element lists (do not change
- output_storage: list of mutable 1-element lists (do not change
the length of these lists)
the length of these lists)
...
@@ -53,7 +59,7 @@ class Op(object2):
...
@@ -53,7 +59,7 @@ class Op(object2):
by a previous call to impl and impl is free to reuse it as it
by a previous call to impl and impl is free to reuse it as it
sees fit.
sees fit.
"""
"""
raise
AbstractFunctionError
()
raise
utils
.
AbstractFunctionError
()
#####################
#####################
# C code generation #
# C code generation #
...
@@ -62,9 +68,8 @@ class Op(object2):
...
@@ -62,9 +68,8 @@ class Op(object2):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
"""Return the C implementation of an Op.
"""Return the C implementation of an Op.
Returns templated C code that does the computation associated
Returns C code that does the computation associated to this L{Op},
to this L{Op}. You may assume that input validation and output
given names for the inputs and outputs.
allocation have already been done.
@param inputs: list of strings. There is a string for each input
@param inputs: list of strings. There is a string for each input
of the function, and the string is the name of a C
of the function, and the string is the name of a C
...
@@ -80,7 +85,7 @@ class Op(object2):
...
@@ -80,7 +85,7 @@ class Op(object2):
'fail').
'fail').
"""
"""
raise
AbstractFunctionError
(
'
%
s.c_code
'
\
raise
utils
.
AbstractFunctionError
(
'
%
s.c_code is not defined
'
\
%
self
.
__class__
.
__name__
)
%
self
.
__class__
.
__name__
)
def
c_code_cleanup
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code_cleanup
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
...
@@ -89,44 +94,33 @@ class Op(object2):
...
@@ -89,44 +94,33 @@ class Op(object2):
This is a convenient place to clean up things allocated by c_code().
This is a convenient place to clean up things allocated by c_code().
"""
"""
raise
AbstractFunctionError
()
raise
utils
.
AbstractFunctionError
()
def
c_compile_args
(
self
):
def
c_compile_args
(
self
):
"""
"""
Return a list of compile args recommended to manipulate this L{Op}.
Return a list of compile args recommended to manipulate this L{Op}.
"""
"""
raise
AbstractFunctionError
()
raise
utils
.
AbstractFunctionError
()
def
c_headers
(
self
):
def
c_headers
(
self
):
"""
"""
Return a list of header files that must be included from C to manipulate
Return a list of header files that must be included from C to manipulate
this L{Op}.
this L{Op}.
"""
"""
raise
AbstractFunctionError
()
raise
utils
.
AbstractFunctionError
()
def
c_libraries
(
self
):
def
c_libraries
(
self
):
"""
"""
Return a list of libraries to link against to manipulate this L{Op}.
Return a list of libraries to link against to manipulate this L{Op}.
"""
"""
raise
AbstractFunctionError
()
raise
utils
.
AbstractFunctionError
()
def
c_support_code
(
self
):
def
c_support_code
(
self
):
"""
"""
Return utility code for use by this L{Op}. It may refer to support code
Return utility code for use by this L{Op}. It may refer to support code
defined for its input L{Result}s.
defined for its input L{Result}s.
"""
"""
raise
AbstractFunctionError
()
raise
utils
.
AbstractFunctionError
()
class
PropertiedOp
(
Op
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
__dict__
==
other
.
__dict__
def
__str__
(
self
):
if
hasattr
(
self
,
'name'
)
and
self
.
name
:
return
self
.
name
else
:
return
"
%
s{
%
s}"
%
(
self
.
__class__
.
__name__
,
", "
.
join
(
"
%
s=
%
s"
%
(
k
,
v
)
for
k
,
v
in
self
.
__dict__
.
items
()
if
k
!=
"name"
))
gof/opt.py
浏览文件 @
14090d83
"""
Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
"""
from
op
import
Op
from
graph
import
Constant
import
graph
from
type
import
Type
from
env
import
InconsistencyError
from
env
import
InconsistencyError
import
utils
import
utils
import
unify
import
unify
import
toolbox
import
toolbox
import
ext
class
Optimizer
:
class
Optimizer
:
...
@@ -20,9 +22,8 @@ class Optimizer:
...
@@ -20,9 +22,8 @@ class Optimizer:
"""
"""
Applies the optimization to the provided L{Env}. It may use all
Applies the optimization to the provided L{Env}. It may use all
the methods defined by the L{Env}. If the L{Optimizer} needs
the methods defined by the L{Env}. If the L{Optimizer} needs
to use a certain tool, such as an L{InstanceFinder}, it should
to use a certain tool, such as an L{InstanceFinder}, it can do
set the L{__env_require__} field to a list of what needs to be
so in its L{add_requirements} method.
registered with the L{Env}.
"""
"""
pass
pass
...
@@ -36,9 +37,19 @@ class Optimizer:
...
@@ -36,9 +37,19 @@ class Optimizer:
self
.
apply
(
env
)
self
.
apply
(
env
)
def
__call__
(
self
,
env
):
def
__call__
(
self
,
env
):
"""
Same as self.optimize(env)
"""
return
self
.
optimize
(
env
)
return
self
.
optimize
(
env
)
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
"""
Add features to the env that are required to apply the optimization.
For example:
env.extend(History())
env.extend(MyFeature())
etc.
"""
pass
pass
...
@@ -79,7 +90,7 @@ class LocalOptimizer(Optimizer):
...
@@ -79,7 +90,7 @@ class LocalOptimizer(Optimizer):
following two methods:
following two methods:
- candidates(env) -> returns a set of ops that can be
- candidates(env) -> returns a set of ops that can be
optimized
optimized
- apply_on_
op(env, op) -> for each op
in candidates,
- apply_on_
node(env, node) -> for each node
in candidates,
this function will be called to perform the actual
this function will be called to perform the actual
optimization.
optimization.
"""
"""
...
@@ -102,7 +113,7 @@ class LocalOptimizer(Optimizer):
...
@@ -102,7 +113,7 @@ class LocalOptimizer(Optimizer):
Calls self.apply_on_op(env, op) for each op in self.candidates(env).
Calls self.apply_on_op(env, op) for each op in self.candidates(env).
"""
"""
for
node
in
self
.
candidates
(
env
):
for
node
in
self
.
candidates
(
env
):
if
env
.
has_node
(
node
)
:
if
node
in
env
.
nodes
:
self
.
apply_on_node
(
env
,
node
)
self
.
apply_on_node
(
env
,
node
)
...
@@ -122,7 +133,7 @@ class OpSpecificOptimizer(LocalOptimizer):
...
@@ -122,7 +133,7 @@ class OpSpecificOptimizer(LocalOptimizer):
def
candidates
(
self
,
env
):
def
candidates
(
self
,
env
):
"""
"""
Returns all
instances of L{self.op}
.
Returns all
nodes that have L{self.op} in their op field
.
"""
"""
return
env
.
get_nodes
(
self
.
op
)
return
env
.
get_nodes
(
self
.
op
)
...
@@ -131,13 +142,22 @@ class OpSpecificOptimizer(LocalOptimizer):
...
@@ -131,13 +142,22 @@ class OpSpecificOptimizer(LocalOptimizer):
class
OpSubOptimizer
(
Optimizer
):
class
OpSubOptimizer
(
Optimizer
):
"""
"""
Replaces all
L{Op}s of a certain type by L{Op}s of another type that
Replaces all
applications of a certain op by the application of
take the same inputs as what they are replacing.
another op that
take the same inputs as what they are replacing.
e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
OpSubOptimizer requires the following features:
- NodeFinder
- ReplaceValidate
"""
"""
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
"""
Requires the following features:
- NodeFinder
- ReplaceValidate
"""
try
:
try
:
env
.
extend
(
toolbox
.
NodeFinder
())
env
.
extend
(
toolbox
.
NodeFinder
())
env
.
extend
(
toolbox
.
ReplaceValidate
())
env
.
extend
(
toolbox
.
ReplaceValidate
())
...
@@ -145,9 +165,12 @@ class OpSubOptimizer(Optimizer):
...
@@ -145,9 +165,12 @@ class OpSubOptimizer(Optimizer):
def
__init__
(
self
,
op1
,
op2
,
failure_callback
=
None
):
def
__init__
(
self
,
op1
,
op2
,
failure_callback
=
None
):
"""
"""
op1 and op2 must both be Op subclasses, they must both take
op1.make_node and op2.make_node must take the same number of
the same number of inputs and they must both have the same
inputs and have the same number of outputs.
number of outputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to do a replacement in the graph. The
arguments to the callback are: (node, replacement, exception)
"""
"""
self
.
op1
=
op1
self
.
op1
=
op1
self
.
op2
=
op2
self
.
op2
=
op2
...
@@ -155,12 +178,8 @@ class OpSubOptimizer(Optimizer):
...
@@ -155,12 +178,8 @@ class OpSubOptimizer(Optimizer):
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
"""
"""
Replaces all
occurrences of self.op1 by instance
s of self.op2
Replaces all
applications of self.op1 by application
s of self.op2
with the same inputs.
with the same inputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to do a replacement in the graph. The
arguments to the callback are: (op1_instance, replacement, exception)
"""
"""
candidates
=
env
.
get_nodes
(
self
.
op1
)
candidates
=
env
.
get_nodes
(
self
.
op1
)
...
@@ -173,7 +192,6 @@ class OpSubOptimizer(Optimizer):
...
@@ -173,7 +192,6 @@ class OpSubOptimizer(Optimizer):
except
Exception
,
e
:
except
Exception
,
e
:
if
self
.
failure_callback
is
not
None
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
node
,
repl
,
e
)
self
.
failure_callback
(
node
,
repl
,
e
)
pass
def
str
(
self
):
def
str
(
self
):
return
"
%
s ->
%
s"
%
(
self
.
op1
,
self
.
op2
)
return
"
%
s ->
%
s"
%
(
self
.
op1
,
self
.
op2
)
...
@@ -183,7 +201,7 @@ class OpSubOptimizer(Optimizer):
...
@@ -183,7 +201,7 @@ class OpSubOptimizer(Optimizer):
class
OpRemover
(
Optimizer
):
class
OpRemover
(
Optimizer
):
"""
"""
@todo untested
@todo untested
Removes all
ops of a certain type
by transferring each of its
Removes all
applications of an op
by transferring each of its
outputs to the corresponding input.
outputs to the corresponding input.
"""
"""
...
@@ -195,21 +213,19 @@ class OpRemover(Optimizer):
...
@@ -195,21 +213,19 @@ class OpRemover(Optimizer):
def
__init__
(
self
,
op
,
failure_callback
=
None
):
def
__init__
(
self
,
op
,
failure_callback
=
None
):
"""
"""
opclass is the class of the ops to remove. It must take as
Applications of the op must have as many inputs as outputs.
many inputs as outputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to remove an operation in the graph. The
arguments to the callback are: (node, exception)
"""
"""
self
.
op
=
op
self
.
op
=
op
self
.
failure_callback
=
failure_callback
self
.
failure_callback
=
failure_callback
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
"""
"""
Removes all occurrences of self.opclass.
Removes all applications of self.op.
If self.failure_callback is not None, it will be called whenever
the Optimizer fails to remove an operation in the graph. The
arguments to the callback are: (opclass_instance, exception)
"""
"""
candidates
=
env
.
get_nodes
(
self
.
op
)
candidates
=
env
.
get_nodes
(
self
.
op
)
for
node
in
candidates
:
for
node
in
candidates
:
...
@@ -231,17 +247,17 @@ class PatternOptimizer(OpSpecificOptimizer):
...
@@ -231,17 +247,17 @@ class PatternOptimizer(OpSpecificOptimizer):
"""
"""
@todo update
@todo update
Replaces all occurrences of the input pattern by the output pattern:
:
Replaces all occurrences of the input pattern by the output pattern:
input_pattern ::= (
OpClass
, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= (
op
, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>,
input_pattern ::= dict(pattern = <input_pattern>,
constraint = <constraint>)
constraint = <constraint>)
sub_pattern ::= input_pattern
sub_pattern ::= input_pattern
sub_pattern ::= string
sub_pattern ::= string
sub_pattern ::= a
Result r such that r.constant is Tru
e
sub_pattern ::= a
Constant instanc
e
constraint ::= lambda env, expr: additional matching condition
constraint ::= lambda env, expr: additional matching condition
output_pattern ::= (
OpClass
, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= (
op
, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= string
output_pattern ::= string
Each string in the input pattern is a variable that will be set to
Each string in the input pattern is a variable that will be set to
...
@@ -253,8 +269,8 @@ class PatternOptimizer(OpSpecificOptimizer):
...
@@ -253,8 +269,8 @@ class PatternOptimizer(OpSpecificOptimizer):
pattern can.
pattern can.
If you put a constant result in the input pattern, there will be a
If you put a constant result in the input pattern, there will be a
match iff a constant result with the same value
is found in its
match iff a constant result with the same value
and the same type
place.
is found in its
place.
You can add a constraint to the match by using the dict(...) form
You can add a constraint to the match by using the dict(...) form
described above with a 'constraint' key. The constraint must be a
described above with a 'constraint' key. The constraint must be a
...
@@ -263,16 +279,27 @@ class PatternOptimizer(OpSpecificOptimizer):
...
@@ -263,16 +279,27 @@ class PatternOptimizer(OpSpecificOptimizer):
arbitrary criterion.
arbitrary criterion.
Examples:
Examples:
PatternOptimizer((
Add, 'x', 'y'), (A
dd, 'y', 'x'))
PatternOptimizer((
add, 'x', 'y'), (a
dd, 'y', 'x'))
PatternOptimizer((
Multiply, 'x', 'x'), (S
quare, 'x'))
PatternOptimizer((
multiply, 'x', 'x'), (s
quare, 'x'))
PatternOptimizer((
Subtract, (A
dd, 'x', 'y'), 'y'), 'x')
PatternOptimizer((
subtract, (a
dd, 'x', 'y'), 'y'), 'x')
PatternOptimizer((
Power, 'x', Double(2.0, constant = True)), (S
quare, 'x'))
PatternOptimizer((
power, 'x', Constant(double, 2.0)), (s
quare, 'x'))
PatternOptimizer((
B
oggle, {'pattern': 'x',
PatternOptimizer((
b
oggle, {'pattern': 'x',
'constraint': lambda env, expr: expr.
owner.scrabble == Tru
e}),
'constraint': lambda env, expr: expr.
type == scrabbl
e}),
(
S
crabble, 'x'))
(
s
crabble, 'x'))
"""
"""
def
__init__
(
self
,
in_pattern
,
out_pattern
,
allow_multiple_clients
=
False
,
failure_callback
=
None
):
def
__init__
(
self
,
in_pattern
,
out_pattern
,
allow_multiple_clients
=
False
,
failure_callback
=
None
):
"""
Creates a PatternOptimizer that replaces occurrences of
in_pattern by occurrences of out_pattern.
If failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (result_to_replace, replacement, exception).
If allow_multiple_clients is False, he pattern matching will
fail if one of the subpatterns has more than one client.
"""
self
.
in_pattern
=
in_pattern
self
.
in_pattern
=
in_pattern
self
.
out_pattern
=
out_pattern
self
.
out_pattern
=
out_pattern
if
isinstance
(
in_pattern
,
(
list
,
tuple
)):
if
isinstance
(
in_pattern
,
(
list
,
tuple
)):
...
@@ -287,15 +314,8 @@ class PatternOptimizer(OpSpecificOptimizer):
...
@@ -287,15 +314,8 @@ class PatternOptimizer(OpSpecificOptimizer):
def
apply_on_node
(
self
,
env
,
node
):
def
apply_on_node
(
self
,
env
,
node
):
"""
"""
Checks if the graph from
op
corresponds to in_pattern. If it does,
Checks if the graph from
node
corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
constructs out_pattern and performs the replacement.
If self.failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (results_to_replace, replacement, exception).
If self.allow_multiple_clients is False, he pattern matching will fail
if one of the subpatterns has more than one client.
"""
"""
def
match
(
pattern
,
expr
,
u
,
first
=
False
):
def
match
(
pattern
,
expr
,
u
,
first
=
False
):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
...
@@ -323,7 +343,7 @@ class PatternOptimizer(OpSpecificOptimizer):
...
@@ -323,7 +343,7 @@ class PatternOptimizer(OpSpecificOptimizer):
return
False
return
False
else
:
else
:
u
=
u
.
merge
(
expr
,
v
)
u
=
u
.
merge
(
expr
,
v
)
elif
isinstance
(
pattern
,
Constant
)
and
isinstance
(
expr
,
Constant
)
and
pattern
.
equals
(
expr
):
elif
isinstance
(
pattern
,
graph
.
Constant
)
and
isinstance
(
expr
,
graph
.
Constant
)
and
pattern
.
equals
(
expr
):
return
u
return
u
else
:
else
:
return
False
return
False
...
@@ -363,28 +383,6 @@ class PatternOptimizer(OpSpecificOptimizer):
...
@@ -363,28 +383,6 @@ class PatternOptimizer(OpSpecificOptimizer):
# class ConstantFinder(Optimizer):
# """
# Sets as constant every orphan that is not destroyed.
# """
# def apply(self, env):
# if env.has_feature(ext.DestroyHandler(env)):
# for r in env.orphans():
# if not env.destroyers(r):
# r.indestructible = True
# r.constant = True
# # for r in env.inputs:
# # if not env.destroyers(r):
# # r.indestructible = True
# else:
# for r in env.orphans():
# r.indestructible = True
# r.constant = True
# # for r in env.inputs:
# # r.indestructible = True
import
graph
class
_metadict
:
class
_metadict
:
# dict that accepts unhashable keys
# dict that accepts unhashable keys
...
@@ -438,15 +436,14 @@ class MergeOptimizer(Optimizer):
...
@@ -438,15 +436,14 @@ class MergeOptimizer(Optimizer):
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
cid
=
_metadict
()
#result -> result.desc() (for constants)
cid
=
_metadict
()
#result -> result.desc() (for constants)
inv_cid
=
_metadict
()
#desc -> result (for constants)
inv_cid
=
_metadict
()
#desc -> result (for constants)
for
i
,
r
in
enumerate
([
r
for
r
in
env
.
results
if
isinstance
(
r
,
Constant
)]):
#env.orphans.union(env.inputs)):
for
i
,
r
in
enumerate
([
r
for
r
in
env
.
results
if
isinstance
(
r
,
graph
.
Constant
)]):
#if isinstance(r, Constant):
sig
=
r
.
signature
()
sig
=
r
.
signature
()
other_r
=
inv_cid
.
get
(
sig
,
None
)
other_r
=
inv_cid
.
get
(
sig
,
None
)
if
other_r
is
not
None
:
if
other_r
is
not
None
:
env
.
replace
(
r
,
other_r
)
env
.
replace
(
r
,
other_r
)
else
:
else
:
cid
[
r
]
=
sig
cid
[
r
]
=
sig
inv_cid
[
sig
]
=
r
inv_cid
[
sig
]
=
r
# we clear the dicts because the Constants signatures are not necessarily hashable
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer cid like the other Results
# and it's more efficient to give them an integer cid like the other Results
cid
.
clear
()
cid
.
clear
()
...
@@ -483,123 +480,3 @@ def MergeOptMerge(opt):
...
@@ -483,123 +480,3 @@ def MergeOptMerge(opt):
merger
=
MergeOptimizer
()
merger
=
MergeOptimizer
()
return
SeqOptimizer
([
merger
,
opt
,
merger
])
return
SeqOptimizer
([
merger
,
opt
,
merger
])
### THE FOLLOWING OPTIMIZERS ARE NEITHER USED NOR TESTED BUT PROBABLY WORK AND COULD BE USEFUL ###
# class MultiOptimizer(Optimizer):
# def __init__(self, **opts):
# self._opts = []
# self.ord = {}
# self.name_to_opt = {}
# self.up_to_date = True
# for name, opt in opts:
# self.register(name, opt, after = [], before = [])
# def register(self, name, opt, **relative):
# self.name_to_opt[name] = opt
# after = relative.get('after', [])
# if not isinstance(after, (list, tuple)):
# after = [after]
# before = relative.get('before', [])
# if not isinstance(before, (list, tuple)):
# before = [before]
# self.up_to_date = False
# if name in self.ord:
# raise Exception("Cannot redefine optimization: '%s'" % name)
# self.ord[name] = set(after)
# for postreq in before:
# self.ord.setdefault(postreq, set()).add(name)
# def get_opts(self):
# if not self.up_to_date:
# self.refresh()
# return self._opts
# def refresh(self):
# self._opts = [self.name_to_opt[name] for name in utils.toposort(self.ord)]
# self.up_to_date = True
# def apply(self, env):
# for opt in self.opts:
# opt.apply(env)
# opts = property(get_opts)
# class TaggedMultiOptimizer(MultiOptimizer):
# def __init__(self, **opts):
# self.tags = {}
# MultiOptimizer.__init__(self, **opts)
# def register(self, name, opt, tags = [], **relative):
# tags = set(tags)
# tags.add(name)
# self.tags[opt] = tags
# MultiOptimizer.register(self, name, opt, **relative)
# def filter(self, whitelist, blacklist):
# return [opt for opt in self.opts
# if self.tags[opt].intersection(whitelist)
# and not self.tags[opt].intersection(blacklist)]
# def whitelist(self, *tags):
# return [opt for opt in self.opts if self.tags[opt].intersection(tags)]
# def blacklist(self, *tags):
# return [opt for opt in self.opts if not self.tags[opt].intersection(tags)]
# class TagFilterMultiOptimizer(Optimizer):
# def __init__(self, all, whitelist = None, blacklist = None):
# self.all = all
# if whitelist is not None:
# self.whitelist = set(whitelist)
# else:
# self.whitelist = None
# if blacklist is not None:
# self.blacklist = set(blacklist)
# else:
# self.blacklist = set()
# def use_whitelist(self, use = True):
# if self.whitelist is None and use:
# self.whitelist = set()
# def allow(self, *tags):
# if self.whitelist is not None:
# self.whitelist.update(tags)
# self.blacklist.difference_update(tags)
# def deny(self, *tags):
# if self.whitelist is not None:
# self.whitelist.difference_update(tags)
# self.blacklist.update(tags)
# def dont_care(self, *tags):
# if self.whitelist is not None:
# self.whitelist.difference_update(tags)
# self.blacklist.difference_update(tags)
# def opts(self):
# if self.whitelist is not None:
# return self.all.filter(self.whitelist, self.blacklist)
# else:
# return self.all.blacklist(*[tag for tag in self.blacklist])
# def apply(self, env):
# for opt in self.opts():
# opt.apply(env)
gof/toolbox.py
浏览文件 @
14090d83
from
random
import
shuffle
import
utils
from
functools
import
partial
from
functools
import
partial
import
graph
import
graph
...
@@ -14,51 +12,6 @@ class Bookkeeper:
...
@@ -14,51 +12,6 @@ class Bookkeeper:
def
on_detach
(
self
,
env
):
def
on_detach
(
self
,
env
):
for
node
in
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
):
for
node
in
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
):
self
.
on_prune
(
env
,
node
)
self
.
on_prune
(
env
,
node
)
# class Toposorter:
# def on_attach(self, env):
# if hasattr(env, 'toposort'):
# raise Exception("Toposorter feature is already present or in conflict with another plugin.")
# env.toposort = partial(self.toposort, env)
# def on_detach(self, env):
# del env.toposort
# def toposort(self, env):
# ords = {}
# for feature in env._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings(env).items():
# ords.setdefault(op, set()).update(prereqs)
# order = graph.io_toposort(env.inputs, env.outputs, ords)
# return order
# def supplemental_orderings(self):
# """
# Returns a dictionary of {op: set(prerequisites)} that must
# be satisfied in addition to the order defined by the structure
# of the graph (returns orderings that not related to input/output
# relationships).
# """
# ords = {}
# for feature in self._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings().items():
# ords.setdefault(op, set()).update(prereqs)
# return ords
# def toposort(self):
# """
# Returns a list of nodes in the order that they must be executed
# in order to preserve the semantics of the graph and respect
# the constraints put forward by the listeners.
# """
# ords = self.supplemental_orderings()
# order = graph.io_toposort(self.inputs, self.outputs, ords)
# return order
class
History
:
class
History
:
...
@@ -223,151 +176,3 @@ class PrintListener(object):
...
@@ -223,151 +176,3 @@ class PrintListener(object):
# class EquivTool(dict):
# def __init__(self, env):
# self.env = env
# def on_rewire(self, clients, r, new_r):
# repl = self(new_r)
# if repl is r:
# self.ungroup(r, new_r)
# elif repl is not new_r:
# raise Exception("Improper use of EquivTool!")
# else:
# self.group(new_r, r)
# def publish(self):
# self.env.equiv = self
# self.env.set_equiv = self.set_equiv
# def unpublish(self):
# del self.env.equiv
# del self.env.set_equiv
# def set_equiv(self, d):
# self.update(d)
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
# class InstanceFinder(Listener, Tool, dict):
# def __init__(self, env):
# self.env = env
# def all_bases(self, cls):
# return utils.all_bases(cls, lambda cls: cls is not object)
# def on_import(self, op):
# for base in self.all_bases(op.__class__):
# self.setdefault(base, set()).add(op)
# def on_prune(self, op):
# for base in self.all_bases(op.__class__):
# self[base].remove(op)
# if not self[base]:
# del self[base]
# def __query__(self, cls):
# all = [x for x in self.get(cls, [])]
# shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, cls):
# return self.__query__(cls)
# def publish(self):
# self.env.get_instances_of = self.query
# class DescFinder(Listener, Tool, dict):
# def __init__(self, env):
# self.env = env
# def on_import(self, op):
# self.setdefault(op.desc(), set()).add(op)
# def on_prune(self, op):
# desc = op.desc()
# self[desc].remove(op)
# if not self[desc]:
# del self[desc]
# def __query__(self, desc):
# all = [x for x in self.get(desc, [])]
# shuffle(all) # this helps for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, desc):
# return self.__query__(desc)
# def publish(self):
# self.env.get_from_desc = self.query
### UNUSED AND UNTESTED ###
# class ChangeListener(Listener):
# def __init__(self, env):
# self.change = False
# def on_import(self, op):
# self.change = True
# def on_prune(self, op):
# self.change = True
# def on_rewire(self, clients, r, new_r):
# self.change = True
# def __call__(self, value = "get"):
# if value == "get":
# return self.change
# else:
# self.change = value
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论