Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c0f62796
提交
c0f62796
authored
8月 18, 2015
作者:
abergeron
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3232 from t13m/merge_assert
make MergeOptimizer merge nodes with assert input
上级
0a7415d7
65dda20e
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
269 行增加
和
14 行删除
+269
-14
opt.py
theano/gof/opt.py
+143
-14
test_opt.py
theano/gof/tests/test_opt.py
+126
-0
没有找到文件。
theano/gof/opt.py
浏览文件 @
c0f62796
...
@@ -503,7 +503,7 @@ class MergeFeature(object):
...
@@ -503,7 +503,7 @@ class MergeFeature(object):
# we adopt convention to keep the last name
# we adopt convention to keep the last name
if
c
.
name
:
if
c
.
name
:
other_c
.
name
=
c
.
name
other_c
.
name
=
c
.
name
self
.
scheduled
.
append
([[(
c
,
other_c
)]])
self
.
scheduled
.
append
([[(
c
,
other_c
,
'merge'
)]])
else
:
else
:
# this is a new constant
# this is a new constant
self
.
const_sig
[
c
]
=
sig
self
.
const_sig
[
c
]
=
sig
...
@@ -515,6 +515,8 @@ class MergeFeature(object):
...
@@ -515,6 +515,8 @@ class MergeFeature(object):
if
node
in
self
.
nodes_seen
:
if
node
in
self
.
nodes_seen
:
return
return
node_has_assert
=
False
# These asserts ensure that the fgraph has set the clients field
# These asserts ensure that the fgraph has set the clients field
# properly.
# properly.
# The clients should at least contain `node` itself!
# The clients should at least contain `node` itself!
...
@@ -523,6 +525,23 @@ class MergeFeature(object):
...
@@ -523,6 +525,23 @@ class MergeFeature(object):
assert
(
node
,
0
)
in
node
.
inputs
[
0
]
.
clients
assert
(
node
,
0
)
in
node
.
inputs
[
0
]
.
clients
merge_candidates
=
[
c
for
(
c
,
i
)
in
node
.
inputs
[
0
]
.
clients
merge_candidates
=
[
c
for
(
c
,
i
)
in
node
.
inputs
[
0
]
.
clients
if
c
in
self
.
nodes_seen
]
if
c
in
self
.
nodes_seen
]
# Put all clients of Assert inputs (if exist) into merge_candidates
for
i
in
node
.
inputs
:
if
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
theano
.
tensor
.
opt
.
Assert
):
node_has_assert
=
True
assert_clients
=
[
c
for
(
c
,
_
)
in
i
.
owner
.
inputs
[
0
]
.
clients
if
c
in
self
.
nodes_seen
]
for
idx
in
range
(
len
(
assert_clients
)):
client
=
assert_clients
[
idx
]
if
isinstance
(
i
.
owner
.
op
,
theano
.
tensor
.
opt
.
Assert
):
for
c
in
client
.
outputs
[
0
]
.
clients
:
if
c
[
0
]
in
self
.
nodes_seen
:
assert_clients
.
append
(
c
[
0
])
merge_candidates
.
extend
(
assert_clients
)
else
:
else
:
merge_candidates
=
[]
merge_candidates
=
[]
...
@@ -533,19 +552,66 @@ class MergeFeature(object):
...
@@ -533,19 +552,66 @@ class MergeFeature(object):
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
continue
continue
cand_has_assert
=
False
# Get input list of the candidate with assert removed
cand_inputs_assert_removed
=
[]
for
i
in
candidate
.
inputs
:
if
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
theano
.
tensor
.
opt
.
Assert
):
cand_has_assert
=
True
cand_inputs_assert_removed
.
append
(
i
.
owner
.
inputs
[
0
])
else
:
cand_inputs_assert_removed
.
append
(
i
)
# Get input list of the node with assert removed
if
node_has_assert
:
node_inputs_assert_removed
=
[]
for
i
in
node
.
inputs
:
if
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
theano
.
tensor
.
opt
.
Assert
):
node_inputs_assert_removed
.
append
(
i
.
owner
.
inputs
[
0
])
else
:
node_inputs_assert_removed
.
append
(
i
)
else
:
node_inputs_assert_removed
=
node
.
inputs
inputs_match
=
all
(
node_in
is
cand_in
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
for
node_in
,
cand_in
candidate
.
inputs
))
in
zip
(
node_inputs_assert_removed
,
cand_inputs_assert_removed
))
if
inputs_match
and
node
.
op
==
candidate
.
op
:
if
inputs_match
and
node
.
op
==
candidate
.
op
:
if
(
node
,
candidate
)
in
self
.
blacklist
:
if
(
node
,
candidate
)
in
self
.
blacklist
:
# They were already tried, and there was an error
# They were already tried, and there was an error
continue
continue
# Schedule transfer of clients from node to candidate
# replace node with candidate
pairs
=
list
(
zip
(
node
.
outputs
,
candidate
.
outputs
))
if
not
(
node_has_assert
or
cand_has_assert
):
# Schedule transfer of clients from node to candidate
pairs
=
list
(
zip
(
node
.
outputs
,
candidate
.
outputs
,
[
'merge'
]
*
len
(
node
.
outputs
)))
# if the current node has assert input, it should not be
# replaced with a candidate node which has no assert input
elif
node_has_assert
and
not
cand_has_assert
:
pairs
=
list
(
zip
(
candidate
.
outputs
,
node
.
outputs
,
[
'merge'
]
*
len
(
node
.
outputs
)))
else
:
new_inputs
=
self
.
get_merged_assert_input
(
node
,
candidate
)
new_node
=
node
.
op
(
*
new_inputs
)
pairs
=
list
(
zip
(
node
.
outputs
,
new_node
.
owner
.
outputs
,
[
'new_node'
]
*
len
(
node
.
outputs
)))
+
\
list
(
zip
(
candidate
.
outputs
,
new_node
.
owner
.
outputs
,
[
'new_node'
]
*
len
(
node
.
outputs
)))
# transfer names
# transfer names
for
node_output
,
cand_output
in
pairs
:
for
pair
in
pairs
:
node_output
,
cand_output
=
pair
[:
2
]
# clobber old name with new one
# clobber old name with new one
# it's arbitrary... one of the names has to go
# it's arbitrary... one of the names has to go
if
node_output
.
name
:
if
node_output
.
name
:
...
@@ -558,6 +624,37 @@ class MergeFeature(object):
...
@@ -558,6 +624,37 @@ class MergeFeature(object):
else
:
else
:
self
.
nodes_seen
.
add
(
node
)
self
.
nodes_seen
.
add
(
node
)
def
get_merged_assert_input
(
self
,
node
,
candidate
):
new_inputs
=
[]
for
node_i
,
cand_i
in
zip
(
node
.
inputs
,
candidate
.
inputs
):
# if node_i is assert
if
(
node_i
.
owner
and
isinstance
(
node_i
.
owner
.
op
,
theano
.
tensor
.
opt
.
Assert
)):
# node_i is assert, cand_i is assert
if
(
cand_i
.
owner
and
isinstance
(
cand_i
.
owner
.
op
,
theano
.
tensor
.
opt
.
Assert
)):
# Here two assert nodes are merged.
# Step 1. Merge conditions of both assert nodes.
# Step 2. Make the new assert node
node_cond
=
node_i
.
owner
.
inputs
[
1
:]
cand_cond
=
cand_i
.
owner
.
inputs
[
1
:]
new_cond
=
list
(
set
(
node_cond
+
cand_cond
))
new_inputs
.
append
(
theano
.
tensor
.
opt
.
assert_op
(
node_i
.
owner
.
inputs
[
0
],
*
new_cond
))
# node_i is assert, cand_i is not assert
else
:
new_inputs
.
append
(
node_i
)
else
:
# if node_i is not an assert node, append cand_i
new_inputs
.
append
(
cand_i
)
return
new_inputs
class
MergeOptimizer
(
Optimizer
):
class
MergeOptimizer
(
Optimizer
):
"""
"""
...
@@ -594,7 +691,7 @@ class MergeOptimizer(Optimizer):
...
@@ -594,7 +691,7 @@ class MergeOptimizer(Optimizer):
while
sched
:
while
sched
:
pairs_list
=
sched
.
pop
()
pairs_list
=
sched
.
pop
()
success
=
True
success
=
True
for
pairs
in
pairs_list
:
for
pairs
_
in
pairs_list
:
# We must check again the equivalence, as the graph
# We must check again the equivalence, as the graph
# can have changed. If so, doing the replacement can
# can have changed. If so, doing the replacement can
# introduce node that depend on itself. Doing the
# introduce node that depend on itself. Doing the
...
@@ -603,17 +700,49 @@ class MergeOptimizer(Optimizer):
...
@@ -603,17 +700,49 @@ class MergeOptimizer(Optimizer):
# doing the full cycle check. The full cycle check is
# doing the full cycle check. The full cycle check is
# skipped by validate() if the graph don't contain
# skipped by validate() if the graph don't contain
# destroyers.
# destroyers.
var
=
pairs
[
0
][
0
]
var
,
candidate
,
merge_mode
=
pairs_
[
0
]
candidate
=
pairs
[
0
][
1
]
if
merge_mode
==
"new_node"
and
hasattr
(
var
,
'fgraph'
):
if
(
not
hasattr
(
var
,
'fgraph'
)
or
pass
not
hasattr
(
candidate
,
'fgraph'
)):
elif
(
not
hasattr
(
var
,
'fgraph'
)
or
not
hasattr
(
candidate
,
'fgraph'
)):
continue
continue
# Keep len(item) == 2 for item in pairs
pairs
=
[
pair
[:
2
]
for
pair
in
pairs_
]
if
var
.
owner
and
candidate
.
owner
:
if
var
.
owner
and
candidate
.
owner
:
node
=
var
.
owner
node
=
var
.
owner
candidate
=
candidate
.
owner
candidate
=
candidate
.
owner
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
# Get input list of the candidate node with assert
node
.
inputs
,
candidate
.
inputs
))
# nodes removed
cand_inputs_assert_removed
=
[]
for
i
in
candidate
.
inputs
:
if
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
theano
.
tensor
.
opt
.
Assert
):
cand_inputs_assert_removed
.
append
(
i
.
owner
.
inputs
[
0
])
else
:
cand_inputs_assert_removed
.
append
(
i
)
# Get input list of the node with assert nodes removed
node_inputs_assert_removed
=
[]
for
i
in
node
.
inputs
:
if
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
theano
.
tensor
.
opt
.
Assert
):
node_inputs_assert_removed
.
append
(
i
.
owner
.
inputs
[
0
])
else
:
node_inputs_assert_removed
.
append
(
i
)
if
merge_mode
==
"new_node"
:
inputs_match
=
True
else
:
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node_inputs_assert_removed
,
cand_inputs_assert_removed
))
# No need to compare the op again, as it don't change.
# No need to compare the op again, as it don't change.
if
not
inputs_match
:
if
not
inputs_match
:
continue
continue
...
...
theano/gof/tests/test_opt.py
浏览文件 @
c0f62796
...
@@ -6,6 +6,9 @@ from theano.gof.opt import * # noqa
...
@@ -6,6 +6,9 @@ from theano.gof.opt import * # noqa
from
theano.gof.fg
import
FunctionGraph
as
Env
from
theano.gof.fg
import
FunctionGraph
as
Env
from
theano.gof.toolbox
import
*
# noqa
from
theano.gof.toolbox
import
*
# noqa
from
theano.tensor.opt
import
Assert
from
theano
import
tensor
as
T
def
as_variable
(
x
):
def
as_variable
(
x
):
if
not
isinstance
(
x
,
Variable
):
if
not
isinstance
(
x
,
Variable
):
...
@@ -360,6 +363,129 @@ class TestMergeOptimizer:
...
@@ -360,6 +363,129 @@ class TestMergeOptimizer:
strg
=
str
(
g
)
strg
=
str
(
g
)
assert
strg
==
'[Op1(y, y)]'
or
strg
==
'[Op1(z, z)]'
assert
strg
==
'[Op1(y, y)]'
or
strg
==
'[Op1(z, z)]'
def
test_one_assert_merge
(
self
):
# Merge two nodes, one has assert, the other not.
x1
=
T
.
matrix
(
'x1'
)
x2
=
T
.
matrix
(
'x2'
)
e
=
T
.
dot
(
x1
,
x2
)
+
T
.
dot
(
T
.
opt
.
assert_op
(
x1
,
(
x1
>
x2
)
.
all
()),
x2
)
g
=
Env
([
x1
,
x2
],
[
e
])
MergeOptimizer
()
.
optimize
(
g
)
strg
=
theano
.
printing
.
debugprint
(
g
,
file
=
'str'
)
strref
=
'''Elemwise{add,no_inplace} [@A] '' 4
|dot [@B] '' 3
| |Assert{msg='Theano Assert failed!'} [@C] '' 2
| | |x1 [@D]
| | |All [@E] '' 1
| | |Elemwise{gt,no_inplace} [@F] '' 0
| | |x1 [@D]
| | |x2 [@G]
| |x2 [@G]
|dot [@B] '' 3
'''
assert
strg
==
strref
,
(
strg
,
strref
)
def
test_both_assert_merge_1
(
self
):
# Merge two nodes, both have assert on the same node
# with different conditions.
x1
=
T
.
matrix
(
'x1'
)
x2
=
T
.
matrix
(
'x2'
)
x3
=
T
.
matrix
(
'x3'
)
e
=
T
.
dot
(
T
.
opt
.
assert_op
(
x1
,
(
x1
>
x3
)
.
all
()),
x2
)
+
\
T
.
dot
(
T
.
opt
.
assert_op
(
x1
,
(
x1
>
x2
)
.
all
()),
x2
)
g
=
Env
([
x1
,
x2
,
x3
],
[
e
])
MergeOptimizer
()
.
optimize
(
g
)
strg
=
theano
.
printing
.
debugprint
(
g
,
file
=
'str'
)
strref1
=
'''Elemwise{add,no_inplace} [@A] '' 6
|dot [@B] '' 5
| |Assert{msg='Theano Assert failed!'} [@C] '' 4
| | |x1 [@D]
| | |All [@E] '' 3
| | | |Elemwise{gt,no_inplace} [@F] '' 1
| | | |x1 [@D]
| | | |x3 [@G]
| | |All [@H] '' 2
| | |Elemwise{gt,no_inplace} [@I] '' 0
| | |x1 [@D]
| | |x2 [@J]
| |x2 [@J]
|dot [@B] '' 5
'''
strref2
=
'''Elemwise{add,no_inplace} [@A] '' 6
|dot [@B] '' 5
| |Assert{msg='Theano Assert failed!'} [@C] '' 4
| | |x1 [@D]
| | |All [@E] '' 3
| | | |Elemwise{gt,no_inplace} [@F] '' 1
| | | |x1 [@D]
| | | |x2 [@G]
| | |All [@H] '' 2
| | |Elemwise{gt,no_inplace} [@I] '' 0
| | |x1 [@D]
| | |x3 [@J]
| |x2 [@G]
|dot [@B] '' 5
'''
# print(strg)
assert
strg
==
strref1
or
strg
==
strref2
,
(
strg
,
strref1
,
strref2
)
def
test_both_assert_merge_2
(
self
):
# Merge two nodes, both have assert on different node
x1
=
T
.
matrix
(
'x1'
)
x2
=
T
.
matrix
(
'x2'
)
x3
=
T
.
matrix
(
'x3'
)
e
=
T
.
dot
(
T
.
opt
.
assert_op
(
x1
,
(
x1
>
x3
)
.
all
()),
x2
)
+
\
T
.
dot
(
x1
,
T
.
opt
.
assert_op
(
x2
,
(
x2
>
x3
)
.
all
()))
g
=
Env
([
x1
,
x2
,
x3
],
[
e
])
MergeOptimizer
()
.
optimize
(
g
)
strg
=
theano
.
printing
.
debugprint
(
g
,
file
=
'str'
)
strref
=
'''Elemwise{add,no_inplace} [@A] '' 7
|dot [@B] '' 6
| |Assert{msg='Theano Assert failed!'} [@C] '' 5
| | |x1 [@D]
| | |All [@E] '' 3
| | |Elemwise{gt,no_inplace} [@F] '' 1
| | |x1 [@D]
| | |x3 [@G]
| |Assert{msg='Theano Assert failed!'} [@H] '' 4
| |x2 [@I]
| |All [@J] '' 2
| |Elemwise{gt,no_inplace} [@K] '' 0
| |x2 [@I]
| |x3 [@G]
|dot [@B] '' 6
'''
# print(strg)
assert
strg
==
strref
,
(
strg
,
strref
)
def
test_both_assert_merge_2_reverse
(
self
):
# Test case "test_both_assert_merge_2" but in reverse order
x1
=
T
.
matrix
(
'x1'
)
x2
=
T
.
matrix
(
'x2'
)
x3
=
T
.
matrix
(
'x3'
)
e
=
T
.
dot
(
x1
,
T
.
opt
.
assert_op
(
x2
,
(
x2
>
x3
)
.
all
()))
+
\
T
.
dot
(
T
.
opt
.
assert_op
(
x1
,
(
x1
>
x3
)
.
all
()),
x2
)
g
=
Env
([
x1
,
x2
,
x3
],
[
e
])
MergeOptimizer
()
.
optimize
(
g
)
strg
=
theano
.
printing
.
debugprint
(
g
,
file
=
'str'
)
strref
=
'''Elemwise{add,no_inplace} [@A] '' 7
|dot [@B] '' 6
| |Assert{msg='Theano Assert failed!'} [@C] '' 5
| | |x1 [@D]
| | |All [@E] '' 3
| | |Elemwise{gt,no_inplace} [@F] '' 1
| | |x1 [@D]
| | |x3 [@G]
| |Assert{msg='Theano Assert failed!'} [@H] '' 4
| |x2 [@I]
| |All [@J] '' 2
| |Elemwise{gt,no_inplace} [@K] '' 0
| |x2 [@I]
| |x3 [@G]
|dot [@B] '' 6
'''
print
(
strg
)
assert
strg
==
strref
,
(
strg
,
strref
)
class
TestEquilibrium
(
object
):
class
TestEquilibrium
(
object
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论