Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e8aff982
提交
e8aff982
authored
1月 29, 2014
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Faster opt by not doing useless stuff.
This make the Canonizer take 4x less time in a test case.
上级
2b7b0305
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
15 行增加
和
13 行删除
+15
-13
opt.py
theano/tensor/opt.py
+15
-13
没有找到文件。
theano/tensor/opt.py
浏览文件 @
e8aff982
...
@@ -2867,12 +2867,13 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -2867,12 +2867,13 @@ class Canonizer(gof.LocalOptimizer):
else
:
else
:
return
v
return
v
def
simplify
(
self
,
num
,
denum
):
def
simplify
(
self
,
num
,
denum
,
out_type
):
"""
"""
Shorthand for:
Shorthand for:
self.simplify_constants(*self.simplify_factors(num, denum))
self.simplify_constants(*self.simplify_factors(num, denum))
"""
"""
rval
=
self
.
simplify_constants
(
*
self
.
simplify_factors
(
num
,
denum
))
rval
=
self
.
simplify_constants
(
*
self
.
simplify_factors
(
num
,
denum
),
out_type
=
out_type
)
for
reason
,
simplifier
in
self
.
external_simplifiers
:
for
reason
,
simplifier
in
self
.
external_simplifiers
:
# TODO: document that 'reason' is associated with this
# TODO: document that 'reason' is associated with this
# simplification to help auditing when things go
# simplification to help auditing when things go
...
@@ -2896,7 +2897,7 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -2896,7 +2897,7 @@ class Canonizer(gof.LocalOptimizer):
denum
.
remove
(
v
)
denum
.
remove
(
v
)
return
num
,
denum
return
num
,
denum
def
simplify_constants
(
self
,
orig_num
,
orig_denum
):
def
simplify_constants
(
self
,
orig_num
,
orig_denum
,
out_type
=
None
):
"""
"""
Finds all constants in orig_num and orig_denum (using
Finds all constants in orig_num and orig_denum (using
...
@@ -2914,7 +2915,6 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -2914,7 +2915,6 @@ class Canonizer(gof.LocalOptimizer):
# Lists representing the numerator and denumerator
# Lists representing the numerator and denumerator
num
,
denum
=
list
(
orig_num
),
list
(
orig_denum
)
num
,
denum
=
list
(
orig_num
),
list
(
orig_denum
)
out_type
=
self
.
merge_num_denum
(
orig_num
,
orig_denum
)
.
type
# Lists representing the *constant* elements of num and denum
# Lists representing the *constant* elements of num and denum
numct
,
denumct
=
[],
[]
numct
,
denumct
=
[],
[]
...
@@ -3001,7 +3001,7 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -3001,7 +3001,7 @@ class Canonizer(gof.LocalOptimizer):
# Here we make the canonical version of the graph around this node
# Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify
# See the documentation of get_num_denum and simplify
orig_num
,
orig_denum
=
self
.
get_num_denum
(
node
.
outputs
[
0
])
orig_num
,
orig_denum
=
self
.
get_num_denum
(
node
.
outputs
[
0
])
num
,
denum
=
self
.
simplify
(
list
(
orig_num
),
list
(
orig_denum
))
num
,
denum
=
self
.
simplify
(
list
(
orig_num
),
list
(
orig_denum
)
,
out
.
type
)
def
same
(
x
,
y
):
def
same
(
x
,
y
):
return
len
(
x
)
==
len
(
y
)
and
all
(
N
.
all
(
xe
==
ye
)
for
xe
,
ye
in
return
len
(
x
)
==
len
(
y
)
and
all
(
N
.
all
(
xe
==
ye
)
for
xe
,
ye
in
...
@@ -3873,7 +3873,8 @@ register_canonicalize(local_add_canonizer, name='local_add_canonizer')
...
@@ -3873,7 +3873,8 @@ register_canonicalize(local_add_canonizer, name='local_add_canonizer')
##################
##################
def
distribute_greedy
(
pos_pairs
,
neg_pairs
,
num
,
denum
,
minscore
=
0
):
def
distribute_greedy
(
pos_pairs
,
neg_pairs
,
num
,
denum
,
out_type
,
minscore
=
0
):
# each pair in pos_pairs and neg_pairs is a num/denum pair. this
# each pair in pos_pairs and neg_pairs is a num/denum pair. this
# function attempts to add num and denum to the corresponding parts
# function attempts to add num and denum to the corresponding parts
# of each pair, and counts how many multiplications/divisions can
# of each pair, and counts how many multiplications/divisions can
...
@@ -3889,10 +3890,10 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0):
...
@@ -3889,10 +3890,10 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0):
# score is number of operations saved, higher is better
# score is number of operations saved, higher is better
score
=
len
(
num
)
+
div_cost
*
len
(
denum
)
score
=
len
(
num
)
+
div_cost
*
len
(
denum
)
new_pos_pairs
=
list
(
itertools
.
starmap
(
local_mul_canonizer
.
simplify
,
new_pos_pairs
=
list
(
itertools
.
starmap
(
local_mul_canonizer
.
simplify
,
[(
n
+
num
,
d
+
denum
)
for
(
n
,
d
)
[(
n
+
num
,
d
+
denum
,
out_type
)
for
(
n
,
d
)
in
pos_pairs
]))
in
pos_pairs
]))
new_neg_pairs
=
list
(
itertools
.
starmap
(
local_mul_canonizer
.
simplify
,
new_neg_pairs
=
list
(
itertools
.
starmap
(
local_mul_canonizer
.
simplify
,
[(
n
+
num
,
d
+
denum
)
for
(
n
,
d
)
[(
n
+
num
,
d
+
denum
,
out_type
)
for
(
n
,
d
)
in
neg_pairs
]))
in
neg_pairs
]))
for
(
n
,
d
),
(
nn
,
dd
)
in
zip
(
pos_pairs
+
neg_pairs
,
new_pos_pairs
+
for
(
n
,
d
),
(
nn
,
dd
)
in
zip
(
pos_pairs
+
neg_pairs
,
new_pos_pairs
+
new_neg_pairs
):
new_neg_pairs
):
...
@@ -3905,7 +3906,7 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0):
...
@@ -3905,7 +3906,7 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0):
return
True
,
new_pos_pairs
,
new_neg_pairs
return
True
,
new_pos_pairs
,
new_neg_pairs
def
attempt_distribution
(
factor
,
num
,
denum
):
def
attempt_distribution
(
factor
,
num
,
denum
,
out_type
):
# we try to insert each num and each denum in the factor
# we try to insert each num and each denum in the factor
# returns: changes?, new_factor, new_num, new_denum
# returns: changes?, new_factor, new_num, new_denum
# if there are changes, new_num and new_denum contain all the numerators
# if there are changes, new_num and new_denum contain all the numerators
...
@@ -3918,13 +3919,13 @@ def attempt_distribution(factor, num, denum):
...
@@ -3918,13 +3919,13 @@ def attempt_distribution(factor, num, denum):
change
=
False
change
=
False
for
n
in
list
(
num
):
for
n
in
list
(
num
):
success
,
pos_pairs
,
neg_pairs
=
distribute_greedy
(
pos_pairs
,
success
,
pos_pairs
,
neg_pairs
=
distribute_greedy
(
pos_pairs
,
neg_pairs
,
[
n
],
[])
neg_pairs
,
[
n
],
[]
,
out_type
)
if
success
:
if
success
:
change
=
True
change
=
True
num
.
remove
(
n
)
num
.
remove
(
n
)
for
d
in
list
(
denum
):
for
d
in
list
(
denum
):
success
,
pos_pairs
,
neg_pairs
=
distribute_greedy
(
pos_pairs
,
success
,
pos_pairs
,
neg_pairs
=
distribute_greedy
(
pos_pairs
,
neg_pairs
,
[],
[
d
])
neg_pairs
,
[],
[
d
]
,
out_type
)
if
success
:
if
success
:
change
=
True
change
=
True
denum
.
remove
(
d
)
denum
.
remove
(
d
)
...
@@ -3969,12 +3970,13 @@ def local_greedy_distributor(node):
...
@@ -3969,12 +3970,13 @@ def local_greedy_distributor(node):
change
=
False
change
=
False
out_type
=
out
.
type
for
candidate
in
list
(
num
):
for
candidate
in
list
(
num
):
if
candidate
not
in
num
:
if
candidate
not
in
num
:
continue
continue
num
.
remove
(
candidate
)
num
.
remove
(
candidate
)
_change
,
candidate
,
num
,
denum
=
attempt_distribution
(
candidate
,
_change
,
candidate
,
num
,
denum
=
attempt_distribution
(
candidate
,
num
,
denum
)
num
,
denum
,
out_type
)
change
|=
_change
change
|=
_change
new_num
.
append
(
candidate
)
new_num
.
append
(
candidate
)
...
@@ -3983,7 +3985,7 @@ def local_greedy_distributor(node):
...
@@ -3983,7 +3985,7 @@ def local_greedy_distributor(node):
continue
continue
denum
.
remove
(
candidate
)
denum
.
remove
(
candidate
)
_change
,
candidate
,
denum
,
num
=
attempt_distribution
(
candidate
,
_change
,
candidate
,
denum
,
num
=
attempt_distribution
(
candidate
,
denum
,
num
)
denum
,
num
,
out_type
)
change
|=
_change
change
|=
_change
new_denum
.
append
(
candidate
)
new_denum
.
append
(
candidate
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论