Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ea32b4db
提交
ea32b4db
authored
5月 05, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
too many things to list
上级
d2cf55aa
全部展开
显示空白字符变更
内嵌
并排
正在显示
19 个修改的文件
包含
221 行增加
和
320 行删除
+221
-320
_test_compile.py
_test_compile.py
+0
-0
_test_gradient.py
_test_gradient.py
+0
-0
_test_sparse.py
_test_sparse.py
+0
-0
_test_tensor.py
_test_tensor.py
+0
-0
_test_cc.py
gof/_test_cc.py
+7
-1
_test_ext.py
gof/_test_ext.py
+1
-2
_test_graph.py
gof/_test_graph.py
+3
-3
_test_link.py
gof/_test_link.py
+13
-7
_test_opt.py
gof/_test_opt.py
+0
-1
cc.py
gof/cc.py
+34
-13
env.py
gof/env.py
+15
-211
ext.py
gof/ext.py
+5
-13
graph.py
gof/graph.py
+58
-42
link.py
gof/link.py
+9
-5
opt.py
gof/opt.py
+11
-7
toolbox.py
gof/toolbox.py
+45
-0
scalar.py
scalar.py
+5
-8
sparse.py
sparse.py
+0
-0
tensor.py
tensor.py
+15
-7
没有找到文件。
_test_compile.py
浏览文件 @
ea32b4db
差异被折叠。
点击展开。
_test_gradient.py
浏览文件 @
ea32b4db
差异被折叠。
点击展开。
_test_sparse.py
浏览文件 @
ea32b4db
差异被折叠。
点击展开。
_test_tensor.py
浏览文件 @
ea32b4db
差异被折叠。
点击展开。
gof/_test_cc.py
浏览文件 @
ea32b4db
...
...
@@ -6,7 +6,8 @@ from cc import *
from
type
import
Type
from
graph
import
Result
,
as_result
,
Apply
,
Constant
from
op
import
Op
from
env
import
Env
import
env
import
toolbox
class
TDouble
(
Type
):
def
filter
(
self
,
data
):
...
...
@@ -125,6 +126,11 @@ def inputs():
return
x
,
y
,
z
def
Env
(
inputs
,
outputs
):
e
=
env
.
Env
(
inputs
,
outputs
)
return
e
class
_test_CLinker
(
unittest
.
TestCase
):
def
test_straightforward
(
self
):
...
...
gof/_test_ext.py
浏览文件 @
ea32b4db
...
...
@@ -257,7 +257,6 @@ class _test_all(unittest.TestCase):
if
__name__
==
'__main__'
:
#unittest.main()
_test_all
(
'test_usage_loop_through_views'
)
.
debug
()
unittest
.
main
()
gof/_test_graph.py
浏览文件 @
ea32b4db
...
...
@@ -161,14 +161,14 @@ class _test_clone(unittest.TestCase):
def
test_accurate
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
node
=
MyOp
.
make_node
(
r1
,
r2
)
new
=
clone
([
r1
,
r2
],
node
.
outputs
)
_
,
new
=
clone
([
r1
,
r2
],
node
.
outputs
,
False
)
assert
self
.
str
([
r1
,
r2
],
new
)
==
[
"MyOp(1, 2)"
]
def
test_copy
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
node
=
MyOp
.
make_node
(
r1
,
r2
)
node2
=
MyOp
.
make_node
(
node
.
outputs
[
0
],
r5
)
new
=
clone
([
r1
,
r2
,
r5
],
node2
.
outputs
)
_
,
new
=
clone
([
r1
,
r2
,
r5
],
node2
.
outputs
,
False
)
assert
node2
.
outputs
[
0
]
.
type
==
new
[
0
]
.
type
and
node2
.
outputs
[
0
]
is
not
new
[
0
]
# the new output is like the old one but not the same object
assert
node2
is
not
new
[
0
]
.
owner
# the new output has a new owner
assert
new
[
0
]
.
owner
.
inputs
[
1
]
is
r5
# the inputs are not copied
...
...
@@ -178,7 +178,7 @@ class _test_clone(unittest.TestCase):
# Checks that manipulating a cloned graph leaves the original unchanged.
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
node
=
MyOp
.
make_node
(
MyOp
.
make_node
(
r1
,
r2
)
.
outputs
[
0
],
r5
)
new
=
clone
([
r1
,
r2
,
r5
],
node
.
outputs
)
_
,
new
=
clone
([
r1
,
r2
,
r5
],
node
.
outputs
,
False
)
new_node
=
new
[
0
]
.
owner
new_node
.
inputs
=
MyResult
(
7
),
MyResult
(
8
)
...
...
gof/_test_link.py
浏览文件 @
ea32b4db
...
...
@@ -2,10 +2,12 @@
import
unittest
from
graph
import
Result
,
as_result
,
Apply
import
graph
from
graph
import
Result
,
as_result
,
Apply
,
Constant
from
type
import
Type
from
op
import
Op
from
env
import
Env
import
env
import
toolbox
from
link
import
*
...
...
@@ -67,6 +69,10 @@ def perform_linker(env):
lnk
=
PerformLinker
(
env
)
return
lnk
def
Env
(
inputs
,
outputs
):
e
=
env
.
Env
(
inputs
,
outputs
)
return
e
class
_test_PerformLinker
(
unittest
.
TestCase
):
...
...
@@ -94,16 +100,14 @@ class _test_PerformLinker(unittest.TestCase):
def
test_input_output_same
(
self
):
x
,
y
,
z
=
inputs
()
a
,
d
=
add
(
x
,
y
),
div
(
x
,
y
)
e
=
mul
(
a
,
d
)
fn
=
perform_linker
(
Env
([
e
],
[
e
]))
.
make_function
()
fn
=
perform_linker
(
Env
([
x
],
[
x
]))
.
make_function
()
self
.
failUnless
(
1.0
is
fn
(
1.0
))
def
test_input_dependency0
(
self
):
x
,
y
,
z
=
inputs
()
a
,
d
=
add
(
x
,
y
),
div
(
x
,
y
)
e
=
mul
(
a
,
d
)
fn
=
perform_linker
(
Env
(
[
x
,
y
,
a
],
[
e
]
))
.
make_function
()
fn
=
perform_linker
(
Env
(
*
graph
.
clone
([
x
,
y
,
a
],
[
e
])
))
.
make_function
()
self
.
failUnless
(
fn
(
1.0
,
2.0
,
9.0
)
==
4.5
)
def
test_skiphole
(
self
):
...
...
@@ -111,9 +115,11 @@ class _test_PerformLinker(unittest.TestCase):
a
=
add
(
x
,
y
)
r
=
raise_err
(
a
)
e
=
add
(
r
,
a
)
fn
=
perform_linker
(
Env
(
[
x
,
y
,
r
],
[
e
]
))
.
make_function
()
fn
=
perform_linker
(
Env
(
*
graph
.
clone
([
x
,
y
,
r
],
[
e
])
))
.
make_function
()
self
.
failUnless
(
fn
(
1.0
,
2.0
,
4.5
)
==
7.5
)
# def test_disconnected_input_output(self):
# x,y,z = inputs()
# a = add(x,y)
...
...
gof/_test_opt.py
浏览文件 @
ea32b4db
...
...
@@ -415,4 +415,3 @@ if __name__ == '__main__':
unittest
.
main
()
gof/cc.py
浏览文件 @
ea32b4db
from
graph
import
Constant
import
graph
from
graph
import
Constant
,
Value
from
link
import
Linker
,
LocalLinker
,
raise_with_op
,
Filter
,
map_storage
,
PerformLinker
from
copy
import
copy
from
utils
import
AbstractFunctionError
...
...
@@ -284,10 +285,11 @@ def apply_policy(policy, r, name, sub):
@type r: L{Result}
@return: C{policy[0](r) + policy[1](r) + ...}
"""
if
isinstance
(
r
,
(
list
,
tuple
)):
if
isinstance
(
policy
,
(
list
,
tuple
)):
ret
=
""
for
sub_policy
in
policy
:
ret
+=
sub_policy
(
r
,
name
,
sub
)
return
ret
return
policy
(
r
,
name
,
sub
)
...
...
@@ -345,7 +347,7 @@ class CLinker(Linker):
self
.
outputs
=
env
.
outputs
self
.
results
=
list
(
env
.
results
)
# The orphans field is listified to ensure a consistent order.
self
.
orphans
=
list
(
env
.
orphans
.
difference
(
self
.
outputs
))
self
.
orphans
=
list
(
r
for
r
in
self
.
results
if
isinstance
(
r
,
Value
)
and
r
not
in
self
.
inputs
)
#list(
env.orphans.difference(self.outputs))
self
.
temps
=
list
(
set
(
self
.
results
)
.
difference
(
self
.
inputs
)
.
difference
(
self
.
outputs
)
.
difference
(
self
.
orphans
))
self
.
node_order
=
env
.
toposort
()
...
...
@@ -403,8 +405,9 @@ class CLinker(Linker):
policy
=
[[
get_nothing
,
get_nothing
,
get_nothing
],
[
get_c_declare
,
get_c_extract
,
get_c_cleanup
]]
elif
result
in
self
.
orphans
:
if
not
isinstance
(
result
,
Constant
):
raise
TypeError
(
"All orphans to CLinker must be Constant."
,
result
)
if
not
isinstance
(
result
,
Value
):
raise
TypeError
(
"All orphans to CLinker must be Value instances."
,
result
)
if
isinstance
(
result
,
Constant
):
try
:
symbol
[
result
]
=
"("
+
result
.
type
.
c_literal
(
result
.
data
)
+
")"
consts
.
append
(
result
)
...
...
@@ -428,7 +431,6 @@ class CLinker(Linker):
elif
result
in
self
.
outputs
:
# outputs don't need to be extracted from Python, so we call c_init rather than c_extract
if
result
.
type
.
c_is_simple
()
or
result
in
no_recycling
:
policy
=
[[
get_nothing
,
get_nothing
,
get_nothing
],
[
get_c_declare
,
get_c_init
,
(
get_c_sync
,
get_c_cleanup
)]]
else
:
...
...
@@ -599,7 +601,12 @@ class CLinker(Linker):
if
input_storage
is
None
:
input_storage
=
[[
None
]
for
result
in
self
.
inputs
]
if
output_storage
is
None
:
output_storage
=
[[
None
]
for
result
in
self
.
outputs
]
map
=
{}
output_storage
=
[]
for
result
in
self
.
outputs
:
if
result
not
in
map
:
map
[
result
]
=
[
None
]
output_storage
.
append
(
map
[
result
])
thunk
=
self
.
cthunk_factory
(
error_storage
,
input_storage
,
output_storage
)
...
...
@@ -642,13 +649,13 @@ class CLinker(Linker):
if
not
getattr
(
self
,
'instantiate'
,
False
):
self
.
code_gen
()
module_name
=
self
.
hash
# Eliminate duplicate inputs and outputs from the storage that we will pass to instantiate
out_storage
=
[
x
for
i
,
x
in
enumerate
(
out_storage
)
if
(
i
+
len
(
in_storage
))
not
in
self
.
dupidx
]
in_storage
=
[
x
for
i
,
x
in
enumerate
(
in_storage
)
if
i
not
in
self
.
dupidx
]
cthunk
=
object
()
# dummy so weave can get the type
module_name
=
self
.
hash
mod
=
weave
.
ext_tools
.
ext_module
(
module_name
)
argnames
=
[
"i
%
i"
%
i
for
i
in
xrange
(
len
(
in_storage
))]
\
...
...
@@ -710,8 +717,11 @@ class CLinker(Linker):
# Eliminate duplicate inputs and outputs from the storage that we will pass to instantiate
out_storage
=
[
x
for
i
,
x
in
enumerate
(
out_storage
)
if
(
i
+
len
(
in_storage
))
not
in
self
.
dupidx
]
in_storage
=
[
x
for
i
,
x
in
enumerate
(
in_storage
)
if
i
not
in
self
.
dupidx
]
module_name
=
self
.
hash
module
=
__import__
(
"
%
s"
%
(
module_name
),
{},
{},
[
module_name
])
ret
=
module
.
instantiate
(
error_storage
,
*
(
in_storage
+
out_storage
+
[
orphan
.
data
for
orphan
in
self
.
orphans
]))
orphd
=
[[
orphan
.
data
]
for
orphan
in
self
.
orphans
]
ret
=
module
.
instantiate
(
error_storage
,
*
(
in_storage
+
out_storage
+
orphd
))
assert
sys
.
getrefcount
(
ret
)
==
2
# refcount leak check
return
ret
...
...
@@ -751,7 +761,9 @@ class OpWiseCLinker(LocalLinker):
node_input_storage
=
[
storage_map
[
r
]
for
r
in
node
.
inputs
]
node_output_storage
=
[
storage_map
[
r
]
for
r
in
node
.
outputs
]
try
:
cl
=
CLinker
(
Env
(
node
.
inputs
,
node
.
outputs
))
e
=
Env
(
*
graph
.
clone
(
node
.
inputs
,
node
.
outputs
))
e
.
toposort
=
lambda
:
e
.
nodes
cl
=
CLinker
(
e
,
[
r
for
r
,
r2
in
zip
(
e
.
outputs
,
node
.
outputs
)
if
r2
in
no_recycling
])
thunk
,
node_input_filters
,
node_output_filters
=
cl
.
make_thunk
(
input_storage
=
node_input_storage
,
output_storage
=
node_output_storage
)
...
...
@@ -823,7 +835,7 @@ class DualLinker(Linker):
function.
"""
def
__init__
(
self
,
env
,
checker
=
_default_checker
):
def
__init__
(
self
,
env
,
checker
=
_default_checker
,
no_recycling
=
[]
):
"""
Initialize a DualLinker.
...
...
@@ -844,6 +856,7 @@ class DualLinker(Linker):
"""
self
.
env
=
env
self
.
checker
=
checker
self
.
no_recycling
=
no_recycling
def
make_thunk
(
self
,
**
kwargs
):
# if inplace:
...
...
@@ -865,8 +878,10 @@ class DualLinker(Linker):
# thunks2 = [c_make_thunk(op) for op in op_order_2]
env
=
self
.
env
_f
,
i1
,
o1
,
thunks1
,
order1
=
PerformLinker
(
env
)
.
make_all
(
**
kwargs
)
_f
,
i2
,
o2
,
thunks2
,
order2
=
OpWiseCLinker
(
env
)
.
make_all
(
**
kwargs
)
no_recycling
=
self
.
no_recycling
_f
,
i1
,
o1
,
thunks1
,
order1
=
PerformLinker
(
env
,
no_recycling
=
no_recycling
)
.
make_all
(
**
kwargs
)
_f
,
i2
,
o2
,
thunks2
,
order2
=
OpWiseCLinker
(
env
,
no_recycling
=
no_recycling
)
.
make_all
(
**
kwargs
)
def
f
():
for
input1
,
input2
in
zip
(
i1
,
i2
):
...
...
@@ -874,6 +889,12 @@ class DualLinker(Linker):
# the copy is necessary in order for inplace ops not to interfere
input2
.
storage
[
0
]
=
copy
(
input1
.
storage
[
0
])
for
thunk1
,
thunk2
,
node1
,
node2
in
zip
(
thunks1
,
thunks2
,
order1
,
order2
):
for
output
,
storage
in
zip
(
node1
.
outputs
,
thunk1
.
outputs
):
if
output
in
no_recycling
:
storage
[
0
]
=
None
for
output
,
storage
in
zip
(
node2
.
outputs
,
thunk2
.
outputs
):
if
output
in
no_recycling
:
storage
[
0
]
=
None
try
:
thunk1
()
thunk2
()
...
...
gof/env.py
浏览文件 @
ea32b4db
...
...
@@ -26,15 +26,7 @@ class Env(object): #(graph.Graph):
The Env supports the replace operation which allows to replace a
result in the subgraph by another, e.g. replace (x + x).out by (2
* x).out. This is the basis for optimization in omega.
Regarding inputs and orphans:
In the context of a computation graph, the inputs and orphans are
both results that are the source nodes of computation. Those
results that are named as inputs will be assumed to contain fresh.
In other words, the backward search from outputs will stop at any
node that has been explicitly named as an input.
* x).out. This is the basis for optimization in theano.
"""
### Special ###
...
...
@@ -69,10 +61,6 @@ class Env(object): #(graph.Graph):
self
.
node_locks
=
{}
self
.
result_locks
=
{}
# # List of functions that undo the replace operations performed.
# # e.g. to recover the initial graph one could write: for u in self.history.__reversed__(): u()
# self.history = []
### Setup a Result ###
...
...
@@ -237,99 +225,13 @@ class Env(object): #(graph.Graph):
raise
TypeError
(
"The type of the replacement must be the same as the type of the original Result."
,
r
,
new_r
)
assert
r
in
self
.
results
for
node
,
i
in
r
.
clients
:
for
node
,
i
in
list
(
r
.
clients
)
:
assert
node
==
'output'
and
self
.
outputs
[
i
]
is
r
or
node
.
inputs
[
i
]
is
r
self
.
change_input
(
node
,
i
,
new_r
)
# # Save where we are so we can backtrack
# if consistency_check:
# chk = self.checkpoint()
# # The copy is required so undo can know what clients to move back!
# clients = copy(self.clients(r))
# # Messy checks so we know what to do if we are replacing an output
# # result. Note that if v is an input result, we do nothing at all for
# # now (it's not clear what it means to replace an input result).
# was_output = False
# if r in self.outputs:
# was_output = True
# self.outputs[self.outputs.index(r)] = new_r
# was_input = False
# if r in self.inputs:
# was_input = True
# self.inputs[self.inputs.index(r)] = new_r
# # The actual replacement operation occurs here. This might raise
# # an error.
# self.__move_clients__(clients, r, new_r) # not sure how to order this wrt to adjusting the outputs
# # This function undoes the replacement.
# def undo():
# # Restore self.outputs
# if was_output:
# self.outputs[self.outputs.index(new_r)] = r
# # Restore self.inputs
# if was_input:
# self.inputs[self.inputs.index(new_r)] = r
# # Move back the clients. This should never raise an error.
# self.__move_clients__(clients, new_r, r)
# self.history.append(undo)
# if consistency_check:
# try:
# self.validate()
# except InconsistencyError, e:
# self.revert(chk)
# raise
def
replace_all
(
self
,
d
):
"""
For (r, new_r) in d.items(), replaces r with new_r. Checks for
consistency at the end and raises an InconsistencyError if the
graph is not consistent. If an error is raised, the graph is
restored to what it was before.
"""
for
r
,
new_r
in
d
.
items
():
self
.
replace
(
r
,
new_r
,
False
)
# chk = self.checkpoint()
# try:
# for r, new_r in d.items():
# self.replace(r, new_r, False)
# except Exception, e:
# self.revert(chk)
# raise
# try:
# self.validate()
# except InconsistencyError, e:
# self.revert(chk)
# raise
# def checkpoint(self):
# """
# Returns an object that can be passed to self.revert in order to backtrack
# to a previous state.
# """
# return len(self.history)
# def consistent(self):
# """
# Returns True iff the subgraph is consistent and does not violate the
# constraints set by the listeners.
# """
# try:
# self.validate()
# except InconsistencyError:
# return False
# return True
def
replace_all
(
self
,
pairs
):
for
r
,
new_r
in
pairs
:
self
.
replace
(
r
,
new_r
)
### features ###
...
...
@@ -386,6 +288,16 @@ class Env(object): #(graph.Graph):
### misc ###
def
toposort
(
self
):
env
=
self
ords
=
{}
for
feature
in
env
.
_features
:
if
hasattr
(
feature
,
'orderings'
):
for
op
,
prereqs
in
feature
.
orderings
(
env
)
.
items
():
ords
.
setdefault
(
op
,
set
())
.
update
(
prereqs
)
order
=
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
,
ords
)
return
order
def
nclients
(
self
,
r
):
"Same as len(self.clients(r))."
return
len
(
self
.
clients
(
r
))
...
...
@@ -439,117 +351,9 @@ class Env(object): #(graph.Graph):
if
node
.
inputs
[
i
]
is
not
result
:
raise
Exception
(
"Inconsistent clients list."
,
result
,
node
.
inputs
[
i
])
# def revert(self, checkpoint):
# """
# Reverts the graph to whatever it was at the provided
# checkpoint (undoes all replacements). A checkpoint at any
# given time can be obtained using self.checkpoint().
# """
# while len(self.history) > checkpoint:
# f = self.history.pop()
# f()
# def supplemental_orderings(self):
# """
# Returns a dictionary of {op: set(prerequisites)} that must
# be satisfied in addition to the order defined by the structure
# of the graph (returns orderings that not related to input/output
# relationships).
# """
# ords = {}
# for feature in self._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings().items():
# ords.setdefault(op, set()).update(prereqs)
# return ords
# def toposort(self):
# """
# Returns a list of nodes in the order that they must be executed
# in order to preserve the semantics of the graph and respect
# the constraints put forward by the listeners.
# """
# ords = self.supplemental_orderings()
# order = graph.io_toposort(self.inputs, self.outputs, ords)
# return order
# def validate(self):
# """
# Raises an error if the graph is inconsistent.
# """
# self.execute_callbacks('validate')
# # for constraint in self._constraints.values():
# # constraint.validate()
# return True
### Private interface ###
# def __move_clients__(self, clients, r, new_r):
# if not (r.type == new_r.type):
# raise TypeError("Cannot move clients between Results that have different types.", r, new_r)
# # We import the new result in the fold
# self.__import_r__([new_r])
# for op, i in clients:
# op.inputs[i] = new_r
# # try:
# # # Try replacing the inputs
# # for op, i in clients:
# # op.set_input(i, new_r)
# # except:
# # # Oops!
# # for op, i in clients:
# # op.set_input(i, r)
# # self.__prune_r__([new_r])
# # raise
# self.__remove_clients__(r, clients)
# self.__add_clients__(new_r, clients)
# # # We import the new result in the fold
# # # why was this line AFTER the set_inputs???
# # # if we do it here then satisfy in import fucks up...
# # self.__import_r__([new_r])
# self.execute_callbacks('on_rewire', clients, r, new_r)
# # for listener in self._listeners.values():
# # try:
# # listener.on_rewire(clients, r, new_r)
# # except AbstractFunctionError:
# # pass
# # We try to get rid of the old one
# self.__prune_r__([r])
def
__str__
(
self
):
return
"[
%
s]"
%
", "
.
join
(
graph
.
as_string
(
self
.
inputs
,
self
.
outputs
))
# def clone_get_equiv(self, clone_inputs = True):
# equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
# new = self.__class__([equiv[input] for input in self.inputs],
# [equiv[output] for output in self.outputs])
# for feature in self._features:
# new.extend(feature)
# return new, equiv
# def clone(self, clone_inputs = True):
# equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
# new = self.__class__([equiv[input] for input in self.inputs],
# [equiv[output] for output in self.outputs])
# for feature in self._features:
# new.extend(feature)
# try:
# new.set_equiv(equiv)
# except AttributeError:
# pass
# return new
# def __copy__(self):
# return self.clone()
...
...
gof/ext.py
浏览文件 @
ea32b4db
#from features import Listener, Constraint, Orderings, Tool
import
graph
import
utils
from
utils
import
AbstractFunctionError
...
...
@@ -253,7 +256,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
"""
self
.
seen
.
add
(
op
)
op
.
deps
[
'destroy'
]
=
[]
view_map
,
destroy_map
=
self
.
get_maps
(
op
)
for
input
in
op
.
inputs
:
...
...
@@ -334,7 +336,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
del
self
.
children
[
output
]
self
.
seen
.
remove
(
op
)
del
op
.
deps
[
'destroy'
]
def
__add_destroyer__
(
self
,
path
):
...
...
@@ -350,9 +351,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
destroyers
=
self
.
destroyers
.
setdefault
(
foundation
,
{})
path
=
destroyers
.
setdefault
(
node
,
path
)
print
"add"
,
path
node
.
deps
[
'destroy'
]
+=
[
user
.
owner
for
user
in
self
.
__users__
(
foundation
)
if
user
not
in
node
.
outputs
]
# for foundation, destroyers in self.destroyers.items():
# for op in destroyers.keys():
# ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
...
...
@@ -361,7 +359,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
self
.
dups
.
add
(
foundation
)
# results marked 'indestructible' must not be destroyed.
if
getattr
(
foundation
,
'indestructible'
,
False
):
if
getattr
(
foundation
,
'indestructible'
,
False
)
or
isinstance
(
foundation
,
graph
.
Constant
)
:
self
.
illegal
.
add
(
foundation
)
...
...
@@ -374,13 +372,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
target
=
path
[
-
1
]
node
=
target
.
owner
print
"rm"
,
path
print
node
.
deps
[
'destroy'
]
for
user
in
self
.
__users__
(
foundation
):
print
" -- "
,
user
if
user
not
in
node
.
outputs
:
node
.
deps
[
'destroy'
]
.
remove
(
user
.
owner
)
destroyers
=
self
.
destroyers
[
foundation
]
del
destroyers
[
node
]
...
...
@@ -477,6 +468,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
In particular, all the users of a destroyed result have priority over the
L{Op} that destroys the result.
"""
self
.
validate
(
env
)
ords
=
{}
for
foundation
,
destroyers
in
self
.
destroyers
.
items
():
for
op
in
destroyers
.
keys
():
...
...
gof/graph.py
浏览文件 @
ea32b4db
...
...
@@ -163,7 +163,6 @@ def as_apply(x):
@deprecated
def
inputs
(
o
):
"""
...
...
@@ -173,7 +172,6 @@ def inputs(o):
Returns the set of inputs necessary to compute the outputs in o
such that input.owner is None.
"""
print
'gof.graph.inputs deprecated: April 29'
results
=
set
()
def
seek
(
r
):
op
=
r
.
owner
...
...
@@ -187,53 +185,71 @@ def inputs(o):
return
results
def
results_and_orphans
(
i
,
o
,
except_unreachable_input
=
False
):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
# def results_and_orphans(i, o, except_unreachable_input=False):
# """
# @type i: list
# @param i: input L{Result}s
# @type o: list
# @param o: output L{Result}s
# Returns the pair (results, orphans). The former is the set of
# L{Result}s that are involved in the subgraph that lies between i and
# o. This includes i, o, orphans(i, o) and all results of all
# intermediary steps from i to o. The second element of the returned
# pair is orphans(i, o).
# """
# results = set()
# i = set(i)
# # results.update(i)
# incomplete_paths = []
# reached = set()
# def helper(r, path):
# if r in i:
# reached.add(r)
# results.update(path)
# elif r.owner is None:
# incomplete_paths.append(path)
# else:
# op = r.owner
# for r2 in op.inputs:
# helper(r2, path + [r2])
Returns the pair (results, orphans). The former is the set of
L{Result}s that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
"""
results
=
set
()
i
=
set
(
i
)
# results.update(i)
incomplete_paths
=
[]
reached
=
set
()
def
helper
(
r
,
path
):
if
r
in
i
:
reached
.
add
(
r
)
results
.
update
(
path
)
elif
r
.
owner
is
None
:
incomplete_paths
.
append
(
path
)
else
:
op
=
r
.
owner
for
r2
in
op
.
inputs
:
helper
(
r2
,
path
+
[
r2
])
# for output in o:
# helper(output, [output])
for
output
in
o
:
helper
(
output
,
[
output
])
# orphans = set()
# for path in incomplete_paths:
# for r in path:
# if r not in results:
# orphans.add(r)
# break
orphans
=
set
()
for
path
in
incomplete_paths
:
for
r
in
path
:
if
r
not
in
results
:
orphans
.
add
(
r
)
break
# if except_unreachable_input and len(i) != len(reached):
# raise Exception(results_and_orphans.E_unreached)
if
except_unreachable_input
and
len
(
i
)
!=
len
(
reached
):
raise
Exception
(
results_and_orphans
.
E_unreached
)
# results.update(orphans)
results
.
update
(
orphans
)
# return results, orphans
# results_and_orphans.E_unreached = 'there were unreachable inputs'
def
results_and_orphans
(
i
,
o
):
results
=
set
()
orphans
=
set
()
def
helper
(
r
):
if
r
in
results
:
return
results
.
add
(
r
)
if
r
.
owner
is
None
:
if
r
not
in
i
:
orphans
.
add
(
r
)
else
:
for
r2
in
r
.
owner
.
inputs
:
helper
(
r2
)
for
output
in
o
:
helper
(
output
)
return
results
,
orphans
results_and_orphans
.
E_unreached
=
'there were unreachable inputs'
def
ops
(
i
,
o
):
...
...
gof/link.py
浏览文件 @
ea32b4db
...
...
@@ -2,7 +2,7 @@
from
utils
import
AbstractFunctionError
import
utils
from
graph
import
Constant
from
graph
import
Value
import
sys
import
traceback
...
...
@@ -135,16 +135,20 @@ def map_storage(env, order, input_storage, output_storage):
storage_map
=
{}
for
r
,
storage
in
zip
(
env
.
inputs
,
input_storage
):
storage_map
[
r
]
=
storage
for
orphan
in
env
.
orphans
:
if
not
isinstance
(
orphan
,
Constant
):
raise
TypeError
(
"Cannot link a graph with non-constant orphans."
,
orphan
)
storage_map
[
orphan
]
=
[
orphan
.
data
]
#
for orphan in env.orphans:
#
if not isinstance(orphan, Constant):
#
raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
#
storage_map[orphan] = [orphan.data]
if
output_storage
is
not
None
:
assert
len
(
env
.
outputs
)
==
len
(
output_storage
)
for
r
,
storage
in
zip
(
env
.
outputs
,
output_storage
):
storage_map
[
r
]
=
storage
thunks
=
[]
for
node
in
order
:
for
r
in
node
.
inputs
:
if
r
not
in
storage_map
:
assert
isinstance
(
r
,
Value
)
storage_map
[
r
]
=
[
r
.
data
]
for
r
in
node
.
outputs
:
storage_map
.
setdefault
(
r
,
[
None
])
...
...
gof/opt.py
浏览文件 @
ea32b4db
...
...
@@ -430,11 +430,16 @@ class MergeOptimizer(Optimizer):
are constant.
"""
def
add_requirements
(
self
,
env
):
try
:
env
.
extend
(
toolbox
.
ReplaceValidate
())
except
:
pass
def
apply
(
self
,
env
):
cid
=
_metadict
()
#result -> result.desc() (for constants)
inv_cid
=
_metadict
()
#desc -> result (for constants)
for
i
,
r
in
enumerate
(
env
.
orphans
.
union
(
env
.
inputs
)):
if
isinstance
(
r
,
Constant
):
for
i
,
r
in
enumerate
(
[
r
for
r
in
env
.
results
if
isinstance
(
r
,
Constant
)]):
#
env.orphans.union(env.inputs)):
#
if isinstance(r, Constant):
sig
=
r
.
signature
()
other_r
=
inv_cid
.
get
(
sig
,
None
)
if
other_r
is
not
None
:
...
...
@@ -446,20 +451,19 @@ class MergeOptimizer(Optimizer):
# and it's more efficient to give them an integer cid like the other Results
cid
.
clear
()
inv_cid
.
clear
()
for
i
,
r
in
enumerate
(
env
.
orphans
.
union
(
env
.
inputs
)
):
for
i
,
r
in
enumerate
(
r
for
r
in
env
.
results
if
r
.
owner
is
None
):
cid
[
r
]
=
i
inv_cid
[
i
]
=
r
for
node
in
env
.
io_toposort
(
):
for
node
in
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
):
node_cid
=
(
node
.
op
,
tuple
([
cid
[
input
]
for
input
in
node
.
inputs
]))
dup
=
inv_cid
.
get
(
node_cid
,
None
)
success
=
False
if
dup
is
not
None
:
success
=
True
d
=
dict
(
zip
(
node
.
outputs
,
dup
.
outputs
))
try
:
env
.
replace_all
(
d
)
except
Exception
,
e
:
env
.
replace_all
_validate
(
zip
(
node
.
outputs
,
dup
.
outputs
)
)
except
InconsistencyError
,
e
:
success
=
False
if
not
success
:
cid
[
node
]
=
node_cid
...
...
gof/toolbox.py
浏览文件 @
ea32b4db
...
...
@@ -16,6 +16,51 @@ class Bookkeeper:
self
.
on_prune
(
env
,
node
)
class
Toposorter
:
def
on_attach
(
self
,
env
):
if
hasattr
(
env
,
'toposort'
):
raise
Exception
(
"Toposorter feature is already present or in conflict with another plugin."
)
env
.
toposort
=
partial
(
self
.
toposort
,
env
)
def
on_deattach
(
self
,
env
):
del
env
.
toposort
def
toposort
(
self
,
env
):
ords
=
{}
for
feature
in
env
.
_features
:
if
hasattr
(
feature
,
'orderings'
):
for
op
,
prereqs
in
feature
.
orderings
(
env
)
.
items
():
ords
.
setdefault
(
op
,
set
())
.
update
(
prereqs
)
order
=
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
,
ords
)
return
order
# def supplemental_orderings(self):
# """
# Returns a dictionary of {op: set(prerequisites)} that must
# be satisfied in addition to the order defined by the structure
# of the graph (returns orderings that not related to input/output
# relationships).
# """
# ords = {}
# for feature in self._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings().items():
# ords.setdefault(op, set()).update(prereqs)
# return ords
# def toposort(self):
# """
# Returns a list of nodes in the order that they must be executed
# in order to preserve the semantics of the graph and respect
# the constraints put forward by the listeners.
# """
# ords = self.supplemental_orderings()
# order = graph.io_toposort(self.inputs, self.outputs, ords)
# return order
class
History
:
def
__init__
(
self
):
...
...
scalar.py
浏览文件 @
ea32b4db
...
...
@@ -25,10 +25,6 @@ def as_scalar(x, name = None):
if
not
isinstance
(
x
.
type
,
Scalar
):
raise
TypeError
(
"Result type field must be a Scalar."
,
x
,
x
.
type
)
return
x
if
isinstance
(
x
,
Constant
):
if
not
isinstance
(
x
.
type
,
Scalar
):
raise
TypeError
(
"Constant type field must be a Scalar."
,
x
,
x
.
type
)
return
x
try
:
return
constant
(
x
)
except
TypeError
:
...
...
@@ -582,7 +578,7 @@ tanh = Tanh(upgrade_to_float, name = 'tanh')
class
Composite
(
ScalarOp
):
def
__init__
(
self
,
inputs
,
outputs
):
env
=
Env
(
inputs
,
outputs
)
.
clone
(
)
env
=
Env
(
*
gof
.
graph
.
clone
(
inputs
,
outputs
)
)
inputs
,
outputs
=
env
.
inputs
,
env
.
outputs
for
node
in
env
.
nodes
:
...
...
@@ -594,7 +590,8 @@ class Composite(ScalarOp):
zip
(
outputs
,
[
"
%%
(o
%
i)s"
%
i
for
i
in
range
(
len
(
outputs
))]))
for
orphan
in
env
.
orphans
:
for
orphan
in
env
.
results
:
#env.orphans:
if
orphan
.
owner
is
None
and
orphan
not
in
env
.
inputs
:
if
isinstance
(
orphan
,
Constant
):
subd
[
orphan
]
=
orphan
.
type
.
c_literal
(
orphan
.
data
)
else
:
...
...
@@ -611,7 +608,7 @@ class Composite(ScalarOp):
name
=
"V
%%(id)
s_tmp
%
i"
%
i
subd
[
output
]
=
name
_c_code
+=
"
%
s
%
s;
\n
"
%
(
output
.
type
.
dtype_specs
()[
1
],
name
)
_c_code
+=
node
.
op
.
c_code
(
node
.
inputs
,
_c_code
+=
node
.
op
.
c_code
(
node
,
"
%(name)
s"
,
[
subd
[
input
]
for
input
in
node
.
inputs
],
[
subd
[
output
]
for
output
in
node
.
outputs
],
...
...
@@ -629,7 +626,7 @@ class Composite(ScalarOp):
if
r
in
env
.
inputs
:
idx
=
env
.
inputs
.
index
(
r
)
return
lambda
inputs
:
inputs
[
idx
]
elif
r
in
env
.
orphans
:
elif
r
.
owner
is
None
:
#
in env.orphans:
return
lambda
inputs
:
r
.
data
node
=
r
.
owner
producers
=
[
compose_impl
(
input
)
for
input
in
node
.
inputs
]
...
...
sparse.py
浏览文件 @
ea32b4db
差异被折叠。
点击展开。
tensor.py
浏览文件 @
ea32b4db
...
...
@@ -6,7 +6,7 @@ import numpy
from
copy
import
copy
from
gof
import
Result
,
Op
,
utils
,
Destroyer
,
Viewer
,
AbstractFunctionError
,
Type
,
Result
,
Constant
,
Apply
from
gof
import
Result
,
Op
,
utils
,
Destroyer
,
Viewer
,
AbstractFunctionError
,
Type
,
Result
,
Constant
,
Apply
,
Value
import
gof
import
blas
# for gemm, dot
...
...
@@ -27,14 +27,9 @@ def as_tensor(x, name = None):
if
not
isinstance
(
x
.
type
,
Tensor
):
raise
TypeError
(
"Result type field must be a Tensor."
,
x
,
x
.
type
)
return
x
if
isinstance
(
x
,
Constant
):
if
not
isinstance
(
x
.
type
,
Tensor
):
raise
TypeError
(
"Constant type field must be a Tensor."
,
x
,
x
.
type
)
return
x
try
:
return
constant
(
x
)
except
TypeError
:
raise
raise
TypeError
(
"Cannot convert
%
s to Tensor"
%
x
,
type
(
x
))
# this has a different name, because _as_tensor is the function which ops use
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
...
...
@@ -48,10 +43,19 @@ def constant(x):
return
TensorConstant
(
Tensor
(
dtype
=
x
.
dtype
,
broadcastable
=
[
d
==
1
for
d
in
x
.
shape
]),
x
)
except
:
raise
raise
TypeError
(
"Could not convert
%
s to Tensor"
%
_x
,
type
(
_x
))
def
value
(
x
):
if
not
isinstance
(
x
,
numpy
.
ndarray
):
x
=
numpy
.
asarray
(
x
)
try
:
return
TensorValue
(
Tensor
(
dtype
=
x
.
dtype
,
broadcastable
=
[
d
==
1
for
d
in
x
.
shape
]),
x
)
except
:
raise
TypeError
(
"Could not convert
%
s to Tensor"
%
_x
,
type
(
_x
))
class
Tensor
(
Type
):
"""
L{Type} representing L{numpy.ndarray} in Theano.
...
...
@@ -342,10 +346,14 @@ class TensorResult(Result, _tensor_py_operators):
class
TensorConstant
(
Constant
,
_tensor_py_operators
):
pass
class
TensorValue
(
Value
,
_tensor_py_operators
):
pass
s2t
.
as_tensor
=
as_tensor
s2t
.
Tensor
=
Tensor
s2t
.
TensorResult
=
TensorResult
s2t
.
TensorConstant
=
TensorConstant
s2t
.
TensorValue
=
TensorValue
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论