Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
bcb1a525
提交
bcb1a525
authored
3月 20, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
documented and added a lot of tests
上级
e2db78da
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
94 行增加
和
176 行删除
+94
-176
_test_cc.py
gof/_test_cc.py
+9
-6
_test_ext.py
gof/_test_ext.py
+0
-0
_test_graph.py
gof/_test_graph.py
+25
-24
_test_link.py
gof/_test_link.py
+5
-6
_test_op.py
gof/_test_op.py
+1
-116
_test_opt.py
gof/_test_opt.py
+4
-4
_test_result.py
gof/_test_result.py
+3
-2
_test_toolbox.py
gof/_test_toolbox.py
+30
-3
cc.py
gof/cc.py
+2
-1
toolbox.py
gof/toolbox.py
+15
-14
没有找到文件。
gof/_test_cc.py
浏览文件 @
bcb1a525
...
...
@@ -169,6 +169,14 @@ class _test_CLinker(unittest.TestCase):
self
.
failUnless
(
fn
(
2.0
,
2.0
)
==
4
)
# note: for now the behavior of fn(2.0, 7.0) is undefined
def
test_dups_inner
(
self
):
# Testing that duplicates are allowed inside the graph
x
,
y
,
z
=
inputs
()
e
=
add
(
mul
(
y
,
y
),
add
(
x
,
z
))
lnk
=
CLinker
(
env
([
x
,
y
,
z
],
[
e
]))
fn
=
lnk
.
make_function
()
self
.
failUnless
(
fn
(
1.0
,
2.0
,
3.0
)
==
8.0
)
class
_test_OpWiseCLinker
(
unittest
.
TestCase
):
...
...
@@ -180,9 +188,4 @@ class _test_OpWiseCLinker(unittest.TestCase):
self
.
failUnless
(
fn
(
2.0
,
2.0
,
2.0
)
==
2.0
)
if
__name__
==
'__main__'
:
# unittest.main()
x
,
y
,
z
=
inputs
()
e
=
add
(
mul
(
add
(
x
,
y
),
div
(
x
,
y
)),
sub
(
sub
(
x
,
y
),
z
))
lnk
=
CLinker
(
env
([
x
,
y
,
z
],
[
e
]))
fn
=
lnk
.
make_function
()
fn
(
2.0
,
0.0
,
2.0
)
unittest
.
main
()
gof/_test_ext.py
浏览文件 @
bcb1a525
差异被折叠。
点击展开。
gof/_test_graph.py
浏览文件 @
bcb1a525
...
...
@@ -39,37 +39,38 @@ class MyOp(Op):
class
_test_inputs
(
unittest
.
TestCase
):
def
test_
0
(
self
):
def
test_
straightforward
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
assert
inputs
(
op
.
outputs
)
==
set
([
r1
,
r2
])
def
test_
1
(
self
):
def
test_
deep
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
r5
)
assert
inputs
(
op2
.
outputs
)
==
set
([
r1
,
r2
,
r5
])
class
_test_orphans
(
unittest
.
TestCase
):
def
test_0
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
r5
)
assert
orphans
([
r1
,
r2
],
op2
.
outputs
)
==
set
([
r5
])
def
test_1
(
self
):
def
test_unreached_inputs
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
r5
)
try
:
# function doesn't raise if we put False instead of True
ro
=
results_and_orphans
([
r1
,
r2
,
op2
.
outputs
[
0
]],
op
.
outputs
,
True
)
self
.
fail
()
except
Exception
,
e
:
if
e
[
0
]
is
results_and_orphans
.
E_unreached
:
return
raise
class
_test_orphans
(
unittest
.
TestCase
):
def
test_straightforward
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
r5
)
assert
orphans
([
r1
,
r2
],
op2
.
outputs
)
==
set
([
r5
])
class
_test_as_string
(
unittest
.
TestCase
):
...
...
@@ -78,24 +79,24 @@ class _test_as_string(unittest.TestCase):
node_formatter
=
lambda
op
,
argstrings
:
"
%
s(
%
s)"
%
(
op
.
__class__
.
__name__
,
", "
.
join
(
argstrings
))
def
test_
0
(
self
):
def
test_
straightforward
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
assert
as_string
([
r1
,
r2
],
op
.
outputs
)
==
[
"MyOp(1, 2)"
]
def
test_
1
(
self
):
def
test_
deep
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
r5
)
assert
as_string
([
r1
,
r2
,
r5
],
op2
.
outputs
)
==
[
"MyOp(MyOp(1, 2), 5)"
]
def
test_
2
(
self
):
def
test_
multiple_references
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
op
.
outputs
[
0
])
assert
as_string
([
r1
,
r2
,
r5
],
op2
.
outputs
)
==
[
"MyOp(*1 -> MyOp(1, 2), *1)"
]
def
test_
3
(
self
):
def
test_
cutoff
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
op
.
outputs
[
0
])
...
...
@@ -105,24 +106,24 @@ class _test_as_string(unittest.TestCase):
class
_test_clone
(
unittest
.
TestCase
):
def
test_
0
(
self
):
def
test_
accurate
(
self
):
r1
,
r2
=
MyResult
(
1
),
MyResult
(
2
)
op
=
MyOp
(
r1
,
r2
)
new
=
clone
([
r1
,
r2
],
op
.
outputs
)
assert
as_string
([
r1
,
r2
],
new
)
==
[
"MyOp(1, 2)"
]
def
test_
1
(
self
):
def
test_
copy
(
self
):
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
r1
,
r2
)
op2
=
MyOp
(
op
.
outputs
[
0
],
r5
)
new
=
clone
([
r1
,
r2
,
r5
],
op2
.
outputs
)
assert
op2
.
outputs
[
0
]
==
new
[
0
]
and
op2
.
outputs
[
0
]
is
not
new
[
0
]
assert
op2
is
not
new
[
0
]
.
owner
assert
new
[
0
]
.
owner
.
inputs
[
1
]
is
r5
assert
new
[
0
]
.
owner
.
inputs
[
0
]
==
op
.
outputs
[
0
]
and
new
[
0
]
.
owner
.
inputs
[
0
]
is
not
op
.
outputs
[
0
]
assert
op2
.
outputs
[
0
]
==
new
[
0
]
and
op2
.
outputs
[
0
]
is
not
new
[
0
]
# the new output is like the old one but not the same object
assert
op2
is
not
new
[
0
]
.
owner
# the new output has a new owner
assert
new
[
0
]
.
owner
.
inputs
[
1
]
is
r5
# the inputs are not copied
assert
new
[
0
]
.
owner
.
inputs
[
0
]
==
op
.
outputs
[
0
]
and
new
[
0
]
.
owner
.
inputs
[
0
]
is
not
op
.
outputs
[
0
]
# check that we copied deeper too
def
test_
2
(
self
):
"Checks that manipulating a cloned graph leaves the original unchanged."
def
test_
not_destructive
(
self
):
# Checks that manipulating a cloned graph leaves the original unchanged.
r1
,
r2
,
r5
=
MyResult
(
1
),
MyResult
(
2
),
MyResult
(
5
)
op
=
MyOp
(
MyOp
(
r1
,
r2
)
.
outputs
[
0
],
r5
)
new
=
clone
([
r1
,
r2
,
r5
],
op
.
outputs
)
...
...
gof/_test_link.py
浏览文件 @
bcb1a525
...
...
@@ -64,8 +64,6 @@ def inputs():
return
x
,
y
,
z
def
env
(
inputs
,
outputs
,
validate
=
True
,
features
=
[]):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return
Env
(
inputs
,
outputs
,
features
=
features
,
consistency_check
=
validate
)
def
perform_linker
(
env
):
...
...
@@ -75,26 +73,27 @@ def perform_linker(env):
class
_test_PerformLinker
(
unittest
.
TestCase
):
def
test_
0
(
self
):
def
test_
thunk_inplace
(
self
):
x
,
y
,
z
=
inputs
()
e
=
mul
(
add
(
x
,
y
),
div
(
x
,
y
))
fn
,
i
,
o
=
perform_linker
(
env
([
x
,
y
,
z
],
[
e
]))
.
make_thunk
(
True
)
fn
()
assert
e
.
data
==
1.5
def
test_
1
(
self
):
def
test_
thunk_not_inplace
(
self
):
x
,
y
,
z
=
inputs
()
e
=
mul
(
add
(
x
,
y
),
div
(
x
,
y
))
fn
,
i
,
o
=
perform_linker
(
env
([
x
,
y
,
z
],
[
e
]))
.
make_thunk
(
False
)
fn
()
assert
o
[
0
]
.
data
==
1.5
assert
e
.
data
!=
1.5
def
test_
2
(
self
):
def
test_
function
(
self
):
x
,
y
,
z
=
inputs
()
e
=
mul
(
add
(
x
,
y
),
div
(
x
,
y
))
fn
=
perform_linker
(
env
([
x
,
y
,
z
],
[
e
]))
.
make_function
()
assert
fn
(
1.0
,
2.0
,
3.0
)
==
1.5
assert
e
.
data
!=
1.5
assert
e
.
data
!=
1.5
# not inplace
def
test_input_output_same
(
self
):
x
,
y
,
z
=
inputs
()
...
...
gof/_test_op.py
浏览文件 @
bcb1a525
...
...
@@ -2,7 +2,7 @@
import
unittest
from
copy
import
copy
from
op
import
*
from
result
import
ResultBase
#, BrokenLinkError
from
result
import
ResultBase
class
MyResult
(
ResultBase
):
...
...
@@ -34,27 +34,6 @@ class MyOp(Op):
self
.
inputs
=
inputs
self
.
outputs
=
[
MyResult
(
sum
([
input
.
thingy
for
input
in
inputs
]))]
# def validate_update(self):
# for input in self.inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# if self.outputs is None:
# self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
# return True
# else:
# old_thingy = self.outputs[0].thingy
# new_thingy = sum([input.thingy for input in self.inputs])
# self.outputs[0].thingy = new_thingy
# return old_thingy != new_thingy
# class MyOp(Op):
# def validate_update(self):
# for input in self.inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# self.outputs = [MyResult(sum([input.thingy for input in self.inputs]))]
class
_test_Op
(
unittest
.
TestCase
):
...
...
@@ -75,100 +54,6 @@ class _test_Op(unittest.TestCase):
else
:
raise
Exception
(
"Expected an exception"
)
# # Setting inputs and outputs
# def test_set_inputs(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.outputs[0]
# op.inputs = MyResult(4), MyResult(5)
# op.validate_update()
# assert op.outputs == [MyResult(9)] # check if the output changed to what I expect
# # assert r3.data is op.outputs[0].data # check if the data was properly transferred by set_output
# def test_set_bad_inputs(self):
# op = MyOp(MyResult(1), MyResult(2))
# try:
# op.inputs = MyResult(4), ResultBase()
# op.validate_update()
# except Exception, e:
# assert str(e) == "Error 1"
# else:
# raise Exception("Expected an exception")
# def test_set_outputs(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2) # here we only make one output
# try:
# op.outputs = MyResult(10), MyResult(11) # setting two outputs should fail
# except TypeError, e:
# assert str(e) == "The new outputs must be exactly as many as the previous outputs."
# else:
# raise Exception("Expected an exception")
# # Tests about broken links
# def test_create_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op.inputs = MyResult(3), MyResult(4)
# assert r3 not in op.outputs
# assert r3.replaced
# def test_cannot_copy_when_input_is_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3)
# op.inputs = MyResult(3), MyResult(4)
# try:
# copy(op2)
# except BrokenLinkError:
# pass
# else:
# raise Exception("Expected an exception")
# def test_get_input_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3)
# op.inputs = MyResult(3), MyResult(4)
# try:
# op2.get_input(0)
# except BrokenLinkError:
# pass
# else:
# raise Exception("Expected an exception")
# def test_get_inputs_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3)
# op.inputs = MyResult(3), MyResult(4)
# try:
# op2.get_inputs()
# except BrokenLinkError:
# pass
# else:
# raise Exception("Expected an exception")
# def test_repair_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# r3 = op.out
# op2 = MyOp(r3, MyResult(10))
# op.inputs = MyResult(3), MyResult(4)
# op2.repair()
# assert op2.outputs == [MyResult(17)]
# # Tests about string representation
# def test_create_broken_link(self):
# r1, r2 = MyResult(1), MyResult(2)
# op = MyOp(r1, r2)
# assert str(op) == "MyOp(1, 2)"
if
__name__
==
'__main__'
:
...
...
gof/_test_opt.py
浏览文件 @
bcb1a525
...
...
@@ -260,7 +260,7 @@ class _test_MergeOptimizer(unittest.TestCase):
class
_test_ConstantFinder
(
unittest
.
TestCase
):
def
test_
0
(
self
):
def
test_
straightforward
(
self
):
x
,
y
,
z
=
inputs
()
y
.
data
=
2
z
.
data
=
2
...
...
@@ -272,7 +272,7 @@ class _test_ConstantFinder(unittest.TestCase):
assert
str
(
g
)
==
"[Op1(x, y, y)]"
\
or
str
(
g
)
==
"[Op1(x, z, z)]"
def
test_
1
(
self
):
def
test_
deep
(
self
):
x
,
y
,
z
=
inputs
()
y
.
data
=
2
z
.
data
=
2
...
...
@@ -284,11 +284,11 @@ class _test_ConstantFinder(unittest.TestCase):
assert
str
(
g
)
==
"[Op1(*1 -> Op2(x, y), *1, *1)]"
\
or
str
(
g
)
==
"[Op1(*1 -> Op2(x, z), *1, *1)]"
def
test_
2
(
self
):
def
test_
destroyed_orphan_not_constant
(
self
):
x
,
y
,
z
=
inputs
()
y
.
data
=
2
z
.
data
=
2
e
=
op_d
(
x
,
op2
(
y
,
z
))
e
=
op_d
(
x
,
op2
(
y
,
z
))
# here x is destroyed by op_d
g
=
env
([
y
],
[
e
])
ConstantFinder
()
.
optimize
(
g
)
assert
not
getattr
(
x
,
'constant'
,
False
)
and
z
.
constant
...
...
gof/_test_result.py
浏览文件 @
bcb1a525
...
...
@@ -36,9 +36,10 @@ class MyResult(ResultBase):
class
_test_ResultBase
(
unittest
.
TestCase
):
def
test_
0
(
self
):
def
test_
trivial
(
self
):
r
=
ResultBase
()
def
test_1
(
self
):
def
test_state
(
self
):
r
=
ResultBase
()
assert
r
.
state
is
Empty
...
...
gof/_test_toolbox.py
浏览文件 @
bcb1a525
...
...
@@ -54,14 +54,12 @@ def inputs():
return
x
,
y
,
z
def
env
(
inputs
,
outputs
,
validate
=
True
,
features
=
[]):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return
Env
(
inputs
,
outputs
,
features
=
features
,
consistency_check
=
validate
)
class
_test_EquivTool
(
unittest
.
TestCase
):
def
test_
0
(
self
):
def
test_
straightforward
(
self
):
x
,
y
,
z
=
inputs
()
sx
=
sigmoid
(
x
)
e
=
add
(
sx
,
sigmoid
(
y
))
...
...
@@ -72,6 +70,35 @@ class _test_EquivTool(unittest.TestCase):
assert
isinstance
(
g
.
equiv
(
sx
)
.
owner
,
Dot
)
class
_test_InstanceFinder
(
unittest
.
TestCase
):
def
test_straightforward
(
self
):
x
,
y
,
z
=
inputs
()
e0
=
dot
(
y
,
z
)
e
=
add
(
add
(
sigmoid
(
x
),
sigmoid
(
sigmoid
(
z
))),
dot
(
add
(
x
,
y
),
e0
))
g
=
env
([
x
,
y
,
z
],
[
e
],
features
=
[
InstanceFinder
])
for
type
,
num
in
((
Add
,
3
),
(
Sigmoid
,
3
),
(
Dot
,
2
)):
if
not
len
([
x
for
x
in
g
.
get_instances_of
(
type
)])
==
num
:
self
.
fail
((
type
,
num
))
new_e0
=
add
(
y
,
z
)
assert
e0
.
owner
in
g
.
get_instances_of
(
Dot
)
assert
new_e0
.
owner
not
in
g
.
get_instances_of
(
Add
)
g
.
replace
(
e0
,
new_e0
)
assert
e0
.
owner
not
in
g
.
get_instances_of
(
Dot
)
assert
new_e0
.
owner
in
g
.
get_instances_of
(
Add
)
for
type
,
num
in
((
Add
,
4
),
(
Sigmoid
,
3
),
(
Dot
,
1
)):
if
not
len
([
x
for
x
in
g
.
get_instances_of
(
type
)])
==
num
:
self
.
fail
((
type
,
num
))
def
test_robustness
(
self
):
x
,
y
,
z
=
inputs
()
e
=
add
(
add
(
sigmoid
(
x
),
sigmoid
(
sigmoid
(
z
))),
dot
(
add
(
x
,
y
),
dot
(
y
,
z
)))
g
=
env
([
x
,
y
,
z
],
[
e
],
features
=
[
InstanceFinder
])
gen
=
g
.
get_instances_of
(
Sigmoid
)
# I want to get Sigmoid instances
g
.
replace
(
e
,
add
(
x
,
y
))
# but here I prune them all
assert
len
([
x
for
x
in
gen
])
==
0
# the generator should not yield them
if
__name__
==
'__main__'
:
unittest
.
main
()
...
...
gof/cc.py
浏览文件 @
bcb1a525
...
...
@@ -512,7 +512,8 @@ class CLinker(Linker):
# List of indices that should be ignored when passing the arguments
# (basically, everything that the previous call to uniq eliminated)
self
.
dupidx
=
[
i
for
i
,
x
in
enumerate
(
all
)
if
all
.
count
(
x
)
>
1
and
all
.
index
(
x
)
!=
i
]
return
self
.
struct_code
def
find_task
(
self
,
failure_code
):
"""
Maps a failure code to the task that is associated to it.
...
...
gof/toolbox.py
浏览文件 @
bcb1a525
...
...
@@ -113,25 +113,26 @@ class PrintListener(Listener):
print
"-- moving from
%
s to
%
s"
%
(
r
,
new_r
)
### UNUSED AND UNTESTED ###
class
ChangeListener
(
Listener
):
#
class ChangeListener(Listener):
def
__init__
(
self
,
env
):
self
.
change
=
False
#
def __init__(self, env):
#
self.change = False
def
on_import
(
self
,
op
):
self
.
change
=
True
#
def on_import(self, op):
#
self.change = True
def
on_prune
(
self
,
op
):
self
.
change
=
True
#
def on_prune(self, op):
#
self.change = True
def
on_rewire
(
self
,
clients
,
r
,
new_r
):
self
.
change
=
True
#
def on_rewire(self, clients, r, new_r):
#
self.change = True
def
__call__
(
self
,
value
=
"get"
):
if
value
==
"get"
:
return
self
.
change
else
:
self
.
change
=
value
#
def __call__(self, value = "get"):
#
if value == "get":
#
return self.change
#
else:
#
self.change = value
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论