Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
65e08101
提交
65e08101
authored
5月 04, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
env redone, toolbox redone
上级
646f4c01
隐藏空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
676 行增加
和
402 行删除
+676
-402
_test_ext.py
gof/_test_ext.py
+29
-22
_test_toolbox.py
gof/_test_toolbox.py
+14
-14
env.py
gof/env.py
+292
-219
ext.py
gof/ext.py
+105
-44
graph.py
gof/graph.py
+6
-6
opt.py
gof/opt.py
+13
-4
toolbox.py
gof/toolbox.py
+214
-92
utils.py
gof/utils.py
+2
-0
tensor.py
tensor.py
+1
-1
没有找到文件。
gof/_test_ext.py
浏览文件 @
65e08101
...
@@ -2,13 +2,15 @@
...
@@ -2,13 +2,15 @@
import
unittest
import
unittest
from
type
import
Type
from
type
import
Type
import
graph
from
graph
import
Result
,
as_result
,
Apply
from
graph
import
Result
,
as_result
,
Apply
from
op
import
Op
from
op
import
Op
from
opt
import
PatternOptimizer
,
OpSubOptimizer
from
opt
import
PatternOptimizer
,
OpSubOptimizer
from
ext
import
*
from
ext
import
*
from
env
import
Env
,
InconsistencyError
from
env
import
Env
,
InconsistencyError
from
toolbox
import
EquivTool
#from toolbox import EquivTool
from
toolbox
import
ReplaceValidate
from
copy
import
copy
from
copy
import
copy
...
@@ -65,8 +67,11 @@ def inputs():
...
@@ -65,8 +67,11 @@ def inputs():
_Env
=
Env
_Env
=
Env
def
Env
(
inputs
,
outputs
,
validate
=
True
):
def
Env
(
inputs
,
outputs
,
validate
=
True
):
e
=
_Env
(
inputs
,
outputs
)
e
=
_Env
(
inputs
,
outputs
)
e
.
extend
(
EquivTool
(
e
))
##e.extend(EquivTool(e))
e
.
extend
(
DestroyHandler
(
e
),
validate
=
validate
)
e
.
extend
(
DestroyHandler
())
e
.
extend
(
ReplaceValidate
())
if
validate
:
e
.
validate
()
return
e
return
e
...
@@ -108,19 +113,19 @@ class _test_all(unittest.TestCase):
...
@@ -108,19 +113,19 @@ class _test_all(unittest.TestCase):
g
=
Env
([
x
,
y
,
z
],
[
e1
,
e2
])
g
=
Env
([
x
,
y
,
z
],
[
e1
,
e2
])
chk
=
g
.
checkpoint
()
chk
=
g
.
checkpoint
()
assert
g
.
consistent
()
assert
g
.
consistent
()
g
.
replace
(
e1
,
add_in_place
(
x
,
y
))
g
.
replace
_validate
(
e1
,
add_in_place
(
x
,
y
))
assert
g
.
consistent
()
assert
g
.
consistent
()
try
:
try
:
g
.
replace
(
e2
,
add_in_place
(
y
,
x
))
g
.
replace
_validate
(
e2
,
add_in_place
(
y
,
x
))
self
.
fail
()
self
.
fail
()
except
InconsistencyError
:
except
InconsistencyError
:
pass
pass
assert
g
.
consistent
()
assert
g
.
consistent
()
g
.
revert
(
chk
)
g
.
revert
(
chk
)
g
.
replace
(
e2
,
add_in_place
(
y
,
x
))
g
.
replace
_validate
(
e2
,
add_in_place
(
y
,
x
))
assert
g
.
consistent
()
assert
g
.
consistent
()
try
:
try
:
g
.
replace
(
e1
,
add_in_place
(
x
,
y
))
g
.
replace
_validate
(
e1
,
add_in_place
(
x
,
y
))
self
.
fail
()
self
.
fail
()
except
InconsistencyError
:
except
InconsistencyError
:
pass
pass
...
@@ -136,7 +141,7 @@ class _test_all(unittest.TestCase):
...
@@ -136,7 +141,7 @@ class _test_all(unittest.TestCase):
assert
str
(
g
)
!=
"[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]"
# we don't want to see that!
assert
str
(
g
)
!=
"[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]"
# we don't want to see that!
e2
=
dot
(
dot
(
add_in_place
(
x
,
y
),
add_in_place
(
y
,
z
)),
add_in_place
(
z
,
x
))
e2
=
dot
(
dot
(
add_in_place
(
x
,
y
),
add_in_place
(
y
,
z
)),
add_in_place
(
z
,
x
))
try
:
try
:
g2
=
Env
(
[
x
,
y
,
z
],
[
e2
]
)
g2
=
Env
(
*
graph
.
clone
([
x
,
y
,
z
],
[
e2
])
)
self
.
fail
()
self
.
fail
()
except
InconsistencyError
:
except
InconsistencyError
:
pass
pass
...
@@ -154,16 +159,18 @@ class _test_all(unittest.TestCase):
...
@@ -154,16 +159,18 @@ class _test_all(unittest.TestCase):
e
=
dot
(
aip
,
transpose_view
(
x
))
e
=
dot
(
aip
,
transpose_view
(
x
))
g
=
Env
([
x
,
y
,
z
],
[
e
],
False
)
g
=
Env
([
x
,
y
,
z
],
[
e
],
False
)
assert
not
g
.
consistent
()
assert
not
g
.
consistent
()
g
.
replace
(
aip
,
add
(
x
,
z
))
g
.
replace
_validate
(
aip
,
add
(
x
,
z
))
assert
g
.
consistent
()
assert
g
.
consistent
()
def
test_usage_loop_through_views_2
(
self
):
def
test_usage_loop_through_views_2
(
self
):
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
e0
=
transpose_view
(
transpose_view
(
transpose_view
(
sigmoid
(
x
)
)))
e0
=
transpose_view
(
transpose_view
(
sigmoid
(
x
)))
e
=
dot
(
add_in_place
(
x
,
y
),
transpose_view
(
e0
))
e
=
dot
(
add_in_place
(
x
,
y
),
transpose_view
(
e0
))
g
=
Env
([
x
,
y
,
z
],
[
e
])
g
=
Env
([
x
,
y
,
z
],
[
e
])
assert
g
.
consistent
()
# because sigmoid can do the copy
assert
g
.
consistent
()
# because sigmoid can do the copy
g
.
replace
(
e0
,
x
,
False
)
# print g
# print g.destroy_handler.children
g
.
replace
(
e0
,
x
)
assert
not
g
.
consistent
()
# we cut off the path to the sigmoid
assert
not
g
.
consistent
()
# we cut off the path to the sigmoid
def
test_usage_loop_insert_views
(
self
):
def
test_usage_loop_insert_views
(
self
):
...
@@ -184,10 +191,10 @@ class _test_all(unittest.TestCase):
...
@@ -184,10 +191,10 @@ class _test_all(unittest.TestCase):
chk
=
g
.
checkpoint
()
chk
=
g
.
checkpoint
()
PatternOptimizer
((
transpose_view
,
(
transpose_view
,
'x'
)),
'x'
)
.
optimize
(
g
)
PatternOptimizer
((
transpose_view
,
(
transpose_view
,
'x'
)),
'x'
)
.
optimize
(
g
)
assert
str
(
g
)
==
"[x]"
assert
str
(
g
)
==
"[x]"
g
.
replace
(
g
.
equiv
(
e
),
add
(
x
,
y
)
)
new_e
=
add
(
x
,
y
)
print
g
g
.
replace_validate
(
x
,
new_e
)
assert
str
(
g
)
==
"[Add(x, y)]"
assert
str
(
g
)
==
"[Add(x, y)]"
g
.
replace
(
g
.
equiv
(
e
),
dot
(
add_in_place
(
x
,
y
),
transpose_view
(
x
)),
False
)
g
.
replace
(
new_e
,
dot
(
add_in_place
(
x
,
y
),
transpose_view
(
x
))
)
assert
str
(
g
)
==
"[Dot(AddInPlace(x, y), TransposeView(x))]"
assert
str
(
g
)
==
"[Dot(AddInPlace(x, y), TransposeView(x))]"
assert
not
g
.
consistent
()
assert
not
g
.
consistent
()
g
.
revert
(
chk
)
g
.
revert
(
chk
)
...
@@ -202,7 +209,7 @@ class _test_all(unittest.TestCase):
...
@@ -202,7 +209,7 @@ class _test_all(unittest.TestCase):
e
=
add_in_place
(
x
,
y
)
e
=
add_in_place
(
x
,
y
)
g
=
Env
([
x
,
y
,
z
],
[
e
],
False
)
g
=
Env
([
x
,
y
,
z
],
[
e
],
False
)
assert
not
g
.
consistent
()
assert
not
g
.
consistent
()
g
.
replace
(
e
,
add
(
x
,
y
))
g
.
replace
_validate
(
e
,
add
(
x
,
y
))
assert
g
.
consistent
()
assert
g
.
consistent
()
def
test_indestructible_through_views
(
self
):
def
test_indestructible_through_views
(
self
):
...
@@ -212,7 +219,7 @@ class _test_all(unittest.TestCase):
...
@@ -212,7 +219,7 @@ class _test_all(unittest.TestCase):
e
=
add_in_place
(
tv
,
y
)
e
=
add_in_place
(
tv
,
y
)
g
=
Env
([
x
,
y
,
z
],
[
e
],
False
)
g
=
Env
([
x
,
y
,
z
],
[
e
],
False
)
assert
not
g
.
consistent
()
assert
not
g
.
consistent
()
g
.
replace
(
tv
,
sigmoid
(
x
))
g
.
replace
_validate
(
tv
,
sigmoid
(
x
))
assert
g
.
consistent
()
assert
g
.
consistent
()
def
test_repair_destroy_path
(
self
):
def
test_repair_destroy_path
(
self
):
...
@@ -223,7 +230,7 @@ class _test_all(unittest.TestCase):
...
@@ -223,7 +230,7 @@ class _test_all(unittest.TestCase):
e4
=
add_in_place
(
e1
,
z
)
e4
=
add_in_place
(
e1
,
z
)
g
=
Env
([
x
,
y
,
z
],
[
e3
,
e4
],
False
)
g
=
Env
([
x
,
y
,
z
],
[
e3
,
e4
],
False
)
assert
not
g
.
consistent
()
assert
not
g
.
consistent
()
g
.
replace
(
e2
,
transpose_view
(
x
)
,
False
)
g
.
replace
(
e2
,
transpose_view
(
x
))
assert
not
g
.
consistent
()
assert
not
g
.
consistent
()
def
test_indirect
(
self
):
def
test_indirect
(
self
):
...
@@ -233,9 +240,9 @@ class _test_all(unittest.TestCase):
...
@@ -233,9 +240,9 @@ class _test_all(unittest.TestCase):
g
=
Env
([
x
,
y
,
z
],
[
e
],
False
)
g
=
Env
([
x
,
y
,
z
],
[
e
],
False
)
assert
not
g
.
consistent
()
assert
not
g
.
consistent
()
new_e0
=
add
(
x
,
y
)
new_e0
=
add
(
x
,
y
)
g
.
replace
(
e0
,
new_e0
,
False
)
g
.
replace
(
e0
,
new_e0
)
assert
g
.
consistent
()
assert
g
.
consistent
()
g
.
replace
(
new_e0
,
add_in_place
(
x
,
y
)
,
False
)
g
.
replace
(
new_e0
,
add_in_place
(
x
,
y
))
assert
not
g
.
consistent
()
assert
not
g
.
consistent
()
def
test_indirect_2
(
self
):
def
test_indirect_2
(
self
):
...
@@ -245,12 +252,12 @@ class _test_all(unittest.TestCase):
...
@@ -245,12 +252,12 @@ class _test_all(unittest.TestCase):
g
=
Env
([
x
,
y
,
z
],
[
e
],
False
)
g
=
Env
([
x
,
y
,
z
],
[
e
],
False
)
assert
not
g
.
consistent
()
assert
not
g
.
consistent
()
new_e0
=
add
(
e0
,
y
)
new_e0
=
add
(
e0
,
y
)
g
.
replace
(
e0
,
new_e0
,
False
)
g
.
replace
(
e0
,
new_e0
)
assert
g
.
consistent
()
assert
g
.
consistent
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
#
unittest.main()
_test_all
(
'test_usage_loop_through_views'
)
.
debug
()
gof/_test_toolbox.py
浏览文件 @
65e08101
...
@@ -59,19 +59,19 @@ def inputs():
...
@@ -59,19 +59,19 @@ def inputs():
return
x
,
y
,
z
return
x
,
y
,
z
class
_test_EquivTool
(
unittest
.
TestCase
):
#
class _test_EquivTool(unittest.TestCase):
def
test_straightforward
(
self
):
#
def test_straightforward(self):
x
,
y
,
z
=
inputs
()
#
x, y, z = inputs()
sx
=
sigmoid
(
x
)
#
sx = sigmoid(x)
e
=
add
(
sx
,
sigmoid
(
y
))
#
e = add(sx, sigmoid(y))
g
=
Env
([
x
,
y
,
z
],
[
e
])
#
g = Env([x, y, z], [e])
g
.
extend
(
EquivTool
(
g
))
#
g.extend(EquivTool(g))
assert
hasattr
(
g
,
'equiv'
)
#
assert hasattr(g, 'equiv')
assert
g
.
equiv
(
sx
)
is
sx
#
assert g.equiv(sx) is sx
g
.
replace
(
sx
,
dot
(
x
,
z
))
#
g.replace(sx, dot(x, z))
assert
g
.
equiv
(
sx
)
is
not
sx
#
assert g.equiv(sx) is not sx
assert
g
.
equiv
(
sx
)
.
owner
.
op
is
dot
#
assert g.equiv(sx).owner.op is dot
class
_test_NodeFinder
(
unittest
.
TestCase
):
class
_test_NodeFinder
(
unittest
.
TestCase
):
...
@@ -81,7 +81,7 @@ class _test_NodeFinder(unittest.TestCase):
...
@@ -81,7 +81,7 @@ class _test_NodeFinder(unittest.TestCase):
e0
=
dot
(
y
,
z
)
e0
=
dot
(
y
,
z
)
e
=
add
(
add
(
sigmoid
(
x
),
sigmoid
(
sigmoid
(
z
))),
dot
(
add
(
x
,
y
),
e0
))
e
=
add
(
add
(
sigmoid
(
x
),
sigmoid
(
sigmoid
(
z
))),
dot
(
add
(
x
,
y
),
e0
))
g
=
Env
([
x
,
y
,
z
],
[
e
])
g
=
Env
([
x
,
y
,
z
],
[
e
])
g
.
extend
(
NodeFinder
(
g
))
g
.
extend
(
NodeFinder
())
assert
hasattr
(
g
,
'get_nodes'
)
assert
hasattr
(
g
,
'get_nodes'
)
for
type
,
num
in
((
add
,
3
),
(
sigmoid
,
3
),
(
dot
,
2
)):
for
type
,
num
in
((
add
,
3
),
(
sigmoid
,
3
),
(
dot
,
2
)):
if
not
len
([
x
for
x
in
g
.
get_nodes
(
type
)])
==
num
:
if
not
len
([
x
for
x
in
g
.
get_nodes
(
type
)])
==
num
:
...
@@ -100,7 +100,7 @@ class _test_NodeFinder(unittest.TestCase):
...
@@ -100,7 +100,7 @@ class _test_NodeFinder(unittest.TestCase):
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
e
=
add
(
add
(
sigmoid
(
x
),
sigmoid
(
sigmoid
(
z
))),
dot
(
add
(
x
,
y
),
dot
(
y
,
z
)))
e
=
add
(
add
(
sigmoid
(
x
),
sigmoid
(
sigmoid
(
z
))),
dot
(
add
(
x
,
y
),
dot
(
y
,
z
)))
g
=
Env
([
x
,
y
,
z
],
[
e
])
g
=
Env
([
x
,
y
,
z
],
[
e
])
g
.
extend
(
NodeFinder
(
g
))
g
.
extend
(
NodeFinder
())
gen
=
g
.
get_nodes
(
sigmoid
)
# I want to get Sigmoid instances
gen
=
g
.
get_nodes
(
sigmoid
)
# I want to get Sigmoid instances
g
.
replace
(
e
,
add
(
x
,
y
))
# but here I prune them all
g
.
replace
(
e
,
add
(
x
,
y
))
# but here I prune them all
assert
len
([
x
for
x
in
gen
])
==
0
# the generator should not yield them
assert
len
([
x
for
x
in
gen
])
==
0
# the generator should not yield them
...
...
gof/env.py
浏览文件 @
65e08101
...
@@ -16,7 +16,7 @@ class InconsistencyError(Exception):
...
@@ -16,7 +16,7 @@ class InconsistencyError(Exception):
class
Env
(
graph
.
Graph
):
class
Env
(
object
):
#(
graph.Graph):
"""
"""
An Env represents a subgraph bound by a set of input results and a
An Env represents a subgraph bound by a set of input results and a
set of output results. An op is in the subgraph iff it depends on
set of output results. An op is in the subgraph iff it depends on
...
@@ -59,14 +59,19 @@ class Env(graph.Graph):
...
@@ -59,14 +59,19 @@ class Env(graph.Graph):
raise
ValueError
(
"One of the provided inputs is the output of an already existing node. "
\
raise
ValueError
(
"One of the provided inputs is the output of an already existing node. "
\
"If that is okay, either discard that input's owner or use graph.clone."
)
"If that is okay, either discard that input's owner or use graph.clone."
)
self
.
__setup_r__
(
input
)
self
.
__setup_r__
(
input
)
self
.
results
.
add
(
input
)
self
.
outputs
=
outputs
self
.
__import_r__
(
outputs
)
self
.
__import_r__
(
outputs
)
self
.
outputs
=
outputs
for
i
,
output
in
enumerate
(
outputs
):
output
.
clients
.
append
((
'output'
,
i
))
self
.
node_locks
=
{}
self
.
result_locks
=
{}
# List of functions that undo the replace operations performed.
#
# List of functions that undo the replace operations performed.
# e.g. to recover the initial graph one could write: for u in self.history.__reversed__(): u()
#
# e.g. to recover the initial graph one could write: for u in self.history.__reversed__(): u()
self
.
history
=
[]
#
self.history = []
### Setup a Result ###
### Setup a Result ###
...
@@ -99,7 +104,7 @@ class Env(graph.Graph):
...
@@ -99,7 +104,7 @@ class Env(graph.Graph):
"""
"""
r
.
clients
+=
all
r
.
clients
+=
all
def
__remove_clients__
(
self
,
r
,
all
):
def
__remove_clients__
(
self
,
r
,
all
,
prune
=
True
):
"""
"""
r -> result
r -> result
all -> list of (op, i) pairs representing who r is an input of.
all -> list of (op, i) pairs representing who r is an input of.
...
@@ -109,14 +114,24 @@ class Env(graph.Graph):
...
@@ -109,14 +114,24 @@ class Env(graph.Graph):
for
entry
in
all
:
for
entry
in
all
:
r
.
clients
.
remove
(
entry
)
r
.
clients
.
remove
(
entry
)
# remove from orphans?
# remove from orphans?
if
not
r
.
clients
:
if
prune
:
self
.
__prune_r__
([
r
])
return
False
return
True
return
False
### import ###
### import ###
def
__import_r__
(
self
,
results
):
def
__import_r__
(
self
,
results
):
# Imports the owners of the results
# Imports the owners of the results
for
node
in
set
(
r
.
owner
for
r
in
results
if
r
is
not
None
):
for
node
in
set
(
r
.
owner
for
r
in
results
if
r
.
owner
is
not
None
):
self
.
__import__
(
node
)
self
.
__import__
(
node
)
for
r
in
results
:
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Value
)
and
r
not
in
self
.
inputs
:
raise
TypeError
(
"Undeclared input"
,
r
)
self
.
results
.
add
(
r
)
def
__import__
(
self
,
node
,
check
=
True
):
def
__import__
(
self
,
node
,
check
=
True
):
# We import the nodes in topological order. We only are interested
# We import the nodes in topological order. We only are interested
...
@@ -127,11 +142,13 @@ class Env(graph.Graph):
...
@@ -127,11 +142,13 @@ class Env(graph.Graph):
if
check
:
if
check
:
for
node
in
new_nodes
:
for
node
in
new_nodes
:
if
hasattr
(
node
,
'env'
)
and
node
.
env
is
not
self
or
\
if
hasattr
(
node
,
'env'
)
and
node
.
env
is
not
self
:
any
(
hasattr
(
r
,
'env'
)
and
r
.
env
is
not
self
or
\
raise
Exception
(
"
%
s is already owned by another env"
%
node
)
r
.
owner
is
None
and
not
isinstance
(
r
,
Value
)
and
r
not
in
self
.
inputs
for
r
in
node
.
inputs
:
for
r
in
node
.
inputs
+
node
.
outputs
):
if
hasattr
(
r
,
'env'
)
and
r
.
env
is
not
self
:
raise
Exception
(
"Could not import
%
s"
%
node
)
raise
Exception
(
"
%
s is already owned by another env"
%
r
)
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Value
)
and
r
not
in
self
.
inputs
:
raise
TypeError
(
"Undeclared input"
,
r
)
for
node
in
new_nodes
:
for
node
in
new_nodes
:
self
.
__setup_node__
(
node
)
self
.
__setup_node__
(
node
)
...
@@ -141,9 +158,6 @@ class Env(graph.Graph):
...
@@ -141,9 +158,6 @@ class Env(graph.Graph):
self
.
results
.
add
(
output
)
self
.
results
.
add
(
output
)
for
i
,
input
in
enumerate
(
node
.
inputs
):
for
i
,
input
in
enumerate
(
node
.
inputs
):
if
input
not
in
self
.
results
:
if
input
not
in
self
.
results
:
if
not
isinstance
(
input
,
Value
):
raise
TypeError
(
"The graph to import contains a leaf that is not an input and has no default value "
\
"(graph state is bad now - use check = True)"
,
input
)
self
.
__setup_r__
(
input
)
self
.
__setup_r__
(
input
)
self
.
results
.
add
(
input
)
self
.
results
.
add
(
input
)
self
.
__add_clients__
(
input
,
[(
node
,
i
)])
self
.
__add_clients__
(
input
,
[(
node
,
i
)])
...
@@ -155,7 +169,7 @@ class Env(graph.Graph):
...
@@ -155,7 +169,7 @@ class Env(graph.Graph):
def
__prune_r__
(
self
,
results
):
def
__prune_r__
(
self
,
results
):
# Prunes the owners of the results.
# Prunes the owners of the results.
for
node
in
set
(
r
.
owner
for
r
in
results
if
r
is
not
None
):
for
node
in
set
(
r
.
owner
for
r
in
results
if
r
.
owner
is
not
None
):
self
.
__prune__
(
node
)
self
.
__prune__
(
node
)
for
r
in
results
:
for
r
in
results
:
if
not
r
.
clients
and
r
in
self
.
results
:
if
not
r
.
clients
and
r
in
self
.
results
:
...
@@ -179,78 +193,99 @@ class Env(graph.Graph):
...
@@ -179,78 +193,99 @@ class Env(graph.Graph):
for
i
,
input
in
enumerate
(
node
.
inputs
):
for
i
,
input
in
enumerate
(
node
.
inputs
):
self
.
__remove_clients__
(
input
,
[(
node
,
i
)])
self
.
__remove_clients__
(
input
,
[(
node
,
i
)])
self
.
__prune_r__
(
node
.
inputs
)
#
self.__prune_r__(node.inputs)
### replace ###
### change input ###
def
change_input
(
self
,
node
,
i
,
new_r
):
if
node
==
'output'
:
r
=
self
.
outputs
[
i
]
if
not
r
.
type
==
new_r
.
type
:
raise
TypeError
(
"The type of the replacement must be the same as the type of the original Result."
,
r
,
new_r
)
self
.
outputs
[
i
]
=
new_r
else
:
if
node
.
env
is
not
self
:
raise
Exception
(
"Cannot operate on
%
s because it does not belong to this Env"
%
node
)
r
=
node
.
inputs
[
i
]
if
not
r
.
type
==
new_r
.
type
:
raise
TypeError
(
"The type of the replacement must be the same as the type of the original Result."
,
r
,
new_r
)
node
.
inputs
[
i
]
=
new_r
self
.
__import_r__
([
new_r
])
self
.
__add_clients__
(
new_r
,
[(
node
,
i
)])
prune
=
self
.
__remove_clients__
(
r
,
[(
node
,
i
)],
False
)
self
.
execute_callbacks
(
'on_change_input'
,
node
,
i
,
r
,
new_r
)
if
prune
:
self
.
__prune_r__
([
r
])
def
replace
(
self
,
r
,
new_r
,
consistency_check
=
True
):
### replace ###
def
replace
(
self
,
r
,
new_r
):
"""
"""
This is the main interface to manipulate the subgraph in Env.
This is the main interface to manipulate the subgraph in Env.
For every op that uses r as input, makes it use new_r instead.
For every op that uses r as input, makes it use new_r instead.
This may raise an error if the new result violates type
This may raise an error if the new result violates type
constraints for one of the target nodes. In that case, no
constraints for one of the target nodes. In that case, no
changes are made.
changes are made.
If the replacement makes the graph inconsistent and the value
of consistency_check is True, this function will raise an
InconsistencyError and will undo the operation, leaving the
graph the way it was before the call to replace.
If consistency_check is False, the replacement will succeed
even if there is an inconsistency, unless the replacement
violates hard constraints on the types involved.
"""
"""
if
r
.
env
is
not
self
:
if
r
.
env
is
not
self
:
raise
Exception
(
"Cannot replace
%
s because it does not belong to this Env"
%
r
)
raise
Exception
(
"Cannot replace
%
s because it does not belong to this Env"
%
r
)
if
not
r
.
type
==
new_r
.
type
:
raise
TypeError
(
"The type of the replacement must be the same as the type of the original Result."
,
r
,
new_r
)
assert
r
in
self
.
results
assert
r
in
self
.
results
# Save where we are so we can backtrack
for
node
,
i
in
r
.
clients
:
if
consistency_check
:
assert
node
==
'output'
and
self
.
outputs
[
i
]
is
r
or
node
.
inputs
[
i
]
is
r
chk
=
self
.
checkpoint
()
self
.
change_input
(
node
,
i
,
new_r
)
# # Save where we are so we can backtrack
# if consistency_check:
# chk = self.checkpoint()
# The copy is required so undo can know what clients to move back!
#
# The copy is required so undo can know what clients to move back!
clients
=
copy
(
self
.
clients
(
r
))
#
clients = copy(self.clients(r))
# Messy checks so we know what to do if we are replacing an output
#
# Messy checks so we know what to do if we are replacing an output
# result. Note that if v is an input result, we do nothing at all for
#
# result. Note that if v is an input result, we do nothing at all for
# now (it's not clear what it means to replace an input result).
#
# now (it's not clear what it means to replace an input result).
was_output
=
False
#
was_output = False
if
r
in
self
.
outputs
:
#
if r in self.outputs:
was_output
=
True
#
was_output = True
self
.
outputs
[
self
.
outputs
.
index
(
r
)]
=
new_r
#
self.outputs[self.outputs.index(r)] = new_r
was_input
=
False
#
was_input = False
if
r
in
self
.
inputs
:
#
if r in self.inputs:
was_input
=
True
#
was_input = True
self
.
inputs
[
self
.
inputs
.
index
(
r
)]
=
new_r
#
self.inputs[self.inputs.index(r)] = new_r
# The actual replacement operation occurs here. This might raise
#
# The actual replacement operation occurs here. This might raise
# an error.
#
# an error.
self
.
__move_clients__
(
clients
,
r
,
new_r
)
# not sure how to order this wrt to adjusting the outputs
#
self.__move_clients__(clients, r, new_r) # not sure how to order this wrt to adjusting the outputs
# This function undoes the replacement.
#
# This function undoes the replacement.
def
undo
():
#
def undo():
# Restore self.outputs
#
# Restore self.outputs
if
was_output
:
#
if was_output:
self
.
outputs
[
self
.
outputs
.
index
(
new_r
)]
=
r
#
self.outputs[self.outputs.index(new_r)] = r
# Restore self.inputs
#
# Restore self.inputs
if
was_input
:
#
if was_input:
self
.
inputs
[
self
.
inputs
.
index
(
new_r
)]
=
r
#
self.inputs[self.inputs.index(new_r)] = r
# Move back the clients. This should never raise an error.
#
# Move back the clients. This should never raise an error.
self
.
__move_clients__
(
clients
,
new_r
,
r
)
#
self.__move_clients__(clients, new_r, r)
self
.
history
.
append
(
undo
)
#
self.history.append(undo)
if
consistency_check
:
#
if consistency_check:
try
:
#
try:
self
.
validate
()
#
self.validate()
except
InconsistencyError
,
e
:
#
except InconsistencyError, e:
self
.
revert
(
chk
)
#
self.revert(chk)
raise
#
raise
def
replace_all
(
self
,
d
):
def
replace_all
(
self
,
d
):
"""
"""
...
@@ -259,42 +294,47 @@ class Env(graph.Graph):
...
@@ -259,42 +294,47 @@ class Env(graph.Graph):
graph is not consistent. If an error is raised, the graph is
graph is not consistent. If an error is raised, the graph is
restored to what it was before.
restored to what it was before.
"""
"""
chk
=
self
.
checkpoint
()
for
r
,
new_r
in
d
.
items
():
try
:
self
.
replace
(
r
,
new_r
,
False
)
for
r
,
new_r
in
d
.
items
():
# chk = self.checkpoint()
self
.
replace
(
r
,
new_r
,
False
)
# try:
except
Exception
,
e
:
# for r, new_r in d.items():
self
.
revert
(
chk
)
# self.replace(r, new_r, False)
raise
# except Exception, e:
try
:
# self.revert(chk)
self
.
validate
()
# raise
except
InconsistencyError
,
e
:
# try:
self
.
revert
(
chk
)
# self.validate()
raise
# except InconsistencyError, e:
# self.revert(chk)
# raise
def
checkpoint
(
self
):
#
def checkpoint(self):
"""
#
"""
Returns an object that can be passed to self.revert in order to backtrack
#
Returns an object that can be passed to self.revert in order to backtrack
to a previous state.
#
to a previous state.
"""
#
"""
return
len
(
self
.
history
)
#
return len(self.history)
def
consistent
(
self
):
#
def consistent(self):
"""
#
"""
Returns True iff the subgraph is consistent and does not violate the
#
Returns True iff the subgraph is consistent and does not violate the
constraints set by the listeners.
#
constraints set by the listeners.
"""
#
"""
try
:
#
try:
self
.
validate
()
#
self.validate()
except
InconsistencyError
:
#
except InconsistencyError:
return
False
#
return False
return
True
#
return True
def
extend
(
self
,
feature
,
do_import
=
True
,
validate
=
False
):
### features ###
def
extend
(
self
,
feature
):
"""
"""
@todo out of date
@todo out of date
Adds an instance of the feature_class to this env's supported
Adds an instance of the feature_class to this env's supported
...
@@ -304,17 +344,34 @@ class Env(graph.Graph):
...
@@ -304,17 +344,34 @@ class Env(graph.Graph):
"""
"""
if
feature
in
self
.
_features
:
if
feature
in
self
.
_features
:
return
# the feature is already present
return
# the feature is already present
self
.
__add_feature__
(
feature
,
do_import
)
self
.
_features
.
append
(
feature
)
if
validate
:
attach
=
getattr
(
feature
,
'on_attach'
,
None
)
self
.
validate
()
if
attach
is
not
None
:
try
:
attach
(
self
)
except
:
self
.
_features
.
pop
()
raise
def
remove_feature
(
self
,
feature
):
try
:
self
.
_features
.
remove
(
feature
)
except
:
return
deattach
=
getattr
(
feature
,
'on_deattach'
,
None
)
if
deattach
is
not
None
:
deattach
(
self
)
### callback utils ###
def
execute_callbacks
(
self
,
name
,
*
args
):
def
execute_callbacks
(
self
,
name
,
*
args
):
for
feature
in
self
.
_features
:
for
feature
in
self
.
_features
:
try
:
try
:
fn
=
getattr
(
feature
,
name
)
fn
=
getattr
(
feature
,
name
)
except
AttributeError
:
except
AttributeError
:
continue
continue
fn
(
*
args
)
fn
(
self
,
*
args
)
def
collect_callbacks
(
self
,
name
,
*
args
):
def
collect_callbacks
(
self
,
name
,
*
args
):
d
=
{}
d
=
{}
...
@@ -326,35 +383,9 @@ class Env(graph.Graph):
...
@@ -326,35 +383,9 @@ class Env(graph.Graph):
d
[
feature
]
=
fn
(
*
args
)
d
[
feature
]
=
fn
(
*
args
)
return
d
return
d
def
__add_feature__
(
self
,
feature
,
do_import
):
self
.
_features
.
append
(
feature
)
publish
=
getattr
(
feature
,
'publish'
,
None
)
if
publish
is
not
None
:
publish
()
if
do_import
:
try
:
fn
=
feature
.
on_import
except
AttributeError
:
return
for
node
in
self
.
io_toposort
():
fn
(
node
)
def
__del_feature__
(
self
,
feature
):
try
:
del
self
.
_features
[
feature
]
except
:
pass
unpublish
=
hasattr
(
feature
,
'unpublish'
)
if
unpublish
is
not
None
:
unpublish
()
def
get_feature
(
self
,
feature
):
idx
=
self
.
_features
.
index
(
feature
)
return
self
.
_features
[
idx
]
def
has_feature
(
self
,
feature
):
return
feature
in
self
.
_features
### misc ###
def
nclients
(
self
,
r
):
def
nclients
(
self
,
r
):
"Same as len(self.clients(r))."
"Same as len(self.clients(r))."
return
len
(
self
.
clients
(
r
))
return
len
(
self
.
clients
(
r
))
...
@@ -374,114 +405,156 @@ class Env(graph.Graph):
...
@@ -374,114 +405,156 @@ class Env(graph.Graph):
def
has_node
(
self
,
node
):
def
has_node
(
self
,
node
):
return
node
in
self
.
nodes
return
node
in
self
.
nodes
def
revert
(
self
,
checkpoint
):
def
check_integrity
(
self
):
"""
nodes
=
graph
.
ops
(
self
.
inputs
,
self
.
outputs
)
Reverts the graph to whatever it was at the provided
if
self
.
nodes
!=
nodes
:
checkpoint (undoes all replacements). A checkpoint at any
missing
=
nodes
.
difference
(
self
.
nodes
)
given time can be obtained using self.checkpoint().
excess
=
self
.
nodes
.
difference
(
nodes
)
"""
raise
Exception
(
"The nodes are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
while
len
(
self
.
history
)
>
checkpoint
:
for
node
in
nodes
:
f
=
self
.
history
.
pop
()
if
node
.
env
is
not
self
:
f
()
raise
Exception
(
"Node should belong to the env."
,
node
)
for
i
,
result
in
enumerate
(
node
.
inputs
):
if
result
.
env
is
not
self
:
raise
Exception
(
"Input of node should belong to the env."
,
result
,
(
node
,
i
))
if
(
node
,
i
)
not
in
result
.
clients
:
raise
Exception
(
"Inconsistent clients list."
,
(
node
,
i
),
result
.
clients
)
results
=
graph
.
results
(
self
.
inputs
,
self
.
outputs
)
if
self
.
results
!=
results
:
missing
=
results
.
difference
(
self
.
results
)
excess
=
self
.
results
.
difference
(
results
)
raise
Exception
(
"The results are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
for
result
in
results
:
if
result
.
owner
is
None
and
result
not
in
self
.
inputs
and
not
isinstance
(
result
,
graph
.
Value
):
raise
Exception
(
"Undeclared input."
,
result
)
if
result
.
env
is
not
self
:
raise
Exception
(
"Result should belong to the env."
,
result
)
for
node
,
i
in
result
.
clients
:
if
node
==
'output'
:
if
self
.
outputs
[
i
]
is
not
result
:
raise
Exception
(
"Inconsistent clients list."
,
result
,
self
.
outputs
[
i
])
continue
if
node
not
in
nodes
:
raise
Exception
(
"Client not in env."
,
result
,
(
node
,
i
))
if
node
.
inputs
[
i
]
is
not
result
:
raise
Exception
(
"Inconsistent clients list."
,
result
,
node
.
inputs
[
i
])
def
supplemental_orderings
(
self
):
# def revert(self, checkpoint):
"""
# """
Returns a dictionary of {op: set(prerequisites)} that must
# Reverts the graph to whatever it was at the provided
be satisfied in addition to the order defined by the structure
# checkpoint (undoes all replacements). A checkpoint at any
of the graph (returns orderings that not related to input/output
# given time can be obtained using self.checkpoint().
relationships).
# """
"""
# while len(self.history) > checkpoint:
ords
=
{}
# f = self.history.pop()
for
feature
in
self
.
_features
:
# f()
if
hasattr
(
feature
,
'orderings'
):
for
op
,
prereqs
in
feature
.
orderings
()
.
items
():
# def supplemental_orderings(self):
ords
.
setdefault
(
op
,
set
())
.
update
(
prereqs
)
# """
return
ords
# Returns a dictionary of {op: set(prerequisites)} that must
# be satisfied in addition to the order defined by the structure
def
toposort
(
self
):
# of the graph (returns orderings that not related to input/output
"""
# relationships).
Returns a list of nodes in the order that they must be executed
# """
in order to preserve the semantics of the graph and respect
# ords = {}
the constraints put forward by the listeners.
# for feature in self._features:
"""
# if hasattr(feature, 'orderings'):
ords
=
self
.
supplemental_orderings
()
# for op, prereqs in feature.orderings().items():
order
=
graph
.
io_toposort
(
self
.
inputs
,
self
.
outputs
,
ords
)
# ords.setdefault(op, set()).update(prereqs)
return
order
# return ords
def
validate
(
self
):
# def toposort(self):
"""
# """
Raises an error if the graph is inconsistent.
# Returns a list of nodes in the order that they must be executed
"""
# in order to preserve the semantics of the graph and respect
self
.
execute_callbacks
(
'validate'
)
# the constraints put forward by the listeners.
# for constraint in self._constraints.values():
# """
# constraint.validate()
# ords = self.supplemental_orderings()
return
True
# order = graph.io_toposort(self.inputs, self.outputs, ords)
# return order
# def validate(self):
# """
# Raises an error if the graph is inconsistent.
# """
# self.execute_callbacks('validate')
# # for constraint in self._constraints.values():
# # constraint.validate()
# return True
### Private interface ###
### Private interface ###
def
__move_clients__
(
self
,
clients
,
r
,
new_r
):
#
def __move_clients__(self, clients, r, new_r):
if
not
(
r
.
type
==
new_r
.
type
):
#
if not (r.type == new_r.type):
raise
TypeError
(
"Cannot move clients between Results that have different types."
,
r
,
new_r
)
#
raise TypeError("Cannot move clients between Results that have different types.", r, new_r)
# We import the new result in the fold
self
.
__import_r__
([
new_r
])
for
op
,
i
in
clients
:
op
.
inputs
[
i
]
=
new_r
# try:
# # Try replacing the inputs
# for op, i in clients:
# op.set_input(i, new_r)
# except:
# # Oops!
# for op, i in clients:
# op.set_input(i, r)
# self.__prune_r__([new_r])
# raise
self
.
__remove_clients__
(
r
,
clients
)
self
.
__add_clients__
(
new_r
,
clients
)
# # We import the new result in the fold
# # We import the new result in the fold
# # why was this line AFTER the set_inputs???
# # if we do it here then satisfy in import fucks up...
# self.__import_r__([new_r])
# self.__import_r__([new_r])
self
.
execute_callbacks
(
'on_rewire'
,
clients
,
r
,
new_r
)
# for op, i in clients:
# for listener in self._listeners.values():
# op.inputs[i] = new_r
# try:
# # try:
# listener.on_rewire(clients, r, new_r)
# # # Try replacing the inputs
# except AbstractFunctionError:
# # for op, i in clients:
# pass
# # op.set_input(i, new_r)
# # except:
# We try to get rid of the old one
# # # Oops!
self
.
__prune_r__
([
r
])
# # for op, i in clients:
# # op.set_input(i, r)
# # self.__prune_r__([new_r])
# # raise
# self.__remove_clients__(r, clients)
# self.__add_clients__(new_r, clients)
# # # We import the new result in the fold
# # # why was this line AFTER the set_inputs???
# # # if we do it here then satisfy in import fucks up...
# # self.__import_r__([new_r])
# self.execute_callbacks('on_rewire', clients, r, new_r)
# # for listener in self._listeners.values():
# # try:
# # listener.on_rewire(clients, r, new_r)
# # except AbstractFunctionError:
# # pass
# # We try to get rid of the old one
# self.__prune_r__([r])
def
__str__
(
self
):
def
__str__
(
self
):
return
"[
%
s]"
%
", "
.
join
(
graph
.
as_string
(
self
.
inputs
,
self
.
outputs
))
return
"[
%
s]"
%
", "
.
join
(
graph
.
as_string
(
self
.
inputs
,
self
.
outputs
))
def
clone_get_equiv
(
self
,
clone_inputs
=
True
):
# def clone_get_equiv(self, clone_inputs = True):
equiv
=
graph
.
clone_get_equiv
(
self
.
inputs
,
self
.
outputs
,
clone_inputs
)
# equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
new
=
self
.
__class__
([
equiv
[
input
]
for
input
in
self
.
inputs
],
# new = self.__class__([equiv[input] for input in self.inputs],
[
equiv
[
output
]
for
output
in
self
.
outputs
])
# [equiv[output] for output in self.outputs])
for
feature
in
self
.
_features
:
# for feature in self._features:
new
.
extend
(
feature
)
# new.extend(feature)
return
new
,
equiv
# return new, equiv
# def clone(self, clone_inputs = True):
# equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
# new = self.__class__([equiv[input] for input in self.inputs],
# [equiv[output] for output in self.outputs])
# for feature in self._features:
# new.extend(feature)
# try:
# new.set_equiv(equiv)
# except AttributeError:
# pass
# return new
# def __copy__(self):
# return self.clone()
def
clone
(
self
,
clone_inputs
=
True
):
equiv
=
graph
.
clone_get_equiv
(
self
.
inputs
,
self
.
outputs
,
clone_inputs
)
new
=
self
.
__class__
([
equiv
[
input
]
for
input
in
self
.
inputs
],
[
equiv
[
output
]
for
output
in
self
.
outputs
])
for
feature
in
self
.
_features
:
new
.
extend
(
feature
)
try
:
new
.
set_equiv
(
equiv
)
except
AttributeError
:
pass
return
new
def
__copy__
(
self
):
return
self
.
clone
()
gof/ext.py
浏览文件 @
65e08101
from
features
import
Listener
,
Constraint
,
Orderings
,
Tool
#from features import Listener, Constraint, Orderings, Tool
import
utils
from
utils
import
AbstractFunctionError
from
utils
import
AbstractFunctionError
from
copy
import
copy
from
copy
import
copy
from
env
import
InconsistencyError
from
env
import
InconsistencyError
from
toolbox
import
Bookkeeper
class
DestroyHandler
(
Listener
,
Constraint
,
Orderings
,
Tool
):
from
collections
import
defaultdict
class
DestroyHandler
(
Bookkeeper
):
#(Listener, Constraint, Orderings, Tool):
"""
"""
This feature ensures that an env represents a consistent data flow
This feature ensures that an env represents a consistent data flow
when some Ops overwrite their inputs and/or provide "views" over
when some Ops overwrite their inputs and/or provide "views" over
...
@@ -27,14 +35,32 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -27,14 +35,32 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
This feature allows some optimizations (eg sub += for +) to be applied
This feature allows some optimizations (eg sub += for +) to be applied
safely.
safely.
"""
"""
def
__init__
(
self
,
env
):
def
__init__
(
self
):
self
.
env
=
None
def
on_attach
(
self
,
env
):
if
self
.
env
is
not
None
:
raise
Exception
(
"A DestroyHandler instance can only serve one Env."
)
for
attr
in
(
'destroyers'
,
'destroy_handler'
):
if
hasattr
(
env
,
attr
):
raise
Exception
(
"DestroyHandler feature is already present or in conflict with another plugin."
)
def
__destroyers
(
r
):
ret
=
self
.
destroyers
.
get
(
r
,
{})
ret
=
ret
.
keys
()
return
ret
env
.
destroyers
=
__destroyers
env
.
destroy_handler
=
self
self
.
env
=
env
# For an Op that has a view_map, {output : input it is a view of}
# For an Op that has a view_map, {output : input it is a view of}
self
.
parent
=
{}
self
.
parent
=
{}
# Reverse mapping of parent: {input : outputs that are a view of it}
# Reverse mapping of parent: {input : outputs that are a view of it}
self
.
children
=
{}
self
.
children
=
defaultdict
(
set
)
# {foundation : {op that destroys it : path }}
# {foundation : {op that destroys it : path }}
# where foundation is a result such that (not self.parent[result])
# where foundation is a result such that (not self.parent[result])
...
@@ -57,25 +83,37 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -57,25 +83,37 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
# indestructible by the user.
# indestructible by the user.
self
.
illegal
=
set
()
self
.
illegal
=
set
()
self
.
env
=
env
self
.
seen
=
set
()
self
.
seen
=
set
()
# Initialize the children if the inputs and orphans.
Bookkeeper
.
on_attach
(
self
,
env
)
for
input
in
env
.
orphans
.
union
(
env
.
inputs
):
self
.
children
[
input
]
=
set
()
# # Initialize the children if the inputs and orphans.
# for input in env.inputs: # env.orphans.union(env.inputs):
def
publish
(
self
):
# self.children[input] = set()
"""
Publishes the following on the env:
def
on_detach
(
self
,
env
):
- destroyers(r) -> returns all L{Op}s that destroy the result r
del
self
.
parent
- destroy_handler -> self
del
self
.
children
"""
del
self
.
destroyers
def
__destroyers
(
r
):
del
self
.
paths
ret
=
self
.
destroyers
.
get
(
r
,
{})
del
self
.
dups
ret
=
ret
.
keys
()
del
self
.
cycles
return
ret
del
self
.
illegal
self
.
env
.
destroyers
=
__destroyers
del
self
.
seen
self
.
env
.
destroy_handler
=
self
self
.
env
=
None
# def publish(self):
# """
# Publishes the following on the env:
# - destroyers(r) -> returns all L{Op}s that destroy the result r
# - destroy_handler -> self
# """
# def __destroyers(r):
# ret = self.destroyers.get(r, {})
# ret = ret.keys()
# return ret
# self.env.destroyers = __destroyers
# self.env.destroy_handler = self
def
__path__
(
self
,
r
):
def
__path__
(
self
,
r
):
"""
"""
...
@@ -105,12 +143,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -105,12 +143,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
"""
"""
children
=
self
.
children
[
r
]
children
=
self
.
children
[
r
]
if
not
children
:
if
not
children
:
return
set
([
r
])
return
[
r
]
else
:
else
:
rval
=
set
([
r
])
rval
=
[
r
]
for
child
in
children
:
for
child
in
children
:
rval
.
update
(
self
.
__views__
(
child
)
)
rval
+=
self
.
__views__
(
child
)
return
rval
return
utils
.
uniq
(
rval
)
def
__users__
(
self
,
r
):
def
__users__
(
self
,
r
):
"""
"""
...
@@ -120,12 +158,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -120,12 +158,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
is returned.
is returned.
"""
"""
views
=
self
.
__views__
(
r
)
views
=
self
.
__views__
(
r
)
rval
=
set
()
rval
=
[]
#
set()
for
view
in
views
:
for
view
in
views
:
for
op
,
i
in
self
.
env
.
clients
(
view
):
for
node
,
i
in
view
.
clients
:
#
self.env.clients(view):
if
op
in
self
.
seen
:
if
node
!=
'output'
:
rval
.
update
(
op
.
outputs
)
rval
+=
node
.
outputs
return
rval
return
utils
.
uniq
(
rval
)
def
__pre__
(
self
,
op
):
def
__pre__
(
self
,
op
):
"""
"""
...
@@ -178,7 +216,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -178,7 +216,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
just_remove is True, we return immediately after removing the
just_remove is True, we return immediately after removing the
cycles.
cycles.
"""
"""
users
=
se
lf
.
__users__
(
start
)
users
=
se
t
(
self
.
__users__
(
start
)
)
users
.
add
(
start
)
users
.
add
(
start
)
for
user
in
users
:
for
user
in
users
:
for
cycle
in
copy
(
self
.
cycles
):
for
cycle
in
copy
(
self
.
cycles
):
...
@@ -208,13 +246,14 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -208,13 +246,14 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
dmap
[
node
.
outputs
[
oidx
]]
=
[
node
.
inputs
[
iidx
]
for
iidx
in
iidxs
]
dmap
[
node
.
outputs
[
oidx
]]
=
[
node
.
inputs
[
iidx
]
for
iidx
in
iidxs
]
return
vmap
,
dmap
return
vmap
,
dmap
def
on_import
(
self
,
op
):
def
on_import
(
self
,
env
,
op
):
"""
"""
Recomputes the dependencies and search for inconsistencies given
Recomputes the dependencies and search for inconsistencies given
that we just added an op to the env.
that we just added an op to the env.
"""
"""
self
.
seen
.
add
(
op
)
self
.
seen
.
add
(
op
)
op
.
deps
[
'destroy'
]
=
[]
view_map
,
destroy_map
=
self
.
get_maps
(
op
)
view_map
,
destroy_map
=
self
.
get_maps
(
op
)
for
input
in
op
.
inputs
:
for
input
in
op
.
inputs
:
...
@@ -251,7 +290,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -251,7 +290,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self
.
__detect_cycles_helper__
(
output
,
[])
self
.
__detect_cycles_helper__
(
output
,
[])
def
on_prune
(
self
,
op
):
def
on_prune
(
self
,
env
,
op
):
"""
"""
Recomputes the dependencies and searches for inconsistencies to remove
Recomputes the dependencies and searches for inconsistencies to remove
given that we just removed an op to the env.
given that we just removed an op to the env.
...
@@ -295,6 +334,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -295,6 +334,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
del
self
.
children
[
output
]
del
self
.
children
[
output
]
self
.
seen
.
remove
(
op
)
self
.
seen
.
remove
(
op
)
del
op
.
deps
[
'destroy'
]
def
__add_destroyer__
(
self
,
path
):
def
__add_destroyer__
(
self
,
path
):
...
@@ -305,11 +345,18 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -305,11 +345,18 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
foundation
=
path
[
0
]
foundation
=
path
[
0
]
target
=
path
[
-
1
]
target
=
path
[
-
1
]
op
=
target
.
owner
node
=
target
.
owner
destroyers
=
self
.
destroyers
.
setdefault
(
foundation
,
{})
destroyers
=
self
.
destroyers
.
setdefault
(
foundation
,
{})
path
=
destroyers
.
setdefault
(
op
,
path
)
path
=
destroyers
.
setdefault
(
node
,
path
)
print
"add"
,
path
node
.
deps
[
'destroy'
]
+=
[
user
.
owner
for
user
in
self
.
__users__
(
foundation
)
if
user
not
in
node
.
outputs
]
# for foundation, destroyers in self.destroyers.items():
# for op in destroyers.keys():
# ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
if
len
(
destroyers
)
>
1
:
if
len
(
destroyers
)
>
1
:
self
.
dups
.
add
(
foundation
)
self
.
dups
.
add
(
foundation
)
...
@@ -325,10 +372,17 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -325,10 +372,17 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
foundation
=
path
[
0
]
foundation
=
path
[
0
]
target
=
path
[
-
1
]
target
=
path
[
-
1
]
op
=
target
.
owner
node
=
target
.
owner
print
"rm"
,
path
print
node
.
deps
[
'destroy'
]
for
user
in
self
.
__users__
(
foundation
):
print
" -- "
,
user
if
user
not
in
node
.
outputs
:
node
.
deps
[
'destroy'
]
.
remove
(
user
.
owner
)
destroyers
=
self
.
destroyers
[
foundation
]
destroyers
=
self
.
destroyers
[
foundation
]
del
destroyers
[
op
]
del
destroyers
[
node
]
if
not
destroyers
:
if
not
destroyers
:
if
foundation
in
self
.
illegal
:
if
foundation
in
self
.
illegal
:
...
@@ -338,14 +392,18 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -338,14 +392,18 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self
.
dups
.
remove
(
foundation
)
self
.
dups
.
remove
(
foundation
)
def
on_rewire
(
self
,
clients
,
r_1
,
r_2
):
def
on_change_input
(
self
,
env
,
node
,
i
,
r
,
new_r
):
if
node
!=
'output'
:
self
.
on_rewire
(
env
,
[(
node
,
i
)],
r
,
new_r
)
def
on_rewire
(
self
,
env
,
clients
,
r_1
,
r_2
):
"""
"""
Recomputes the dependencies and searches for inconsistencies to remove
Recomputes the dependencies and searches for inconsistencies to remove
given that all the clients are moved from r_1 to r_2, clients being
given that all the clients are moved from r_1 to r_2, clients being
a list of (op, i) pairs such that op.inputs[i] used to be r_1 and is
a list of (op, i) pairs such that op.inputs[i] used to be r_1 and is
now r_2.
now r_2.
"""
"""
path_1
=
self
.
__path__
(
r_1
)
path_1
=
self
.
__path__
(
r_1
)
path_2
=
self
.
__path__
(
r_2
)
path_2
=
self
.
__path__
(
r_2
)
...
@@ -396,7 +454,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -396,7 +454,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self
.
children
.
setdefault
(
r_2
,
set
())
self
.
children
.
setdefault
(
r_2
,
set
())
self
.
__detect_cycles__
(
r_2
)
self
.
__detect_cycles__
(
r_2
)
def
validate
(
self
):
def
validate
(
self
,
env
):
"""
"""
Raises an L{InconsistencyError} on any of the following conditions:
Raises an L{InconsistencyError} on any of the following conditions:
- Some results are destroyed by more than one L{Op}
- Some results are destroyed by more than one L{Op}
...
@@ -412,9 +470,9 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -412,9 +470,9 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
else
:
else
:
return
True
return
True
def
orderings
(
self
):
def
orderings
(
self
,
env
):
"""
"""
Returns a dict of {
op : set(op
s that must be computed before it)} according
Returns a dict of {
node : set(node
s that must be computed before it)} according
to L{DestroyHandler}.
to L{DestroyHandler}.
In particular, all the users of a destroyed result have priority over the
In particular, all the users of a destroyed result have priority over the
L{Op} that destroys the result.
L{Op} that destroys the result.
...
@@ -426,6 +484,8 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
...
@@ -426,6 +484,8 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
return
ords
return
ords
class
Destroyer
:
class
Destroyer
:
"""
"""
Base class for Ops that destroy one or more of their inputs in an
Base class for Ops that destroy one or more of their inputs in an
...
@@ -493,3 +553,4 @@ def view_roots(r):
...
@@ -493,3 +553,4 @@ def view_roots(r):
return
[
r
]
return
[
r
]
else
:
else
:
return
[
r
]
return
[
r
]
gof/graph.py
浏览文件 @
65e08101
...
@@ -202,7 +202,7 @@ def results_and_orphans(i, o, except_unreachable_input=False):
...
@@ -202,7 +202,7 @@ def results_and_orphans(i, o, except_unreachable_input=False):
"""
"""
results
=
set
()
results
=
set
()
i
=
set
(
i
)
i
=
set
(
i
)
results
.
update
(
i
)
#
results.update(i)
incomplete_paths
=
[]
incomplete_paths
=
[]
reached
=
set
()
reached
=
set
()
...
@@ -287,7 +287,7 @@ def orphans(i, o):
...
@@ -287,7 +287,7 @@ def orphans(i, o):
return
results_and_orphans
(
i
,
o
)[
1
]
return
results_and_orphans
(
i
,
o
)[
1
]
def
clone
(
i
,
o
,
copy_inputs
=
Fals
e
):
def
clone
(
i
,
o
,
copy_inputs
=
Tru
e
):
"""
"""
@type i: list
@type i: list
@param i: input L{Result}s
@param i: input L{Result}s
...
@@ -299,8 +299,8 @@ def clone(i, o, copy_inputs = False):
...
@@ -299,8 +299,8 @@ def clone(i, o, copy_inputs = False):
Copies the subgraph contained between i and o and returns the
Copies the subgraph contained between i and o and returns the
outputs of that copy (corresponding to o).
outputs of that copy (corresponding to o).
"""
"""
equiv
=
clone_get_equiv
(
i
,
o
)
equiv
=
clone_get_equiv
(
i
,
o
,
copy_inputs
)
return
[
equiv
[
output
]
for
output
in
o
]
return
[
equiv
[
input
]
for
input
in
i
],
[
equiv
[
output
]
for
output
in
o
]
def
clone_get_equiv
(
i
,
o
,
copy_inputs_and_orphans
=
False
):
def
clone_get_equiv
(
i
,
o
,
copy_inputs_and_orphans
=
False
):
...
@@ -324,7 +324,7 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
...
@@ -324,7 +324,7 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
for
input
in
i
:
for
input
in
i
:
if
copy_inputs_and_orphans
:
if
copy_inputs_and_orphans
:
cpy
=
copy
(
input
)
cpy
=
input
.
clone
(
)
cpy
.
owner
=
None
cpy
.
owner
=
None
cpy
.
index
=
None
cpy
.
index
=
None
d
[
input
]
=
cpy
d
[
input
]
=
cpy
...
@@ -337,7 +337,7 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
...
@@ -337,7 +337,7 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
node
=
result
.
owner
node
=
result
.
owner
if
node
is
None
:
# result is an orphan
if
node
is
None
:
# result is an orphan
if
copy_inputs_and_orphans
:
if
copy_inputs_and_orphans
:
cpy
=
copy
(
result
)
cpy
=
result
.
clone
(
)
d
[
result
]
=
cpy
d
[
result
]
=
cpy
else
:
else
:
d
[
result
]
=
result
d
[
result
]
=
result
...
...
gof/opt.py
浏览文件 @
65e08101
...
@@ -115,7 +115,10 @@ class OpSpecificOptimizer(LocalOptimizer):
...
@@ -115,7 +115,10 @@ class OpSpecificOptimizer(LocalOptimizer):
"""
"""
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
env
.
extend
(
toolbox
.
NodeFinder
(
env
))
try
:
env
.
extend
(
toolbox
.
NodeFinder
())
env
.
extend
(
toolbox
.
ReplaceValidate
())
except
:
pass
def
candidates
(
self
,
env
):
def
candidates
(
self
,
env
):
"""
"""
...
@@ -135,7 +138,10 @@ class OpSubOptimizer(Optimizer):
...
@@ -135,7 +138,10 @@ class OpSubOptimizer(Optimizer):
"""
"""
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
env
.
extend
(
toolbox
.
NodeFinder
(
env
))
try
:
env
.
extend
(
toolbox
.
NodeFinder
())
env
.
extend
(
toolbox
.
ReplaceValidate
())
except
:
pass
def
__init__
(
self
,
op1
,
op2
,
failure_callback
=
None
):
def
__init__
(
self
,
op1
,
op2
,
failure_callback
=
None
):
"""
"""
...
@@ -163,7 +169,7 @@ class OpSubOptimizer(Optimizer):
...
@@ -163,7 +169,7 @@ class OpSubOptimizer(Optimizer):
repl
=
self
.
op2
.
make_node
(
*
node
.
inputs
)
repl
=
self
.
op2
.
make_node
(
*
node
.
inputs
)
assert
len
(
node
.
outputs
)
==
len
(
repl
.
outputs
)
assert
len
(
node
.
outputs
)
==
len
(
repl
.
outputs
)
for
old
,
new
in
zip
(
node
.
outputs
,
repl
.
outputs
):
for
old
,
new
in
zip
(
node
.
outputs
,
repl
.
outputs
):
env
.
replace
(
old
,
new
)
env
.
replace
_validate
(
old
,
new
)
except
Exception
,
e
:
except
Exception
,
e
:
if
self
.
failure_callback
is
not
None
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
node
,
repl
,
e
)
self
.
failure_callback
(
node
,
repl
,
e
)
...
@@ -182,7 +188,10 @@ class OpRemover(Optimizer):
...
@@ -182,7 +188,10 @@ class OpRemover(Optimizer):
"""
"""
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
env
.
extend
(
toolbox
.
NodeFinder
(
env
))
try
:
env
.
extend
(
toolbox
.
NodeFinder
())
env
.
extend
(
toolbox
.
ReplaceValidate
())
except
:
pass
def
__init__
(
self
,
op
,
failure_callback
=
None
):
def
__init__
(
self
,
op
,
failure_callback
=
None
):
"""
"""
...
...
gof/toolbox.py
浏览文件 @
65e08101
from
random
import
shuffle
from
random
import
shuffle
import
utils
import
utils
from
functools
import
partial
import
graph
class
EquivTool
(
dict
):
class
Bookkeeper
:
def
on_attach
(
self
,
env
):
for
node
in
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
):
self
.
on_import
(
env
,
node
)
def
__init__
(
self
,
env
):
def
on_deattach
(
self
,
env
):
self
.
env
=
env
for
node
in
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
):
self
.
on_prune
(
env
,
node
)
class
History
:
def
__init__
(
self
):
self
.
history
=
{}
def
on_rewire
(
self
,
clients
,
r
,
new_r
):
def
on_attach
(
self
,
env
):
repl
=
self
(
new_r
)
if
hasattr
(
env
,
'checkpoint'
)
or
hasattr
(
env
,
'revert'
):
if
repl
is
r
:
raise
Exception
(
"History feature is already present or in conflict with another plugin."
)
self
.
ungroup
(
r
,
new_r
)
self
.
history
[
env
]
=
[]
elif
repl
is
not
new_r
:
env
.
checkpoint
=
lambda
:
len
(
self
.
history
[
env
])
raise
Exception
(
"Improper use of EquivTool!"
)
env
.
revert
=
partial
(
self
.
revert
,
env
)
else
:
self
.
group
(
new_r
,
r
)
def
on_deattach
(
self
,
env
):
del
env
.
checkpoint
def
publish
(
self
):
del
env
.
revert
self
.
env
.
equiv
=
self
del
self
.
history
[
env
]
self
.
env
.
set_equiv
=
self
.
set_equiv
def
on_change_input
(
self
,
env
,
node
,
i
,
r
,
new_r
):
def
unpublish
(
self
):
if
self
.
history
[
env
]
is
None
:
del
self
.
env
.
equiv
return
del
self
.
env
.
set_equiv
h
=
self
.
history
[
env
]
h
.
append
(
lambda
:
env
.
change_input
(
node
,
i
,
r
))
def
set_equiv
(
self
,
d
):
self
.
update
(
d
)
def
revert
(
self
,
env
,
checkpoint
):
"""
def
group
(
self
,
main
,
*
keys
):
Reverts the graph to whatever it was at the provided
"Marks all the keys as having been replaced by the Result main."
checkpoint (undoes all replacements). A checkpoint at any
keys
=
[
key
for
key
in
keys
if
key
is
not
main
]
given time can be obtained using self.checkpoint().
if
self
.
has_key
(
main
):
"""
raise
Exception
(
"Only group results that have not been grouped before."
)
h
=
self
.
history
[
env
]
for
key
in
keys
:
self
.
history
[
env
]
=
None
if
self
.
has_key
(
key
):
while
len
(
h
)
>
checkpoint
:
raise
Exception
(
"Only group results that have not been grouped before."
)
f
=
h
.
pop
()
if
key
is
main
:
f
()
continue
self
.
history
[
env
]
=
h
self
.
setdefault
(
key
,
main
)
def
ungroup
(
self
,
main
,
*
keys
):
class
Validator
:
"Undoes group(main, *keys)"
keys
=
[
key
for
key
in
keys
if
key
is
not
main
]
def
on_attach
(
self
,
env
):
for
key
in
keys
:
if
hasattr
(
env
,
'validate'
):
if
self
[
key
]
is
main
:
raise
Exception
(
"Validator feature is already present or in conflict with another plugin."
)
del
self
[
key
]
env
.
validate
=
lambda
:
env
.
execute_callbacks
(
'validate'
)
def
consistent
():
def
__call__
(
self
,
key
):
try
:
"Returns the currently active replacement for the given key."
env
.
validate
()
next
=
self
.
get
(
key
,
None
)
return
True
while
next
:
except
:
key
=
next
return
False
next
=
self
.
get
(
next
,
None
)
env
.
consistent
=
consistent
return
key
def
on_deattach
(
self
,
env
):
del
env
.
validate
class
NodeFinder
(
dict
):
del
env
.
consistent
def
__init__
(
self
,
env
):
class
ReplaceValidate
(
History
,
Validator
):
def
on_attach
(
self
,
env
):
History
.
on_attach
(
self
,
env
)
Validator
.
on_attach
(
self
,
env
)
for
attr
in
(
'replace_validate'
,
'replace_all_validate'
):
if
hasattr
(
env
,
attr
):
raise
Exception
(
"ReplaceValidate feature is already present or in conflict with another plugin."
)
env
.
replace_validate
=
partial
(
self
.
replace_validate
,
env
)
env
.
replace_all_validate
=
partial
(
self
.
replace_all_validate
,
env
)
def
on_deattach
(
self
,
env
):
History
.
on_deattach
(
self
,
env
)
Validator
.
on_deattach
(
self
,
env
)
del
env
.
replace_validate
del
env
.
replace_all_validate
def
replace_validate
(
self
,
env
,
r
,
new_r
):
self
.
replace_all_validate
(
env
,
[(
r
,
new_r
)])
def
replace_all_validate
(
self
,
env
,
replacements
):
chk
=
env
.
checkpoint
()
for
r
,
new_r
in
replacements
:
env
.
replace
(
r
,
new_r
)
try
:
env
.
validate
()
except
:
env
.
revert
(
chk
)
raise
class
NodeFinder
(
dict
,
Bookkeeper
):
def
__init__
(
self
):
self
.
env
=
None
def
on_attach
(
self
,
env
):
if
self
.
env
is
not
None
:
raise
Exception
(
"A NodeFinder instance can only serve one Env."
)
if
hasattr
(
env
,
'get_nodes'
):
raise
Exception
(
"NodeFinder is already present or in conflict with another plugin."
)
self
.
env
=
env
self
.
env
=
env
env
.
get_nodes
=
partial
(
self
.
query
,
env
)
Bookkeeper
.
on_attach
(
self
,
env
)
def
on_import
(
self
,
node
):
def
on_deattach
(
self
,
env
):
if
self
.
env
is
not
env
:
raise
Exception
(
"This NodeFinder instance was not attached to the provided env."
)
self
.
env
=
None
del
env
.
get_nodes
Bookkeeper
.
on_deattach
(
self
,
env
)
def
on_import
(
self
,
env
,
node
):
try
:
try
:
self
.
setdefault
(
node
.
op
,
set
())
.
ad
d
(
node
)
self
.
setdefault
(
node
.
op
,
[])
.
appen
d
(
node
)
except
TypeError
:
except
TypeError
:
#node.op is unhashable
pass
return
def
on_prune
(
self
,
node
):
def
on_prune
(
self
,
env
,
node
):
try
:
try
:
self
[
node
.
op
]
.
remove
(
node
)
nodes
=
self
[
node
.
op
]
except
TypeError
:
except
TypeError
:
#node.op is unhashable
return
return
if
not
self
[
node
.
op
]:
nodes
.
remove
(
node
)
if
not
nodes
:
del
self
[
node
.
op
]
del
self
[
node
.
op
]
def
query
(
self
,
op
):
def
query
(
self
,
env
,
op
):
try
:
try
:
all
=
self
.
get
(
op
,
[])
all
=
self
.
get
(
op
,
[])
except
TypeError
:
except
TypeError
:
raise
TypeError
(
"
%
s in unhashable and cannot be queried by the optimizer"
%
op
)
raise
TypeError
(
"
%
s in unhashable and cannot be queried by the optimizer"
%
op
)
all
=
[
x
for
x
in
all
]
all
=
list
(
all
)
shuffle
(
all
)
# this helps a lot for debugging because the order of the replacements will vary
while
all
:
while
all
:
next
=
all
.
pop
()
next
=
all
.
pop
()
if
self
.
env
.
has_node
(
next
)
:
if
next
in
env
.
nodes
:
yield
next
yield
next
def
publish
(
self
):
self
.
env
.
get_nodes
=
self
.
query
def
__eq__
(
self
,
other
):
class
PrintListener
(
object
):
return
isinstance
(
other
,
NodeFinder
)
and
self
.
env
is
other
.
env
def
__init__
(
self
,
active
=
True
):
self
.
active
=
active
def
on_attach
(
self
,
env
):
if
self
.
active
:
print
"-- attaching to: "
,
env
def
on_deattach
(
self
,
env
):
if
self
.
active
:
print
"-- deattaching from: "
,
env
def
on_import
(
self
,
env
,
node
):
if
self
.
active
:
print
"-- importing:
%
s"
%
node
def
on_prune
(
self
,
env
,
node
):
if
self
.
active
:
print
"-- pruning:
%
s"
%
node
def
on_change_input
(
self
,
env
,
node
,
i
,
r
,
new_r
):
if
self
.
active
:
print
"-- changing (
%
s.inputs[
%
s]) from
%
s to
%
s"
%
(
node
,
i
,
r
,
new_r
)
# class EquivTool(dict):
# def __init__(self, env):
# self.env = env
# def on_rewire(self, clients, r, new_r):
# repl = self(new_r)
# if repl is r:
# self.ungroup(r, new_r)
# elif repl is not new_r:
# raise Exception("Improper use of EquivTool!")
# else:
# self.group(new_r, r)
# def publish(self):
# self.env.equiv = self
# self.env.set_equiv = self.set_equiv
# def unpublish(self):
# del self.env.equiv
# del self.env.set_equiv
# def set_equiv(self, d):
# self.update(d)
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
# class InstanceFinder(Listener, Tool, dict):
# class InstanceFinder(Listener, Tool, dict):
...
@@ -158,28 +302,6 @@ class NodeFinder(dict):
...
@@ -158,28 +302,6 @@ class NodeFinder(dict):
class
PrintListener
(
object
):
def
__init__
(
self
,
env
,
active
=
True
):
self
.
env
=
env
self
.
active
=
active
if
active
:
print
"-- initializing"
def
on_import
(
self
,
node
):
if
self
.
active
:
print
"-- importing:
%
s"
%
node
def
on_prune
(
self
,
node
):
if
self
.
active
:
print
"-- pruning:
%
s"
%
node
def
on_rewire
(
self
,
clients
,
r
,
new_r
):
if
self
.
active
:
if
r
.
owner
is
not
None
:
r
=
r
.
owner
if
new_r
.
owner
is
not
None
:
new_r
=
new_r
.
owner
print
"-- moving from
%
s to
%
s"
%
(
r
,
new_r
)
### UNUSED AND UNTESTED ###
### UNUSED AND UNTESTED ###
...
...
gof/utils.py
浏览文件 @
65e08101
...
@@ -26,6 +26,8 @@ class object2(object):
...
@@ -26,6 +26,8 @@ class object2(object):
if
hasattr
(
self
,
'__eq__'
)
or
hasattr
(
self
,
'__cmp__'
):
if
hasattr
(
self
,
'__eq__'
)
or
hasattr
(
self
,
'__cmp__'
):
raise
TypeError
(
"unhashable object:
%
s"
%
self
)
raise
TypeError
(
"unhashable object:
%
s"
%
self
)
return
id
(
self
)
return
id
(
self
)
def
__ne__
(
self
,
other
):
return
not
self
==
other
class
scratchpad
:
class
scratchpad
:
def
clear
(
self
):
def
clear
(
self
):
...
...
tensor.py
浏览文件 @
65e08101
...
@@ -71,7 +71,7 @@ class Tensor(Type):
...
@@ -71,7 +71,7 @@ class Tensor(Type):
def
__init__
(
self
,
dtype
,
broadcastable
):
def
__init__
(
self
,
dtype
,
broadcastable
):
self
.
dtype
=
str
(
dtype
)
self
.
dtype
=
str
(
dtype
)
self
.
broadcastable
=
broadcastable
self
.
broadcastable
=
tuple
(
broadcastable
)
self
.
dtype_specs
()
# error checking is done there
self
.
dtype_specs
()
# error checking is done there
def
filter
(
self
,
data
,
strict
=
False
):
def
filter
(
self
,
data
,
strict
=
False
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论