Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a5129f65
提交
a5129f65
authored
2月 29, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
redone result, op, tests for result, op, graph
上级
4a6845b1
显示空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
1163 行增加
和
724 行删除
+1163
-724
_test_graph.py
gof/_test_graph.py
+124
-0
_test_op.py
gof/_test_op.py
+146
-0
_test_result.py
gof/_test_result.py
+26
-0
env.py
gof/env.py
+17
-87
ext.py
gof/ext.py
+8
-68
features.py
gof/features.py
+188
-206
graph.py
gof/graph.py
+76
-70
op.py
gof/op.py
+422
-196
result.py
gof/result.py
+147
-82
utils.py
gof/utils.py
+9
-15
没有找到文件。
gof/_test_graph.py
0 → 100644
浏览文件 @
a5129f65
import
unittest
from
graph
import
*
from
op
import
Op
from
result
import
ResultBase
,
BrokenLinkError
class
MyResult
(
ResultBase
):
def
__init__
(
self
,
thingy
):
self
.
thingy
=
thingy
ResultBase
.
__init__
(
self
,
role
=
None
,
data
=
[
self
.
thingy
],
constant
=
False
)
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
MyResult
)
and
other
.
thingy
==
self
.
thingy
def
__str__
(
self
):
return
str
(
self
.
thingy
)
def
__repr__
(
self
):
return
str
(
self
.
thingy
)
class
MyOp
(
Op
):
def
validate_update
(
self
):
for
input
in
self
.
inputs
:
if
not
isinstance
(
input
,
MyResult
):
raise
Exception
(
"Error 1"
)
self
.
outputs
=
[
MyResult
(
sum
([
input
.
thingy
for
input
in
self
.
inputs
]))]
class
_test_inputs
(
unittest
.
TestCase
):
def
test_0
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
assert
inputs
(
op
.
outputs
)
==
set
([
r1
,
r2
])
def
test_1
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
r5
)
assert
inputs
(
op2
.
outputs
)
==
set
([
r1
,
r2
,
r5
])
class
_test_orphans
(
unittest
.
TestCase
):
def
test_0
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
r5
)
assert
orphans
([
r1
,
r2
],
op2
.
outputs
)
==
set
([
r5
])
class
_test_as_string
(
unittest
.
TestCase
):
leaf_formatter
=
str
node_formatter
=
lambda
op
,
argstrings
:
"
%
s(
%
s)"
%
(
op
.
__class__
.
__name__
,
", "
.
join
(
argstrings
))
def
test_0
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
assert
as_string
([
r1
,
r2
],
op
.
outputs
)
==
[
"MyOp(1, 2)"
]
def
test_1
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
r5
)
assert
as_string
([
r1
,
r2
,
r5
],
op2
.
outputs
)
==
[
"MyOp(MyOp(1, 2), 5)"
]
def
test_2
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
op
.
outputs
[
0
])
assert
as_string
([
r1
,
r2
,
r5
],
op2
.
outputs
)
==
[
"MyOp(*1 -> MyOp(1, 2), *1)"
]
def
test_3
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
op
.
outputs
[
0
])
assert
as_string
(
op
.
outputs
,
op2
.
outputs
)
==
[
"MyOp(3, 3)"
]
assert
as_string
(
op2
.
inputs
,
op2
.
outputs
)
==
[
"MyOp(3, 3)"
]
class
_test_clone
(
unittest
.
TestCase
):
def
test_0
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
new
=
clone
([
r1
,
r2
],
op
.
outputs
)
assert
as_string
([
r1
,
r2
],
new
)
==
[
"MyOp(1, 2)"
]
def
test_1
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
r5
)
new
=
clone
([
r1
,
r2
,
r5
],
op2
.
outputs
)
assert
op2
.
outputs
[
0
]
==
new
[
0
]
and
op2
.
outputs
[
0
]
is
not
new
[
0
]
assert
op2
is
not
new
[
0
]
.
owner
assert
new
[
0
]
.
owner
.
inputs
[
1
]
is
r5
assert
new
[
0
]
.
owner
.
inputs
[
0
]
==
op
.
outputs
[
0
]
and
new
[
0
]
.
owner
.
inputs
[
0
]
is
not
op
.
outputs
[
0
]
def
test_2
(
self
):
"Checks that manipulating a cloned graph leaves the original unchanged."
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
MyOp
(
r1
,
r2
)
.
outputs
[
0
],
r5
)
new
=
clone
([
r1
,
r2
,
r5
],
op
.
outputs
)
new_op
=
new
[
0
]
.
owner
new_op
.
inputs
=
MyResult
(
7
),
MyResult
(
8
)
assert
as_string
(
inputs
(
new_op
.
outputs
),
new_op
.
outputs
)
==
[
"MyOp(7, 8)"
]
assert
as_string
(
inputs
(
op
.
outputs
),
op
.
outputs
)
==
[
"MyOp(MyOp(1, 2), 5)"
]
if
__name__
==
'__main__'
:
unittest
.
main
()
gof/_test_op.py
0 → 100644
浏览文件 @
a5129f65
import
unittest
from
copy
import
copy
from
op
import
*
from
result
import
ResultBase
,
BrokenLinkError
class
MyResult
(
ResultBase
):
def
__init__
(
self
,
thingy
):
self
.
thingy
=
thingy
ResultBase
.
__init__
(
self
,
role
=
None
,
data
=
[
self
.
thingy
],
constant
=
False
)
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
MyResult
)
and
other
.
thingy
==
self
.
thingy
def
__str__
(
self
):
return
str
(
self
.
thingy
)
def
__repr__
(
self
):
return
str
(
self
.
thingy
)
class
MyOp
(
Op
):
def
validate_update
(
self
):
for
input
in
self
.
inputs
:
if
not
isinstance
(
input
,
MyResult
):
raise
Exception
(
"Error 1"
)
self
.
outputs
=
[
MyResult
(
sum
([
input
.
thingy
for
input
in
self
.
inputs
]))]
class
_test_Op
(
unittest
.
TestCase
):
# Sanity tests
def
test_sanity_0
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
assert
op
.
inputs
==
[
r1
,
r2
]
# Are the inputs what I provided?
assert
op
.
outputs
==
[
MyResult
(
3
)]
# Are the outputs what I expect?
# validate_update
def
test_validate_update
(
self
):
try
:
MyOp
(
ResultBase
(),
MyResult
(
1
))
# MyOp requires MyResult instances
except
Exception
,
e
:
assert
str
(
e
)
==
"Error 1"
else
:
raise
Exception
(
"Expected an exception"
)
# Setting inputs and outputs
def
test_set_inputs
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
r3
=
op
.
outputs
[
0
]
op
.
inputs
=
MyResult
(
4
),
MyResult
(
5
)
assert
op
.
outputs
==
[
MyResult
(
9
)]
# check if the output changed to what I expect
assert
r3
.
data
is
op
.
outputs
[
0
]
.
data
# check if the data was properly transferred by set_output
def
test_set_bad_inputs
(
self
):
op
=
MyOp
(
MyResult
(
1
),
MyResult
(
2
))
try
:
op
.
inputs
=
MyResult
(
4
),
ResultBase
()
except
Exception
,
e
:
assert
str
(
e
)
==
"Error 1"
else
:
raise
Exception
(
"Expected an exception"
)
def
test_set_outputs
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
# here we only make one output
try
:
op
.
outputs
=
MyResult
(
10
),
MyResult
(
11
)
# setting two outputs should fail
except
TypeError
,
e
:
assert
str
(
e
)
==
"The new outputs must be exactly as many as the previous outputs."
else
:
raise
Exception
(
"Expected an exception"
)
# Tests about broken links
def
test_create_broken_link
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
r3
=
op
.
out
op
.
inputs
=
MyResult
(
3
),
MyResult
(
4
)
assert
r3
not
in
op
.
outputs
assert
r3
.
replaced
def
test_cannot_copy_when_input_is_broken_link
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
r3
=
op
.
out
op2
=
MyOp
(
r3
)
op
.
inputs
=
MyResult
(
3
),
MyResult
(
4
)
try
:
copy
(
op2
)
except
BrokenLinkError
:
pass
else
:
raise
Exception
(
"Expected an exception"
)
def
test_get_input_broken_link
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
r3
=
op
.
out
op2
=
MyOp
(
r3
)
op
.
inputs
=
MyResult
(
3
),
MyResult
(
4
)
try
:
op2
.
get_input
(
0
)
except
BrokenLinkError
:
pass
else
:
raise
Exception
(
"Expected an exception"
)
def
test_get_inputs_broken_link
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
r3
=
op
.
out
op2
=
MyOp
(
r3
)
op
.
inputs
=
MyResult
(
3
),
MyResult
(
4
)
try
:
op2
.
get_inputs
()
except
BrokenLinkError
:
pass
else
:
raise
Exception
(
"Expected an exception"
)
def
test_repair_broken_link
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
r3
=
op
.
out
op2
=
MyOp
(
r3
,
MyResult
(
10
))
op
.
inputs
=
MyResult
(
3
),
MyResult
(
4
)
op2
.
repair
()
assert
op2
.
outputs
==
[
MyResult
(
17
)]
# Tests about string representation
def
test_create_broken_link
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
assert
str
(
op
)
==
"MyOp(1, 2)"
if
__name__
==
'__main__'
:
unittest
.
main
()
gof/_test_result.py
0 → 100644
浏览文件 @
a5129f65
import
unittest
from
result
import
*
class
_test_ResultBase
(
unittest
.
TestCase
):
def
test_0
(
self
):
r
=
ResultBase
()
def
test_1
(
self
):
r
=
ResultBase
()
assert
r
.
state
is
Empty
r
.
data
=
0
assert
r
.
data
==
0
assert
r
.
state
is
Computed
r
.
data
=
1
assert
r
.
data
==
1
assert
r
.
state
is
Computed
r
.
data
=
None
assert
r
.
data
==
None
assert
r
.
state
is
Empty
if
__name__
==
'__main__'
:
unittest
.
main
()
gof/env.py
浏览文件 @
a5129f65
...
@@ -6,53 +6,21 @@ from utils import ClsInit
...
@@ -6,53 +6,21 @@ from utils import ClsInit
from
err
import
GofError
,
GofTypeError
,
PropagationError
from
err
import
GofError
,
GofTypeError
,
PropagationError
from
op
import
Op
from
op
import
Op
from
result
import
is_result
from
result
import
is_result
from
features
import
Listener
,
Orderings
,
Constraint
,
Tool
from
features
import
Listener
,
Orderings
,
Constraint
,
Tool
,
uniq_features
import
utils
import
utils
__all__
=
[
'InconsistencyError'
,
__all__
=
[
'InconsistencyError'
,
'Env'
]
'Env'
]
# class AliasDict(dict):
class
InconsistencyError
(
Exception
):
# "Utility class to keep track of what result has been replaced with what result."
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
#TODO: why is this not in err.py? -James
class
InconsistencyError
(
GofError
):
"""
"""
This exception is raised by Env whenever one of the listeners marks
This exception is raised by Env whenever one of the listeners marks
the graph as inconsistent.
the graph as inconsistent.
"""
"""
pass
pass
def
require_set
(
cls
):
def
require_set
(
cls
):
"""Return the set of objects named in a __env_require__ field in a base class"""
"""Return the set of objects named in a __env_require__ field in a base class"""
r
=
set
()
r
=
set
()
...
@@ -70,22 +38,22 @@ def require_set(cls):
...
@@ -70,22 +38,22 @@ def require_set(cls):
r
.
add
(
req
)
r
.
add
(
req
)
return
r
return
r
class
Env
(
graph
.
Graph
):
class
Env
(
graph
.
Graph
):
"""
"""
An Env represents a subgraph bound by a set of input results and a set of output
An Env represents a subgraph bound by a set of input results and a
results. An op is in the subgraph iff it depends on the value of some of the Env's
set of output results. An op is in the subgraph iff it depends on
inputs _and_ some of the Env's outputs depend on it. A result is in the subgraph
the value of some of the Env's inputs _and_ some of the Env's
iff it is an input or an output of an op that is in the subgraph.
outputs depend on it. A result is in the subgraph iff it is an
input or an output of an op that is in the subgraph.
The Env supports the replace operation which allows to replace a
result in the
The Env supports the replace operation which allows to replace a
subgraph by another, e.g. replace (x + x).out by (2 * x).out. This is the basis
result in the subgraph by another, e.g. replace (x + x).out by (2
for optimization in omega.
* x).out. This is the basis
for optimization in omega.
An Env can have listeners, which are instances of EnvListener. Each listener is
An Env's functionality can be extended with features, which must
informed of any op entering or leaving the subgraph (which happens at construction
be subclasses of L{Listener}, L{Constraint}, L{Orderings} or
time and whenever there is a replacement). In addition to that, each listener can
L{Tool}.
implement the 'consistent' and 'ordering' methods (see EnvListener) in order to
restrict how ops in the subgraph can be related.
Regarding inputs and orphans:
Regarding inputs and orphans:
In the context of a computation graph, the inputs and orphans are both
In the context of a computation graph, the inputs and orphans are both
...
@@ -93,7 +61,6 @@ class Env(graph.Graph):
...
@@ -93,7 +61,6 @@ class Env(graph.Graph):
named as inputs will be assumed to contain fresh. In other words, the
named as inputs will be assumed to contain fresh. In other words, the
backward search from outputs will stop at any node that has been explicitly
backward search from outputs will stop at any node that has been explicitly
named as an input.
named as an input.
"""
"""
### Special ###
### Special ###
...
@@ -102,11 +69,6 @@ class Env(graph.Graph):
...
@@ -102,11 +69,6 @@ class Env(graph.Graph):
"""
"""
Create an Env which operates on the subgraph bound by the inputs and outputs
Create an Env which operates on the subgraph bound by the inputs and outputs
sets. If consistency_check is False, an illegal graph will be tolerated.
sets. If consistency_check is False, an illegal graph will be tolerated.
Features are class types derived from things in the tools file. These
can be listeners, constraints, orderings, etc. Features add much
(most?) functionality to an Env.
"""
"""
self
.
_features
=
{}
self
.
_features
=
{}
...
@@ -115,30 +77,12 @@ class Env(graph.Graph):
...
@@ -115,30 +77,12 @@ class Env(graph.Graph):
self
.
_orderings
=
{}
self
.
_orderings
=
{}
self
.
_tools
=
{}
self
.
_tools
=
{}
# self._preprocessors = set()
# for feature in features:
# if issubclass(feature, tools.Preprocessor):
# preprocessor = feature()
# self._preprocessors.add(preprocessor)
# inputs, outputs = preprocessor.transform(inputs, outputs)
# The inputs and outputs set bound the subgraph this Env operates on.
# The inputs and outputs set bound the subgraph this Env operates on.
self
.
inputs
=
set
(
inputs
)
self
.
inputs
=
set
(
inputs
)
self
.
outputs
=
set
(
outputs
)
self
.
outputs
=
set
(
outputs
)
for
feature_class
in
u
tils
.
u
niq_features
(
features
):
for
feature_class
in
uniq_features
(
features
):
self
.
add_feature
(
feature_class
,
False
)
self
.
add_feature
(
feature_class
,
False
)
# feature = feature_class(self)
# if isinstance(feature, tools.Listener):
# self._listeners.add(feature)
# if isinstance(feature, tools.Constraint):
# self._constraints.add(feature)
# if isinstance(feature, tools.Orderings):
# self._orderings.add(feature)
# if isinstance(feature, tools.Tool):
# self._tools.add(feature)
# feature.publish()
# All ops in the subgraph defined by inputs and outputs are cached in _ops
# All ops in the subgraph defined by inputs and outputs are cached in _ops
self
.
_ops
=
set
()
self
.
_ops
=
set
()
...
@@ -234,20 +178,6 @@ class Env(graph.Graph):
...
@@ -234,20 +178,6 @@ class Env(graph.Graph):
except
KeyError
:
except
KeyError
:
pass
pass
# for i, feature in enumerate(self._features):
# if isinstance(feature, feature_class): # exact class or subclass, nothing to do
# return
# elif issubclass(feature_class, feature.__class__): # superclass, we replace it
# new_feature = feature_class(self)
# self._features[i] = new_feature
# break
# else:
# new_feature = feature_class(self)
# self._features.append(new_feature)
# if isinstance(new_feature, tools.Listener):
# for op in self.io_toposort():
# new_feature.on_import(op)
def
get_feature
(
self
,
feature_class
):
def
get_feature
(
self
,
feature_class
):
try
:
try
:
return
self
.
_features
[
feature_class
]
return
self
.
_features
[
feature_class
]
...
@@ -325,7 +255,7 @@ class Env(graph.Graph):
...
@@ -325,7 +255,7 @@ class Env(graph.Graph):
self
.
outputs
.
add
(
new_r
)
self
.
outputs
.
add
(
new_r
)
# The actual replacement operation occurs here. This might raise
# The actual replacement operation occurs here. This might raise
# a
GofTypeError
# a
n error.
self
.
__move_clients__
(
clients
,
r
,
new_r
)
self
.
__move_clients__
(
clients
,
r
,
new_r
)
# This function undoes the replacement.
# This function undoes the replacement.
...
...
gof/ext.py
浏览文件 @
a5129f65
from
copy
import
copy
#
from copy import copy
from
op
import
Op
#
from op import Op
from
lib
import
DummyOp
#
from lib import DummyOp
from
features
import
Listener
,
Constraint
,
Orderings
#
from features import Listener, Constraint, Orderings
from
env
import
InconsistencyError
#
from env import InconsistencyError
from
utils
import
ClsInit
#
from utils import ClsInit
import
graph
#
import graph
#TODO: move mark_outputs_as_destroyed to the place that uses this function
__all__
=
[
'Destroyer'
,
'Viewer'
]
#TODO: move Return to where it is used.
__all__
=
[
'IONames'
,
'mark_outputs_as_destroyed'
]
class
IONames
:
"""
Requires assigning a name to each of this Op's inputs and outputs.
"""
__metaclass__
=
ClsInit
input_names
=
()
output_names
=
()
@staticmethod
def
__clsinit__
(
cls
,
name
,
bases
,
dct
):
for
names
in
[
'input_names'
,
'output_names'
]:
if
names
in
dct
:
x
=
getattr
(
cls
,
names
)
if
isinstance
(
x
,
str
):
x
=
[
x
,]
setattr
(
cls
,
names
,
x
)
if
isinstance
(
x
,
(
list
,
tuple
)):
x
=
[
a
for
a
in
x
]
setattr
(
cls
,
names
,
x
)
for
i
,
varname
in
enumerate
(
x
):
if
not
isinstance
(
varname
,
str
)
or
hasattr
(
cls
,
varname
)
or
varname
in
[
'inputs'
,
'outputs'
]:
raise
TypeError
(
"In
%
s: '
%
s' is not a valid input or output name"
%
(
cls
.
__name__
,
varname
))
# Set an attribute for the variable so we can do op.x to return the input or output named "x".
setattr
(
cls
,
varname
,
property
(
lambda
op
,
type
=
names
.
replace
(
'_name'
,
''
),
index
=
i
:
getattr
(
op
,
type
)[
index
]))
else
:
print
'ERROR: Class variable
%
s::
%
s is neither list, tuple, or string'
%
(
name
,
names
)
raise
TypeError
,
str
(
names
)
else
:
setattr
(
cls
,
names
,
())
# def __init__(self, inputs, outputs, use_self_setters = False):
# assert len(inputs) == len(self.input_names)
# assert len(outputs) == len(self.output_names)
# Op.__init__(self, inputs, outputs, use_self_setters)
def
__validate__
(
self
):
assert
len
(
self
.
inputs
)
==
len
(
self
.
input_names
)
assert
len
(
self
.
outputs
)
==
len
(
self
.
output_names
)
@classmethod
def
n_inputs
(
cls
):
return
len
(
cls
.
input_names
)
@classmethod
def
n_outputs
(
cls
):
return
len
(
cls
.
output_names
)
def
get_by_name
(
self
,
name
):
"""
Returns the input or output which corresponds to the given name.
"""
if
name
in
self
.
input_names
:
return
self
.
input_names
[
self
.
input_names
.
index
(
name
)]
elif
name
in
self
.
output_names
:
return
self
.
output_names
[
self
.
output_names
.
index
(
name
)]
else
:
raise
AttributeError
(
"No such input or output name for
%
s:
%
s"
%
(
self
.
__class__
.
__name__
,
name
))
...
...
gof/features.py
浏览文件 @
a5129f65
from
copy
import
copy
# from copy import copy
from
op
import
Op
# from op import Op
import
result
# import result
import
graph
# import graph
import
utils
import
utils
from
random
import
shuffle
#
from random import shuffle
__all__
=
[
'Feature'
,
__all__
=
[
'Feature'
,
...
@@ -13,10 +14,11 @@ __all__ = ['Feature',
...
@@ -13,10 +14,11 @@ __all__ = ['Feature',
'Constraint'
,
'Constraint'
,
'Orderings'
,
'Orderings'
,
'Tool'
,
'Tool'
,
'EquivTool'
,
'uniq_features'
,
'InstanceFinder'
,
# 'EquivTool',
'PrintListener'
,
# 'InstanceFinder',
'ChangeListener'
,
# 'PrintListener',
# 'ChangeListener',
]
]
...
@@ -24,256 +26,236 @@ __all__ = ['Feature',
...
@@ -24,256 +26,236 @@ __all__ = ['Feature',
class
Feature
(
object
):
class
Feature
(
object
):
def
__init__
(
self
,
env
):
def
__init__
(
self
,
env
):
"""
Initializes the Feature's env field to the parameter
provided.
"""
self
.
env
=
env
self
.
env
=
env
class
Listener
(
Feature
):
class
Listener
(
Feature
):
"""
When registered by an env, each listener is informed of any op
entering or leaving the subgraph (which happens at construction
time and whenever there is a replacement).
"""
def
on_import
(
self
,
op
):
def
on_import
(
self
,
op
):
pass
"""
This method is called by the env whenever a new op is
added to the graph.
"""
raise
utils
.
AbstractFunctionError
()
def
on_prune
(
self
,
op
):
def
on_prune
(
self
,
op
):
pass
"""
This method is called by the env whenever an op is
removed from the graph.
"""
raise
utils
.
AbstractFunctionError
()
def
on_rewire
(
self
,
clients
,
r
,
new_r
):
def
on_rewire
(
self
,
clients
,
r
,
new_r
):
pass
"""
clients -> (op, i) pairs such that op.inputs[i] is new_r
but used to be r
r -> the old result that was used by the ops in clients
new_r -> the new result that is now used by the ops in clients
Note that the change from r to new_r is done before this
method is called.
"""
raise
utils
.
AbstractFunctionError
()
class
Constraint
(
Feature
):
class
Constraint
(
Feature
):
"""
When registered by an env, a Constraint can restrict the ops that
can be in the subgraph or restrict the ways ops interact with each
other.
"""
def
validate
(
self
):
def
validate
(
self
):
return
True
"""
Raises an L{InconsistencyError} if the env is currently
invalid from the perspective of this object.
"""
raise
utils
.
AbstractFunctionError
()
class
Orderings
(
Feature
):
class
Orderings
(
Feature
):
"""
When registered by an env, an Orderings object can provide supplemental
ordering constraints to the subgraph's topological sort.
"""
def
orderings
(
self
):
def
orderings
(
self
):
return
{}
"""
Returns {op: set(ops that must be evaluated before this op), ...}
This is called by env.orderings() and used in env.toposort() but
not in env.io_toposort().
"""
raise
utils
.
AbstractFunctionError
()
class
Tool
(
Feature
):
class
Tool
(
Feature
):
"""
A Tool can extend the functionality of an env so that, for example,
optimizations can have access to efficient ways to search the graph.
"""
def
publish
(
self
):
def
publish
(
self
):
pass
"""
This is only called once by the env, when the Tool is added.
Adds methods to env.
# class Preprocessor(Feature):
"""
raise
utils
.
AbstractFunctionError
()
# def transform(self, inputs, outputs):
# return inputs, outputs
# def __call__(self, inputs, outputs):
def
uniq_features
(
_features
,
*
_rest
):
# return self.transform(inputs, outputs)
"""Return a list such that no element is a subclass of another"""
# used in Env.__init__ to
features
=
[
x
for
x
in
_features
]
# class Optimization(object):
for
other
in
_rest
:
features
+=
[
x
for
x
in
other
]
# def require(self):
res
=
[]
# return []
while
features
:
feature
=
features
.
pop
()
# def apply(self, env):
for
feature2
in
features
:
# pass
if
issubclass
(
feature2
,
feature
):
break
# def __call__(self, env):
# return self.apply(env)
# Optimization
# * require <tool_class>*
# * apply
# Prog
# * __init__
# * inputs
# * outputs
# * listeners, constraints, orderings
# * dispatched by isinstance Listener, etc.
# * {tool_class: preferred_implementation, ...}
class
EquivTool
(
Listener
,
Tool
,
dict
):
def
on_rewire
(
self
,
clients
,
r
,
new_r
):
repl
=
self
(
new_r
)
if
repl
is
r
:
self
.
ungroup
(
r
,
new_r
)
elif
repl
is
not
new_r
:
raise
Exception
(
"Improper use of EquivTool!"
)
else
:
self
.
group
(
new_r
,
r
)
def
publish
(
self
):
self
.
env
.
equiv
=
self
def
group
(
self
,
main
,
*
keys
):
"Marks all the keys as having been replaced by the Result main."
keys
=
[
key
for
key
in
keys
if
key
is
not
main
]
if
self
.
has_key
(
main
):
raise
Exception
(
"Only group results that have not been grouped before."
)
for
key
in
keys
:
if
self
.
has_key
(
key
):
raise
Exception
(
"Only group results that have not been grouped before."
)
if
key
is
main
:
continue
self
.
setdefault
(
key
,
main
)
def
ungroup
(
self
,
main
,
*
keys
):
"Undoes group(main, *keys)"
keys
=
[
key
for
key
in
keys
if
key
is
not
main
]
for
key
in
keys
:
if
self
[
key
]
is
main
:
del
self
[
key
]
def
__call__
(
self
,
key
):
"Returns the currently active replacement for the given key."
next
=
self
.
get
(
key
,
None
)
while
next
:
key
=
next
next
=
self
.
get
(
next
,
None
)
return
key
class
InstanceFinder
(
Listener
,
Tool
,
dict
):
def
__init__
(
self
,
env
):
self
.
env
=
env
def
all_bases
(
self
,
cls
):
return
utils
.
all_bases
(
cls
,
lambda
cls
:
issubclass
(
cls
,
Op
))
# return [cls for cls in utils.all_bases(cls) if issubclass(cls, Op)]
def
on_import
(
self
,
op
):
for
base
in
self
.
all_bases
(
op
.
__class__
):
self
.
setdefault
(
base
,
set
())
.
add
(
op
)
def
on_prune
(
self
,
op
):
for
base
in
self
.
all_bases
(
op
.
__class__
):
self
[
base
]
.
remove
(
op
)
if
not
self
[
base
]:
del
self
[
base
]
def
__query__
(
self
,
cls
):
all
=
[
x
for
x
in
self
.
get
(
cls
,
[])]
shuffle
(
all
)
# this helps a lot for debugging because the order of the replacements will vary
while
all
:
next
=
all
.
pop
()
if
next
in
self
.
env
.
ops
():
yield
next
def
query
(
self
,
cls
):
return
self
.
__query__
(
cls
)
def
publish
(
self
):
self
.
env
.
get_instances_of
=
self
.
query
class
PrintListener
(
Listener
):
def
__init__
(
self
,
env
,
active
=
True
):
self
.
env
=
env
self
.
active
=
active
if
active
:
print
"-- initializing"
def
on_import
(
self
,
op
):
if
self
.
active
:
print
"-- importing:
%
s"
%
graph
.
as_string
(
self
.
env
.
inputs
,
op
.
outputs
)
def
on_prune
(
self
,
op
):
if
self
.
active
:
print
"-- pruning:
%
s"
%
graph
.
as_string
(
self
.
env
.
inputs
,
op
.
outputs
)
def
on_rewire
(
self
,
clients
,
r
,
new_r
):
if
self
.
active
:
if
r
.
owner
is
None
:
rg
=
id
(
r
)
#r.name
else
:
else
:
rg
=
graph
.
as_string
(
self
.
env
.
inputs
,
r
.
owner
.
outputs
)
res
.
append
(
feature
)
if
new_r
.
owner
is
None
:
return
res
new_rg
=
id
(
new_r
)
#new_r.name
else
:
new_rg
=
graph
.
as_string
(
self
.
env
.
inputs
,
new_r
.
owner
.
outputs
)
print
"-- moving from
%
s to
%
s"
%
(
rg
,
new_rg
)
class
ChangeListener
(
Listener
):
def
__init__
(
self
,
env
):
# MOVE TO LIB
self
.
change
=
False
def
on_import
(
self
,
op
):
# class EquivTool(Listener, Tool, dict):
self
.
change
=
True
def
on_prune
(
self
,
op
):
# def on_rewire(self, clients, r, new_r):
self
.
change
=
True
# repl = self(new_r)
# if repl is r:
# self.ungroup(r, new_r)
# elif repl is not new_r:
# raise Exception("Improper use of EquivTool!")
# else:
# self.group(new_r, r)
def
on_rewire
(
self
,
clients
,
r
,
new_r
):
# def publish(self):
self
.
change
=
True
# self.env.equiv = self
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
# class InstanceFinder(Listener, Tool, dict):
# def __init__(self, env):
# self.env = env
def
__call__
(
self
,
value
=
"get"
):
# def all_bases(self, cls):
if
value
==
"get"
:
# return utils.all_bases(cls, lambda cls: issubclass(cls, Op))
return
self
.
change
# # return [cls for cls in utils.all_bases(cls) if issubclass(cls, Op)]
else
:
self
.
change
=
value
# def on_import(self, op):
# for base in self.all_bases(op.__class__):
# self.setdefault(base, set()).add(op)
# def on_prune(self, op):
# for base in self.all_bases(op.__class__):
# self[base].remove(op)
# if not self[base]:
# del self[base]
# def __query__(self, cls):
# all = [x for x in self.get(cls, [])]
# shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, cls):
# return self.__query__(cls)
# def publish(self):
# self.env.get_instances_of = self.query
# class
SuperFinder(Listener, Tool, dict
):
# class
PrintListener(Listener
):
# def __init__(self, env,
props
):
# def __init__(self, env,
active = True
):
# self.env = env
# self.env = env
# self.props = props
# self.active = active
# if active:
# print "-- initializing"
# def on_import(self, op):
# def on_import(self, op):
#
for prop, value in self.props(op).items()
:
#
if self.active
:
#
self.setdefault(prop, {}).setdefault(value, set()).add(op
)
#
print "-- importing: %s" % graph.as_string(self.env.inputs, op.outputs
)
# def on_prune(self, op):
# def on_prune(self, op):
# for prop, value in self.props(op).items():
# if self.active:
# self[prop][value].remove(op)
# print "-- pruning: %s" % graph.as_string(self.env.inputs, op.outputs)
# if len(self[prop][value]) == 0:
# del self[prop][value]
# if len(self[prop]) == 0:
# del self[prop]
# def __query__(self, order, template):
# all = []
# for prop, value in template.items():
# all += [x for x in self.get(prop, {}).get(value, set())]
# # If not None, the order option requires the order listener to be included in the env under the name 'order'
# if order == 'o->i':
# all.sort(lambda op1, op2: self.env.order[op1].__cmp__(self.env.order[op2]))
# elif order == 'i->o':
# all.sort(lambda op1, op2: self.env.order[op2].__cmp__(self.env.order[op1]))
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, **template):
# def on_rewire(self, clients, r, new_r):
# return self.__query__(None, template)
# if self.active:
# if r.owner is None:
# rg = id(r) #r.name
# else:
# rg = graph.as_string(self.env.inputs, r.owner.outputs)
# if new_r.owner is None:
# new_rg = id(new_r) #new_r.name
# else:
# new_rg = graph.as_string(self.env.inputs, new_r.owner.outputs)
# print "-- moving from %s to %s" % (rg, new_rg)
# def query_downstream(self, **template):
# return self.__query__('i->o', template)
# def query_upstream(self, **template):
# return self.__query__('o->i', template)
# def publish(self):
# class ChangeListener(Listener):
# self.env.query = self.query
# def __init__(self, env):
# self.change = False
# def on_import(self, op):
# self.change = True
# def on_prune(self, op):
# self.change = True
# def on_rewire(self, clients, r, new_r):
# self.change = True
# def __call__(self, value = "get"):
# if value == "get":
# return self.change
# else:
# self.change = value
gof/graph.py
浏览文件 @
a5129f65
from
copy
import
copy
from
copy
import
copy
from
result
import
BrokenLink
,
BrokenLinkError
from
op
import
Op
import
utils
import
utils
...
@@ -11,11 +9,17 @@ __all__ = ['inputs',
...
@@ -11,11 +9,17 @@ __all__ = ['inputs',
'ops'
,
'ops'
,
'clone'
,
'clone_get_equiv'
,
'clone'
,
'clone_get_equiv'
,
'io_toposort'
,
'io_toposort'
,
'default_leaf_formatter'
,
'default_node_formatter'
,
'op_as_string'
,
'as_string'
,
'as_string'
,
'Graph'
]
'Graph'
]
def
inputs
(
o
,
repair
=
False
):
is_result
=
utils
.
attr_checker
(
'owner'
,
'index'
)
is_op
=
utils
.
attr_checker
(
'inputs'
,
'outputs'
)
def
inputs
(
o
):
"""
"""
o -> list of output Results
o -> list of output Results
...
@@ -24,21 +28,12 @@ def inputs(o, repair = False):
...
@@ -24,21 +28,12 @@ def inputs(o, repair = False):
"""
"""
results
=
set
()
results
=
set
()
def
seek
(
r
):
def
seek
(
r
):
if
isinstance
(
r
,
BrokenLink
):
raise
BrokenLinkError
op
=
r
.
owner
op
=
r
.
owner
if
op
is
None
:
if
op
is
None
:
results
.
add
(
r
)
results
.
add
(
r
)
else
:
else
:
for
i
in
range
(
len
(
op
.
inputs
)):
for
input
in
op
.
inputs
:
try
:
seek
(
input
)
seek
(
op
.
inputs
[
i
])
except
BrokenLinkError
:
if
repair
:
op
.
refresh
()
seek
(
op
.
inputs
[
i
])
else
:
raise
for
output
in
o
:
for
output
in
o
:
seek
(
output
)
seek
(
output
)
return
results
return
results
...
@@ -60,8 +55,6 @@ def results_and_orphans(i, o):
...
@@ -60,8 +55,6 @@ def results_and_orphans(i, o):
incomplete_paths
=
[]
incomplete_paths
=
[]
def
helper
(
r
,
path
):
def
helper
(
r
,
path
):
if
isinstance
(
r
,
BrokenLink
):
raise
BrokenLinkError
if
r
in
i
:
if
r
in
i
:
results
.
update
(
path
)
results
.
update
(
path
)
elif
r
.
owner
is
None
:
elif
r
.
owner
is
None
:
...
@@ -128,45 +121,56 @@ def orphans(i, o):
...
@@ -128,45 +121,56 @@ def orphans(i, o):
return
results_and_orphans
(
i
,
o
)[
1
]
return
results_and_orphans
(
i
,
o
)[
1
]
def
clone
(
i
,
o
):
def
clone
(
i
,
o
,
copy_inputs
=
False
):
"""
"""
i -> list of input Results
i -> list of input Results
o -> list of output Results
o -> list of output Results
copy_inputs -> if True, the inputs will be copied (defaults to False)
Copies the subgraph contained between i and o and returns the
Copies the subgraph contained between i and o and returns the
outputs of that copy (corresponding to o). The input Results in
outputs of that copy (corresponding to o).
the list are _not_ copied and the new graph refers to the
originals.
"""
"""
new_o
,
equiv
=
clone_get_equiv
(
i
,
o
)
equiv
=
clone_get_equiv
(
i
,
o
)
return
new_o
return
[
equiv
[
output
]
for
output
in
o
]
def
clone_get_equiv
(
i
,
o
,
copy_inputs
=
False
):
def
clone_get_equiv
(
i
,
o
,
copy_inputs
=
False
):
"""
"""
i -> list of input Results
i -> list of input Results
o -> list of output Results
o -> list of output Results
copy_inputs -> if True, the inputs will be replaced in the cloned
graph by copies available in the equiv dictionary
returned by the function (copy_inputs defaults to False)
Returns (new_o, equiv) where new_o are the outputs of a copy of
Returns equiv a dictionary mapping each result and op in the
the whole subgraph bounded by i and o and equiv is a dictionary
graph delimited by i and o to a copy (akin to deepcopy's memo).
that maps the original ops and results found in the subgraph to
their copy (akin to deepcopy's memo). See clone for more details.
"""
"""
d
=
{}
d
=
{}
for
op
in
ops
(
i
,
o
):
for
input
in
i
:
d
[
op
]
=
copy
(
op
)
if
copy_inputs
:
d
[
input
]
=
copy
(
input
)
else
:
d
[
input
]
=
input
def
clone_helper
(
result
):
if
result
in
d
:
return
d
[
result
]
op
=
result
.
owner
if
not
op
:
return
result
else
:
new_op
=
op
.
__class__
(
*
[
clone_helper
(
input
)
for
input
in
op
.
inputs
])
d
[
op
]
=
new_op
for
output
,
new_output
in
zip
(
op
.
outputs
,
new_op
.
outputs
):
d
[
output
]
=
new_output
return
d
[
result
]
for
old_op
,
op
in
d
.
items
():
for
output
in
o
:
for
old_output
,
output
in
zip
(
old_op
.
outputs
,
op
.
outputs
):
clone_helper
(
output
)
d
[
old_output
]
=
output
for
i
,
input
in
enumerate
(
op
.
inputs
):
owner
=
input
.
owner
if
owner
in
d
:
op
.
_inputs
[
i
]
=
d
[
owner
]
.
outputs
[
input
.
_index
]
return
[[
d
[
output
]
for
output
in
o
],
d
]
return
d
def
io_toposort
(
i
,
o
,
orderings
=
{}):
def
io_toposort
(
i
,
o
,
orderings
=
{}):
...
@@ -188,17 +192,32 @@ def io_toposort(i, o, orderings = {}):
...
@@ -188,17 +192,32 @@ def io_toposort(i, o, orderings = {}):
all
=
ops
(
i
,
o
)
all
=
ops
(
i
,
o
)
for
op
in
all
:
for
op
in
all
:
prereqs_d
.
setdefault
(
op
,
set
())
.
update
(
set
([
input
.
owner
for
input
in
op
.
inputs
if
input
.
owner
and
input
.
owner
in
all
]))
prereqs_d
.
setdefault
(
op
,
set
())
.
update
(
set
([
input
.
owner
for
input
in
op
.
inputs
if
input
.
owner
and
input
.
owner
in
all
]))
# prereqs_d[op] = set([input.owner for input in op.inputs if input.owner and input.owner in all])
return
utils
.
toposort
(
prereqs_d
)
return
utils
.
toposort
(
prereqs_d
)
def
as_string
(
i
,
o
):
default_leaf_formatter
=
str
default_node_formatter
=
lambda
op
,
argstrings
:
"
%
s(
%
s)"
%
(
op
.
__class__
.
__name__
,
", "
.
join
(
argstrings
))
def
op_as_string
(
i
,
op
,
leaf_formatter
=
default_leaf_formatter
,
node_formatter
=
default_node_formatter
):
strs
=
as_string
(
i
,
op
.
inputs
,
leaf_formatter
,
node_formatter
)
return
node_formatter
(
op
,
strs
)
def
as_string
(
i
,
o
,
leaf_formatter
=
default_leaf_formatter
,
node_formatter
=
default_node_formatter
):
"""
"""
i -> list of input Results
i -> list of input Results
o -> list of output Results
o -> list of output Results
leaf_formatter -> function that takes a result and returns a string to describe it
node_formatter -> function that takes an op and the list of strings corresponding
to its arguments and returns a string to describe it
Returns a string representation of the subgraph between i and o. If the same
Returns a string representation of the subgraph between i and o. If the same
O
p is used by several other ops, the first occurrence will be marked as
o
p is used by several other ops, the first occurrence will be marked as
'*n -> description' and all subsequent occurrences will be marked as '*n',
'*n -> description' and all subsequent occurrences will be marked as '*n',
where n is an id number (ids are attributed in an unspecified order and only
where n is an id number (ids are attributed in an unspecified order and only
exist for viewing convenience).
exist for viewing convenience).
...
@@ -219,50 +238,37 @@ def as_string(i, o):
...
@@ -219,50 +238,37 @@ def as_string(i, o):
done
=
set
()
done
=
set
()
def
multi_index
(
x
):
def
multi_index
(
x
):
try
:
return
multi
.
index
(
x
)
+
1
return
multi
.
index
(
x
)
+
1
except
:
return
999
def
describe
(
x
,
first
=
False
):
if
isinstance
(
x
,
Result
):
done
.
add
(
x
)
if
x
.
owner
is
not
None
and
x
not
in
i
:
op
=
x
.
owner
idx
=
op
.
outputs
.
index
(
x
)
if
idx
:
s
=
describe
(
op
,
first
)
+
"."
+
str
(
idx
)
else
:
s
=
describe
(
op
,
first
)
return
s
else
:
return
str
(
id
(
x
))
elif
isinstance
(
x
,
Op
):
def
describe
(
r
):
if
x
in
done
:
if
r
.
owner
is
not
None
and
r
not
in
i
:
return
"*
%
i"
%
multi_index
(
x
)
op
=
r
.
owner
idx
=
op
.
outputs
.
index
(
r
)
if
idx
==
op
.
_default_output_idx
:
idxs
=
""
else
:
else
:
done
.
add
(
x
)
idxs
=
"::
%
i"
%
idx
if
not
first
and
hasattr
(
x
,
'name'
)
and
x
.
name
is
not
None
:
if
op
in
done
:
return
x
.
name
return
"*
%
i
%
s"
%
(
multi_index
(
x
),
idxs
)
s
=
x
.
__class__
.
__name__
+
"("
+
", "
.
join
([
describe
(
v
)
for
v
in
x
.
inputs
])
+
")"
else
:
if
x
in
multi
:
done
.
add
(
op
)
s
=
node_formatter
(
op
,
[
describe
(
input
)
for
input
in
op
.
inputs
])
if
op
in
multi
:
return
"*
%
i ->
%
s"
%
(
multi_index
(
x
),
s
)
return
"*
%
i ->
%
s"
%
(
multi_index
(
x
),
s
)
else
:
else
:
return
s
return
s
else
:
else
:
raise
TypeError
(
"Cannot print type:
%
s"
%
x
.
__class__
)
return
leaf_formatter
(
r
)
return
"["
+
", "
.
join
([
describe
(
x
,
True
)
for
x
in
o
])
+
"]"
return
[
describe
(
output
)
for
output
in
o
]
# Op.__str__ = lambda self: as_string(inputs(self.outputs), self.outputs)[1:-1]
# Result.__str__ = lambda self: as_string(inputs([self]), [self])[1:-1]
class
Graph
:
class
Graph
:
"""
Object-oriented wrapper for all the functions in this module.
"""
def
__init__
(
self
,
inputs
,
outputs
):
def
__init__
(
self
,
inputs
,
outputs
):
self
.
inputs
=
inputs
self
.
inputs
=
inputs
...
...
gof/op.py
浏览文件 @
a5129f65
...
@@ -4,8 +4,9 @@ Contains the Op class, which is the base interface for all operations
...
@@ -4,8 +4,9 @@ Contains the Op class, which is the base interface for all operations
compatible with gof's graph manipulation routines.
compatible with gof's graph manipulation routines.
"""
"""
from
result
import
BrokenLink
from
result
import
BrokenLinkError
from
utils
import
ClsInit
,
all_bases
,
all_bases_collect
from
utils
import
ClsInit
,
all_bases
,
all_bases_collect
,
AbstractFunctionError
import
graph
from
copy
import
copy
from
copy
import
copy
...
@@ -29,10 +30,6 @@ class Op(object):
...
@@ -29,10 +30,6 @@ class Op(object):
__slots__
=
[
'_inputs'
,
'_outputs'
]
__slots__
=
[
'_inputs'
,
'_outputs'
]
#create inputs and outputs as read-only attributes
inputs
=
property
(
lambda
self
:
self
.
_inputs
,
doc
=
"The list of this Op's input Results."
)
outputs
=
property
(
lambda
self
:
self
.
_outputs
,
doc
=
"The list of this Op's output Results."
)
_default_output_idx
=
0
_default_output_idx
=
0
def
default_output
(
self
):
def
default_output
(
self
):
...
@@ -45,236 +42,465 @@ class Op(object):
...
@@ -45,236 +42,465 @@ class Op(object):
out
=
property
(
default_output
,
out
=
property
(
default_output
,
doc
=
"Same as self.outputs[0] if this Op's has_default_output field is True."
)
doc
=
"Same as self.outputs[0] if this Op's has_default_output field is True."
)
def
__init__
(
self
,
inputs
,
outputs
,
use_self_setters
=
False
):
"""
Initializes the '_inputs' and '_outputs' slots and sets the
owner of all outputs to self.
If use_self_setters is False, Op::set_input and Op::set_output
are used, which do the minimum checks and manipulations. Else,
the user defined set_input and set_output functions are
called (in any case, all inputs and outputs are initialized
to None).
"""
self
.
_inputs
=
[
None
]
*
len
(
inputs
)
self
.
_outputs
=
[
None
]
*
len
(
outputs
)
if
use_self_setters
:
def
__init__
(
self
,
*
inputs
):
for
i
,
input
in
enumerate
(
inputs
):
self
.
_inputs
=
None
self
.
set_input
(
i
,
input
,
validate
=
False
)
self
.
_outputs
=
None
for
i
,
output
in
enumerate
(
outputs
):
self
.
__set_inputs
(
inputs
)
self
.
set_output
(
i
,
output
,
validate
=
False
)
self
.
validate_update
()
self
.
validate
()
else
:
for
i
,
input
in
enumerate
(
inputs
):
Op
.
set_input
(
self
,
i
,
input
,
validate
=
False
)
for
i
,
output
in
enumerate
(
outputs
):
Op
.
set_output
(
self
,
i
,
output
,
validate
=
False
)
self
.
validate
()
self
.
validate
()
def
__get_input
(
self
,
i
):
input
=
self
.
_inputs
[
i
]
if
input
.
replaced
:
raise
BrokenLinkError
()
return
input
def
__set_input
(
self
,
i
,
new_input
):
self
.
_inputs
[
i
]
=
new_input
def
__get_inputs
(
self
):
for
input
in
self
.
_inputs
:
if
input
.
replaced
:
raise
BrokenLinkError
()
return
self
.
_inputs
def
__set_inputs
(
self
,
new_inputs
):
self
.
_inputs
=
list
(
new_inputs
)
def
set_input
(
self
,
i
,
input
,
allow_changes
=
False
,
validate
=
True
):
"""
def
__get_output
(
self
,
i
):
Sets the ith input of self.inputs to input. i must be an
return
self
.
_outputs
[
i
]
integer in the range from 0 to len(self.inputs) - 1 and input
def
__set_output
(
self
,
i
,
new_output
):
must be a Result instance. The method may raise a GofTypeError
old_output
=
self
.
_outputs
[
i
]
or a GofValueError accordingly to the semantics of the Op, if
if
old_output
!=
new_output
:
the new input is of the wrong type or has the wrong
old_output
.
replaced
=
True
properties.
If i > len(self.inputs), an IndexError must be raised. If i ==
len(self.inputs), it is allowed for the Op to extend the list
of inputs if it is a vararg Op, else an IndexError should be
raised.
For a vararg Op, it is also allowed to have the input
parameter set to None for 0 <= i < len(self.inputs), in which
case the rest of the inputs will be shifted left. In any other
situation, a ValueError should be raised.
In some cases, set_input may change some outputs: for example,
a change of an input from float to double might require the
output's type to also change from float to double. If
allow_changes is True, set_input is allowed to perform those
changes and must return a list of pairs, each pair containing
the old output and the output it was replaced with (they
_must_ be different Result instances). See Op::set_output for
important information about replacing outputs. If
allow_changes is False and some change in the outputs is
required for the change in input to be correct, a
PropagationError must be raised.
This default implementation sets the ith input to input and
changes no outputs. It returns None.
"""
previous
=
self
.
inputs
[
i
]
self
.
inputs
[
i
]
=
input
if
validate
:
try
:
try
:
self
.
validate
()
# We try to reuse the old storage, if there is one
new_output
.
data
=
old_output
.
data
except
:
except
:
# this call gives a subclass the chance to undo the set_outputs
pass
# that it may have triggered...
new_output
.
role
=
(
self
,
i
)
# TODO: test this functionality!
self
.
_outputs
[
i
]
=
new_output
self
.
set_input
(
i
,
previous
,
True
,
False
)
def
set_output
(
self
,
i
,
output
,
validate
=
True
):
def
__get_outputs
(
self
):
"""
return
self
.
_outputs
Sets the ith output to output. The previous output, which is
def
__set_outputs
(
self
,
new_outputs
):
being replaced, must be invalidated using Result::invalidate.
if
self
.
_outputs
is
None
:
The new output must not already have an owner, or its owner must
for
i
,
output
in
enumerate
(
new_outputs
):
be self. It cannot be a broken link, unless it used to be at this
output
.
role
=
(
self
,
i
)
spot, in which case it can be reinstated.
self
.
_outputs
=
new_outputs
return
For Ops that have vararg output lists, see the regulations in
Op::set_input.
if
len
(
self
.
_outputs
)
!=
len
(
new_outputs
):
"""
raise
TypeError
(
"The new outputs must be exactly as many as the previous outputs."
)
if
isinstance
(
output
.
owner
,
BrokenLink
)
\
and
output
.
owner
.
owner
is
self
\
for
i
,
new_output
in
enumerate
(
new_outputs
):
and
output
.
owner
.
index
==
i
:
self
.
__set_output
(
i
,
new_output
)
output
.
revalidate
()
else
:
output
.
set_owner
(
self
,
i
)
# this checks for an already existing owner
def
get_input
(
self
,
i
):
previous
=
self
.
outputs
[
i
]
return
self
.
__get_input
(
i
)
if
previous
:
def
set_input
(
self
,
i
,
new_input
):
previous
.
invalidate
()
old_input
=
self
.
__get_input
(
i
)
self
.
outputs
[
i
]
=
output
if
validate
:
try
:
try
:
self
.
validate
()
self
.
__set_input
(
i
,
new_input
)
self
.
validate_update
()
except
:
except
:
self
.
set_output
(
i
,
previous
,
False
)
self
.
__set_input
(
i
,
old_input
)
self
.
validate_update
()
raise
def
get_inputs
(
self
):
return
self
.
__get_inputs
()
def
set_inputs
(
self
,
new_inputs
):
old_inputs
=
self
.
__get_inputs
()
try
:
self
.
__set_inputs
(
new_inputs
)
self
.
validate_update
()
except
:
self
.
_inputs
=
old_inputs
raise
def
_dontuse_repair
(
self
,
allow_changes
=
False
):
def
get_output
(
self
,
i
):
return
self
.
__get_output
(
i
)
def
get_outputs
(
self
):
return
self
.
__get_outputs
()
#create inputs and outputs as read-only attributes
inputs
=
property
(
get_inputs
,
set_inputs
,
doc
=
"The list of this Op's input Results."
)
outputs
=
property
(
get_outputs
,
__set_outputs
,
doc
=
"The list of this Op's output Results."
)
def
validate_update
(
self
):
"""
"""
This function attempts to repair all inputs that are broken
(Abstract) This function must do two things:
links by calling set_input on the new Result that replaced
* validate: check all the inputs in self.inputs to ensure
them. Note that if a set_input operation invalidates one or
that they have the right type for this Op, etc.
more outputs, new broken links might appear in the other ops
If the validation fails, raise an exception.
that use this op's outputs.
* update: create output Results and set the Op's outputs
The existing outputs must not be changed in place.
It is possible that the new inputs are inconsistent with this
The return value of validate_update is not used.
op, in which case an exception will be raised and the previous
inputs (and outputs) will be restored.
refresh returns a list of (old_output, new_output) pairs
detailing the changes, if any.
"""
"""
backtrack
=
[]
raise
AbstractFunctionError
()
def
repair
(
self
):
"""
Repairs all the inputs that are broken links to use what
they were replaced with. Then, calls self.validate_update()
to validate the new inputs and make new outputs.
"""
changed
=
False
repaired_inputs
=
[]
old_inputs
=
self
.
_inputs
for
input
in
self
.
_inputs
:
if
input
.
replaced
:
changed
=
True
role
=
input
.
role
.
old_role
input
=
role
[
0
]
.
outputs
[
role
[
1
]]
repaired_inputs
.
append
(
input
)
if
changed
:
try
:
try
:
for
i
,
input
in
enumerate
(
self
.
inputs
):
self
.
__set_inputs
(
repaired_inputs
)
link
=
input
.
owner
self
.
validate_update
()
if
isinstance
(
link
,
BrokenLink
):
current
=
link
.
owner
.
outputs
[
link
.
index
]
dirt
=
self
.
set_input
(
i
,
current
,
allow_changes
)
backtrack
.
append
((
i
,
input
,
dirt
))
except
:
except
:
# Restore the inputs and outputs that were successfully changed.
self
.
_inputs
=
old_inputs
for
i
,
input
,
dirt
in
backtrack
:
self
.
inputs
[
i
]
=
input
if
dirt
:
for
old
,
new
in
dirt
:
new
.
invalidate
()
old
.
revalidate
()
self
.
outputs
[
self
.
outputs
.
index
(
new
)]
=
old
raise
raise
all_dirt
=
[]
return
changed
for
i
,
input
,
dirt
in
backtrack
:
if
dirt
:
all_dirt
+=
dirt
return
all_dirt
def
perform
(
self
):
#
# copy
#
def
__copy__
(
self
):
"""
"""
Performs the computation on the inputs and stores the results
Shallow copy of this Op. The inputs are the exact same, but
in the outputs. This function should check for the validity of
the outputs are recreated because of the one-owner-per-result
the inputs and raise appropriate errors for debugging (for
policy.
executing without checks, override _perform).
An Op may define additional ways to perform the computation
that are more efficient (e.g. a piece of C code or a C struct
with direct references to the inputs and outputs), but
perform() should always be available in order to have a
consistent interface to execute graphs.
"""
"""
raise
NotImplementedError
return
self
.
__class__
(
*
self
.
inputs
)
def
_perform
(
self
):
#
# String representation
#
def
__str__
(
self
):
return
graph
.
op_as_string
(
self
.
inputs
,
self
)
def
__repr__
(
self
):
return
str
(
self
)
#
# perform
#
def
perform
(
self
):
"""
"""
Performs the computation on the inputs and stores the results
(abstract) Performs the computation associated to this Op,
in the outputs, like perform(), but is not required to check
places the result(s) in the output Results and gives them
the
existence or the validity of the input
s.
the
Computed statu
s.
"""
"""
r
eturn
self
.
perform
()
r
aise
AbstractFunctionError
()
@classmethod
def
require
(
cls
):
#
# C code generators
#
def
c_var_names
(
self
):
"""
"""
Returns a set of Feature subclasses that must be used by any
Returns ([list of input names], [list of output names]) for
Env manipulating this kind of op. For instance, a Destroyer
use as C variables.
requires ext.DestroyHandler to guarantee that various
destructive operations don't interfere.
By default, this collates the __require__ field of this class
and the __require__ fields of all classes that are directly or
indirectly superclasses to this class into a set.
"""
"""
r
=
set
()
return
[[
"i
%
i"
%
i
for
i
in
xrange
(
len
(
self
.
inputs
))],
[
"o
%
i"
%
i
for
i
in
xrange
(
len
(
self
.
outputs
))]]
bases
=
all_bases
(
cls
,
lambda
cls
:
hasattr
(
cls
,
'__env_require__'
))
def
c_validate
(
self
):
"""
Returns C code that checks that the inputs to this function
can be worked on. If a failure occurs, set an Exception
and insert "
%(fail)
s".
for
base
in
bases
:
You may use the variable names defined by c_var_names()
req
=
base
.
__env_require__
"""
if
isinstance
(
req
,
(
list
,
tuple
)):
raise
AbstractFunctionError
()
r
.
update
(
req
)
else
:
r
.
add
(
req
)
return
r
def
c_validate_cleanup
(
self
):
"""
Clean up things allocated by c_validate().
"""
raise
AbstractFunctionError
()
def
vali
date
(
self
):
def
c_up
date
(
self
):
"""
"""
This class's __validate__ function will be called, as well as
Returns C code that allocates and/or updates the outputs
the __validate__ functions of all base classes down the class
(eg resizing, etc.) so they can be manipulated safely
tree. If you do not want to execute __validate__ from the base
by c_code.
classes, set the class variable __validate_override__ to True.
You may use the variable names defined by c_var_names()
"""
"""
vfns
=
all_bases_collect
(
self
.
__class__
,
'validate'
)
raise
AbstractFunctionError
()
for
vfn
in
vfns
:
vfn
(
self
)
def
c_update_cleanup
(
self
):
"""
Clean up things allocated by c_update().
"""
raise
AbstractFunctionError
()
def
__copy__
(
self
):
def
c_code
(
self
):
"""
"""
Copies the inputs list shallowly and copies all the outputs
Returns C code that does the computation associated to this
because of the one owner per output restriction.
Op. You may assume that input validation and output allocation
have already been done.
You may use the variable names defined by c_var_names()
"""
"""
new_inputs
=
copy
(
self
.
inputs
)
raise
AbstractFunctionError
()
# We copy the outputs because they are tied to a single Op.
new_outputs
=
[
copy
(
output
)
for
output
in
self
.
outputs
]
def
c_code_cleanup
(
self
):
op
=
self
.
__class__
(
new_inputs
,
new_outputs
)
op
.
_inputs
=
new_inputs
op
.
_outputs
=
new_outputs
for
i
,
output
in
enumerate
(
op
.
outputs
):
# We adjust _owner and _index manually since the copies
# point to the previous op (self).
output
.
_owner
=
op
output
.
_index
=
i
return
op
def
__deepcopy__
(
self
,
memo
):
"""
"""
Not implemented. Use gof.graph.clone(inputs, outputs) to copy
Clean up things allocated by c_code().
a subgraph.
"""
"""
raise
NotImplementedError
(
"Use gof.graph.clone(inputs, outputs) to copy a subgraph."
)
raise
AbstractFunctionError
()
# def __init__(self, inputs, outputs, use_self_setters = False):
# """
# Initializes the '_inputs' and '_outputs' slots and sets the
# owner of all outputs to self.
# If use_self_setters is False, Op::set_input and Op::set_output
# are used, which do the minimum checks and manipulations. Else,
# the user defined set_input and set_output functions are
# called (in any case, all inputs and outputs are initialized
# to None).
# """
# self._inputs = [None] * len(inputs)
# self._outputs = [None] * len(outputs)
# if use_self_setters:
# for i, input in enumerate(inputs):
# self.set_input(i, input, validate = False)
# for i, output in enumerate(outputs):
# self.set_output(i, output, validate = False)
# self.validate()
# else:
# for i, input in enumerate(inputs):
# Op.set_input(self, i, input, validate = False)
# for i, output in enumerate(outputs):
# Op.set_output(self, i, output, validate = False)
# self.validate()
# self.validate()
# def set_input(self, i, input, allow_changes = False, validate = True):
# """
# Sets the ith input of self.inputs to input. i must be an
# integer in the range from 0 to len(self.inputs) - 1 and input
# must be a Result instance. The method may raise a GofTypeError
# or a GofValueError accordingly to the semantics of the Op, if
# the new input is of the wrong type or has the wrong
# properties.
# If i > len(self.inputs), an IndexError must be raised. If i ==
# len(self.inputs), it is allowed for the Op to extend the list
# of inputs if it is a vararg Op, else an IndexError should be
# raised.
# For a vararg Op, it is also allowed to have the input
# parameter set to None for 0 <= i < len(self.inputs), in which
# case the rest of the inputs will be shifted left. In any other
# situation, a ValueError should be raised.
# In some cases, set_input may change some outputs: for example,
# a change of an input from float to double might require the
# output's type to also change from float to double. If
# allow_changes is True, set_input is allowed to perform those
# changes and must return a list of pairs, each pair containing
# the old output and the output it was replaced with (they
# _must_ be different Result instances). See Op::set_output for
# important information about replacing outputs. If
# allow_changes is False and some change in the outputs is
# required for the change in input to be correct, a
# PropagationError must be raised.
# This default implementation sets the ith input to input and
# changes no outputs. It returns None.
# """
# previous = self.inputs[i]
# self.inputs[i] = input
# if validate:
# try:
# self.validate()
# except:
# # this call gives a subclass the chance to undo the set_outputs
# # that it may have triggered...
# # TODO: test this functionality!
# self.set_input(i, previous, True, False)
# def set_output(self, i, output, validate = True):
# """
# Sets the ith output to output. The previous output, which is
# being replaced, must be invalidated using Result::invalidate.
# The new output must not already have an owner, or its owner must
# be self. It cannot be a broken link, unless it used to be at this
# spot, in which case it can be reinstated.
# For Ops that have vararg output lists, see the regulations in
# Op::set_input.
# """
# if isinstance(output.owner, BrokenLink) \
# and output.owner.owner is self \
# and output.owner.index == i:
# output.revalidate()
# else:
# output.set_owner(self, i) # this checks for an already existing owner
# previous = self.outputs[i]
# if previous:
# previous.invalidate()
# self.outputs[i] = output
# if validate:
# try:
# self.validate()
# except:
# self.set_output(i, previous, False)
# def _dontuse_repair(self, allow_changes = False):
# """
# This function attempts to repair all inputs that are broken
# links by calling set_input on the new Result that replaced
# them. Note that if a set_input operation invalidates one or
# more outputs, new broken links might appear in the other ops
# that use this op's outputs.
# It is possible that the new inputs are inconsistent with this
# op, in which case an exception will be raised and the previous
# inputs (and outputs) will be restored.
# refresh returns a list of (old_output, new_output) pairs
# detailing the changes, if any.
# """
# backtrack = []
# try:
# for i, input in enumerate(self.inputs):
# link = input.owner
# if isinstance(link, BrokenLink):
# current = link.owner.outputs[link.index]
# dirt = self.set_input(i, current, allow_changes)
# backtrack.append((i, input, dirt))
# except:
# # Restore the inputs and outputs that were successfully changed.
# for i, input, dirt in backtrack:
# self.inputs[i] = input
# if dirt:
# for old, new in dirt:
# new.invalidate()
# old.revalidate()
# self.outputs[self.outputs.index(new)] = old
# raise
# all_dirt = []
# for i, input, dirt in backtrack:
# if dirt:
# all_dirt += dirt
# return all_dirt
# def perform(self):
# """
# Performs the computation on the inputs and stores the results
# in the outputs. This function should check for the validity of
# the inputs and raise appropriate errors for debugging (for
# executing without checks, override _perform).
# An Op may define additional ways to perform the computation
# that are more efficient (e.g. a piece of C code or a C struct
# with direct references to the inputs and outputs), but
# perform() should always be available in order to have a
# consistent interface to execute graphs.
# """
# raise NotImplementedError
# def _perform(self):
# """
# Performs the computation on the inputs and stores the results
# in the outputs, like perform(), but is not required to check
# the existence or the validity of the inputs.
# """
# return self.perform()
# @classmethod
# def require(cls):
# """
# Returns a set of Feature subclasses that must be used by any
# Env manipulating this kind of op. For instance, a Destroyer
# requires ext.DestroyHandler to guarantee that various
# destructive operations don't interfere.
# By default, this collates the __require__ field of this class
# and the __require__ fields of all classes that are directly or
# indirectly superclasses to this class into a set.
# """
# r = set()
# bases = all_bases(cls, lambda cls: hasattr(cls, '__env_require__'))
# for base in bases:
# req = base.__env_require__
# if isinstance(req, (list, tuple)):
# r.update(req)
# else:
# r.add(req)
# return r
# def validate(self):
# """
# This class's __validate__ function will be called, as well as
# the __validate__ functions of all base classes down the class
# tree. If you do not want to execute __validate__ from the base
# classes, set the class variable __validate_override__ to True.
# """
# vfns = all_bases_collect(self.__class__, 'validate')
# for vfn in vfns:
# vfn(self)
# def __copy__(self):
# """
# Copies the inputs list shallowly and copies all the outputs
# because of the one owner per output restriction.
# """
# new_inputs = copy(self.inputs)
# # We copy the outputs because they are tied to a single Op.
# new_outputs = [copy(output) for output in self.outputs]
# op = self.__class__(new_inputs, new_outputs)
# op._inputs = new_inputs
# op._outputs = new_outputs
# for i, output in enumerate(op.outputs):
# # We adjust _owner and _index manually since the copies
# # point to the previous op (self).
# output._owner = op
# output._index = i
# return op
# def __deepcopy__(self, memo):
# """
# Not implemented. Use gof.graph.clone(inputs, outputs) to copy
# a subgraph.
# """
# raise NotImplementedError("Use gof.graph.clone(inputs, outputs) to copy a subgraph.")
gof/result.py
浏览文件 @
a5129f65
...
@@ -5,37 +5,32 @@ value that is the input or the output of an Op.
...
@@ -5,37 +5,32 @@ value that is the input or the output of an Op.
"""
"""
import
unittest
from
err
import
GofError
from
utils
import
AbstractFunctionError
from
utils
import
AbstractFunctionError
from
python25
import
all
from
python25
import
all
__all__
=
[
'is_result'
,
'ResultBase'
,
'BrokenLink'
,
'BrokenLinkError'
]
__all__
=
[
'is_result'
,
'ResultBase'
,
'BrokenLink'
,
'BrokenLinkError'
,
'StateError'
,
'Empty'
,
'Allocated'
,
'Computed'
,
]
class
BrokenLink
:
class
BrokenLink
:
"""
"""The owner of a Result that was replaced by another Result"""
This is placed as the owner of a Result that was replaced by
__slots__
=
[
'old_role'
]
another Result.
def
__init__
(
self
,
role
):
self
.
old_role
=
role
"""
def
__nonzero__
(
self
):
return
False
__slots__
=
[
'owner'
,
'index'
]
def
__init__
(
self
,
owner
,
index
):
self
.
owner
=
owner
self
.
index
=
index
def
__nonzero__
(
self
):
return
False
class
BrokenLinkError
(
Exception
):
"""The owner is a BrokenLink"""
class
BrokenLinkError
(
GofError
):
class
StateError
(
Exception
):
"""
"""The state of the Result is a problem"""
"""
pass
# ResultBase state keywords
# ResultBase state keywords
...
@@ -61,6 +56,7 @@ class ResultBase(object):
...
@@ -61,6 +56,7 @@ class ResultBase(object):
_data - anything
_data - anything
constant - Boolean
constant - Boolean
state - one of (Empty, Allocated, Computed)
state - one of (Empty, Allocated, Computed)
name - string
Properties:
Properties:
role - (rw)
role - (rw)
...
@@ -89,33 +85,17 @@ class ResultBase(object):
...
@@ -89,33 +85,17 @@ class ResultBase(object):
expressions).
expressions).
"""
"""
class
BrokenLink
:
"""The owner of a Result that was replaced by another Result"""
__slots__
=
[
'old_role'
]
def
__init__
(
self
,
role
):
self
.
old_role
=
role
def
__nonzero__
(
self
):
return
False
class
BrokenLinkError
(
Exception
):
__slots__
=
[
'_role'
,
'constant'
,
'_data'
,
'state'
,
'_name'
]
"""The owner is a BrokenLink"""
class
StateError
(
Exception
):
def
__init__
(
self
,
role
=
None
,
data
=
None
,
constant
=
False
,
name
=
None
):
"""The state of the Result is a problem"""
__slots__
=
[
'_role'
,
'constant'
,
'_data'
,
'state'
]
def
__init__
(
self
,
role
=
None
,
data
=
None
,
constant
=
False
):
self
.
_role
=
role
self
.
_role
=
role
self
.
constant
=
constant
self
.
_data
=
[
None
]
self
.
_data
=
[
None
]
if
data
is
None
:
#None is not filtered
self
.
_data
[
0
]
=
None
self
.
state
=
Empty
self
.
state
=
Empty
else
:
self
.
constant
=
False
try
:
self
.
__set_data
(
data
)
self
.
_data
[
0
]
=
self
.
data_filter
(
data
)
self
.
constant
=
constant
# can only lock data after setting it
except
AbstractFunctionError
:
self
.
name
=
name
self
.
_data
[
0
]
=
data
self
.
state
=
Computed
#
#
# role
# role
...
@@ -144,7 +124,7 @@ class ResultBase(object):
...
@@ -144,7 +124,7 @@ class ResultBase(object):
def
__get_owner
(
self
):
def
__get_owner
(
self
):
if
self
.
_role
is
None
:
return
None
if
self
.
_role
is
None
:
return
None
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
if
self
.
replaced
:
raise
BrokenLinkError
()
return
self
.
_role
[
0
]
return
self
.
_role
[
0
]
owner
=
property
(
__get_owner
,
owner
=
property
(
__get_owner
,
...
@@ -156,7 +136,7 @@ class ResultBase(object):
...
@@ -156,7 +136,7 @@ class ResultBase(object):
def
__get_index
(
self
):
def
__get_index
(
self
):
if
self
.
_role
is
None
:
return
None
if
self
.
_role
is
None
:
return
None
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
if
self
.
replaced
:
raise
BrokenLinkError
()
return
self
.
_role
[
1
]
return
self
.
_role
[
1
]
index
=
property
(
__get_index
,
index
=
property
(
__get_index
,
...
@@ -171,35 +151,39 @@ class ResultBase(object):
...
@@ -171,35 +151,39 @@ class ResultBase(object):
return
self
.
_data
[
0
]
return
self
.
_data
[
0
]
def
__set_data
(
self
,
data
):
def
__set_data
(
self
,
data
):
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
if
self
.
replaced
:
if
self
.
constant
:
raise
Exception
(
'cannot set constant ResultBase'
)
raise
BrokenLinkError
()
if
data
is
self
.
_data
[
0
]:
return
if
self
.
constant
:
raise
Exception
(
'cannot set constant ResultBase'
)
if
data
is
None
:
if
data
is
None
:
self
.
_data
[
0
]
=
None
self
.
_data
[
0
]
=
None
self
.
state
=
Empty
self
.
state
=
Empty
return
return
if
data
is
self
or
data
is
self
.
_data
[
0
]:
return
try
:
try
:
self
.
_data
[
0
]
=
self
.
data_filter
(
data
)
self
.
validate
(
data
)
except
AbstractFunctionError
:
#use default behaviour
except
AbstractFunctionError
:
pass
self
.
_data
[
0
]
=
data
self
.
_data
[
0
]
=
data
if
isinstance
(
data
,
ResultBase
):
raise
Exception
()
self
.
state
=
Computed
self
.
state
=
Computed
data
=
property
(
__get_data
,
__set_data
,
data
=
property
(
__get_data
,
__set_data
,
doc
=
"The storage associated with this result"
)
doc
=
"The storage associated with this result"
)
def
data_filter
(
self
,
data
):
def
validate
(
self
,
data
):
"""(abstract) Return an appropriate _data based on data.
"""(abstract) Raise an exception if the data is not of an
acceptable type.
If a subclass overrides this function,
then that overriding
If a subclass overrides this function,
__set_data will use
i
mplementation will be used in __set_data to map the argument to
i
t to check that the argument can be used properly. This gives
self._data. This gives a subclass the opportunity to ensure that
a subclass the opportunity to ensure that the contents of
the contents of
self._data remain sensible.
self._data remain sensible.
"""
"""
raise
AbstractFunctionError
()
raise
AbstractFunctionError
()
#
#
# alloc
# alloc
#
#
...
@@ -229,18 +213,99 @@ class ResultBase(object):
...
@@ -229,18 +213,99 @@ class ResultBase(object):
#
#
def
__get_replaced
(
self
):
def
__get_replaced
(
self
):
return
isinstance
(
self
.
_role
,
ResultBase
.
BrokenLink
)
return
isinstance
(
self
.
_role
,
BrokenLink
)
def
__set_replaced
(
self
,
replace
):
def
__set_replaced
(
self
,
replace
):
if
replace
==
self
.
replaced
:
return
if
replace
==
self
.
replaced
:
return
if
replace
:
if
replace
:
self
.
_role
=
ResultBase
.
BrokenLink
(
self
.
_role
)
self
.
_role
=
BrokenLink
(
self
.
_role
)
else
:
else
:
self
.
_role
=
self
.
_role
.
old_role
self
.
_role
=
self
.
_role
.
old_role
replaced
=
property
(
__get_replaced
,
__set_replaced
,
doc
=
"has this Result been replaced?"
)
replaced
=
property
(
__get_replaced
,
__set_replaced
,
doc
=
"has this Result been replaced?"
)
#
# C code generators
#
def
c_extract
(
self
):
get_from_list
=
"""
PyObject* py_
%(name)
s = PyList_GET_ITEM(
%(name)
s_storage, 0);
Py_XINCREF(py_
%(name)
s);
"""
return
self
.
c_data_extract
()
+
get_from_list
def
c_data_extract
(
self
):
"""
The code returned from this function must be templated using "
%(name)
s",
representing the name that the caller wants to call this Result.
The Python object self.data is in a variable called "py_
%(name)
s" and
this code must declare a variable named "
%(name)
s" of a type appropriate
to manipulate from C. Additional variables and typedefs can be produced.
If the data is improper, set an appropriate error message and insert
"
%(fail)
s".
"""
raise
AbstractFunction
()
def
c_sync
(
self
,
var_name
):
set_in_list
=
"""
PyList_SET_ITEM(
%(name)
s_storage, 0, py_
%(name)
s);
Py_XDECREF(py_
%(name)
s);
"""
return
self
.
c_data_sync
()
+
set_in_list
def
c_data_sync
(
self
):
"""
The code returned from this function must be templated using "
%(name)
s",
representing the name that the caller wants to call this Result.
The returned code may set "py_
%(name)
s" to a PyObject* and that PyObject*
will be accessible from Python via result.data. Do not forget to adjust
reference counts if "py_
%(name)
s" is changed from its original value!
"""
raise
AbstractFunction
()
#
# name
#
def
__get_name
(
self
):
if
self
.
_name
:
return
self
.
_name
elif
self
.
_role
:
return
"
%
s.
%
i"
%
(
self
.
owner
.
__class__
,
self
.
owner
.
outputs
.
index
(
self
))
else
:
return
None
def
__set_name
(
self
,
name
):
if
name
is
not
None
and
not
isinstance
(
name
,
str
):
raise
TypeError
(
"Name is expected to be a string, or None."
)
self
.
_name
=
name
name
=
property
(
__get_name
,
__set_name
,
doc
=
"Name of the Result."
)
#
# String representation
#
def
__str__
(
self
):
name
=
self
.
name
if
name
:
if
self
.
state
is
Computed
:
return
name
+
":"
+
str
(
self
.
data
)
else
:
return
name
elif
self
.
state
is
Computed
:
return
str
(
self
.
data
)
else
:
return
"<?>"
def
__repr__
(
self
):
return
self
.
name
or
"<?>"
#################
#################
# NumpyR Compatibility
# NumpyR Compatibility
...
@@ -252,26 +317,26 @@ class ResultBase(object):
...
@@ -252,26 +317,26 @@ class ResultBase(object):
def
set_value
(
self
,
value
):
def
set_value
(
self
,
value
):
self
.
data
=
value
#may raise exception
self
.
data
=
value
#may raise exception
class
_test_ResultBase
(
unittest
.
TestCase
):
def
test_0
(
self
):
r
=
ResultBase
()
def
test_1
(
self
):
r
=
ResultBase
()
assert
r
.
state
is
Empty
r
.
data
=
0
assert
r
.
data
==
0
assert
r
.
state
is
Computed
r
.
data
=
1
assert
r
.
data
==
1
assert
r
.
state
is
Computed
r
.
data
=
None
assert
r
.
data
==
None
assert
r
.
state
is
Empty
if
__name__
==
'__main__'
:
unittest
.
main
()
# def c_data_extract(self):
# return """
# PyArrayObject* %%(name)s;
# if (py_%%(name)s == Py_None)
# %%(name)s = NULL;
# else
# %%(name)s = (PyArrayObject*)(py_%%(name)s);
# typedef %(dtype)s %%(name)s_dtype;
# """ % dict(dtype = self.dtype)
# def c_data_sync(self):
# return """
# if (!%(name)s) {
# Py_XDECREF(py_%(name));
# py_%(name)s = Py_None;
# }
# else if ((void*)py_%(name)s != (void*)%(name)s) {
# Py_XDECREF(py_%(name));
# py_%(name)s = (PyObject*)%(name)s;
# }
# """
gof/utils.py
浏览文件 @
a5129f65
...
@@ -12,6 +12,15 @@ class AbstractFunctionError(Exception):
...
@@ -12,6 +12,15 @@ class AbstractFunctionError(Exception):
function has been left out of an implementation class.
function has been left out of an implementation class.
"""
"""
def
attr_checker
(
*
attrs
):
def
f
(
candidate
):
for
attr
in
attrs
:
if
not
hasattr
(
candidate
,
attr
):
return
False
return
True
f
.
__doc__
=
"Checks that the candidate has the following attributes:
%
s"
%
", "
.
join
([
"'
%
s'"
%
attr
for
attr
in
attrs
])
return
f
def
all_bases
(
cls
,
accept
):
def
all_bases
(
cls
,
accept
):
rval
=
set
([
cls
])
rval
=
set
([
cls
])
...
@@ -34,21 +43,6 @@ def all_bases_collect(cls, raw_name):
...
@@ -34,21 +43,6 @@ def all_bases_collect(cls, raw_name):
def
uniq_features
(
_features
,
*
_rest
):
"""Return a list such that no element is a subclass of another"""
# used in Env.__init__ to
features
=
[
x
for
x
in
_features
]
for
other
in
_rest
:
features
+=
[
x
for
x
in
other
]
res
=
[]
while
features
:
feature
=
features
.
pop
()
for
feature2
in
features
:
if
issubclass
(
feature2
,
feature
):
break
else
:
res
.
append
(
feature
)
return
res
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论