Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2fe4b0b8
提交
2fe4b0b8
authored
11月 11, 2013
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1594 from nouiz/faster_opt
[MRG] Fix test python 2.4 and faster opt
上级
f07e4644
9c4cd7dc
显示空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
89 行增加
和
62 行删除
+89
-62
optimization.txt
doc/extending/optimization.txt
+3
-1
fg.py
theano/gof/fg.py
+4
-2
graph.py
theano/gof/graph.py
+6
-8
opt.py
theano/gof/opt.py
+40
-14
toolbox.py
theano/gof/toolbox.py
+6
-5
nnet.py
theano/tensor/nnet/nnet.py
+1
-1
opt.py
theano/tensor/opt.py
+19
-0
opt_uncanonicalize.py
theano/tensor/opt_uncanonicalize.py
+9
-30
var.py
theano/tensor/var.py
+1
-1
没有找到文件。
doc/extending/optimization.txt
浏览文件 @
2fe4b0b8
...
@@ -283,7 +283,9 @@ The local version of the above code would be the following:
...
@@ -283,7 +283,9 @@ The local version of the above code would be the following:
The definition of transform is the inner loop of the global optimizer,
The definition of transform is the inner loop of the global optimizer,
where the node is given as argument. If no changes are to be made,
where the node is given as argument. If no changes are to be made,
``False`` must be returned. Else, a list of what to replace the node's
``False`` must be returned. Else, a list of what to replace the node's
outputs with must be returned.
outputs with must be returned. This list must have the same length as
node.ouputs. If one of node.outputs don't have clients(it is not used
in the graph), you can put None in the returned list to remove it.
In order to apply the local optimizer we must use it in conjunction
In order to apply the local optimizer we must use it in conjunction
with a :ref:`navigator`. Basically, a :ref:`navigator` is a global
with a :ref:`navigator`. Basically, a :ref:`navigator` is a global
...
...
theano/gof/fg.py
浏览文件 @
2fe4b0b8
...
@@ -93,6 +93,7 @@ class FunctionGraph(utils.object2):
...
@@ -93,6 +93,7 @@ class FunctionGraph(utils.object2):
inputs
,
outputs
=
graph
.
clone
(
inputs
,
outputs
)
inputs
,
outputs
=
graph
.
clone
(
inputs
,
outputs
)
self
.
execute_callbacks_time
=
0
self
.
execute_callbacks_time
=
0
self
.
execute_callbacks_times
=
{}
if
features
is
None
:
if
features
is
None
:
features
=
[]
features
=
[]
...
@@ -507,7 +508,7 @@ class FunctionGraph(utils.object2):
...
@@ -507,7 +508,7 @@ class FunctionGraph(utils.object2):
attach
(
self
)
attach
(
self
)
except
toolbox
.
AlreadyThere
:
except
toolbox
.
AlreadyThere
:
return
return
self
.
execute_callbacks_times
.
setdefault
(
feature
,
0
)
#it would be nice if we could require a specific class instead of
#it would be nice if we could require a specific class instead of
#a "workalike" so we could do actual error checking
#a "workalike" so we could do actual error checking
#if not isinstance(feature, toolbox.Feature):
#if not isinstance(feature, toolbox.Feature):
...
@@ -549,8 +550,9 @@ class FunctionGraph(utils.object2):
...
@@ -549,8 +550,9 @@ class FunctionGraph(utils.object2):
# try; the AttributeError reall must come from feature.${name}
# try; the AttributeError reall must come from feature.${name}
# not existing
# not existing
continue
continue
tf0
=
time
.
time
()
fn
(
self
,
*
args
,
**
kwargs
)
fn
(
self
,
*
args
,
**
kwargs
)
self
.
execute_callbacks_times
[
feature
]
+=
time
.
time
()
-
tf0
self
.
execute_callbacks_time
+=
time
.
time
()
-
t0
self
.
execute_callbacks_time
+=
time
.
time
()
-
t0
def
collect_callbacks
(
self
,
name
,
*
args
):
def
collect_callbacks
(
self
,
name
,
*
args
):
...
...
theano/gof/graph.py
浏览文件 @
2fe4b0b8
...
@@ -495,14 +495,14 @@ def stack_search(start, expand, mode='bfs', build_inv=False):
...
@@ -495,14 +495,14 @@ def stack_search(start, expand, mode='bfs', build_inv=False):
:param start: search from these nodes
:param start: search from these nodes
:type expand: callable
:type expand: callable
:param expand:
:param expand:
when we get to a node, add expand(node) to the list of nodes to visit.
This function
when we get to a node, add expand(node) to the list of nodes to visit.
should return a list, or None
This function
should return a list, or None
:rtype: list of `Variable` or `Apply` instances (depends on `expend`)
:rtype: list of `Variable` or `Apply` instances (depends on `expend`)
:return: the list of nodes in order of traversal.
:return: the list of nodes in order of traversal.
:note:
:note:
a node will appear at most once in the return value, even if it
appears multiple times
a node will appear at most once in the return value, even if it
in the start parameter.
appears multiple times
in the start parameter.
:postcondition: every element of start is transferred to the returned list.
:postcondition: every element of start is transferred to the returned list.
:postcondition: start is empty.
:postcondition: start is empty.
...
@@ -549,9 +549,7 @@ def ancestors(variable_list, blockers=None):
...
@@ -549,9 +549,7 @@ def ancestors(variable_list, blockers=None):
"""
"""
def
expand
(
r
):
def
expand
(
r
):
if
r
.
owner
and
(
not
blockers
or
r
not
in
blockers
):
if
r
.
owner
and
(
not
blockers
or
r
not
in
blockers
):
l
=
list
(
r
.
owner
.
inputs
)
return
reversed
(
r
.
owner
.
inputs
)
l
.
reverse
()
return
l
dfs_variables
=
stack_search
(
deque
(
variable_list
),
expand
,
'dfs'
)
dfs_variables
=
stack_search
(
deque
(
variable_list
),
expand
,
'dfs'
)
return
dfs_variables
return
dfs_variables
...
@@ -801,7 +799,7 @@ def io_toposort(inputs, outputs, orderings=None):
...
@@ -801,7 +799,7 @@ def io_toposort(inputs, outputs, orderings=None):
if
isinstance
(
obj
,
Variable
):
if
isinstance
(
obj
,
Variable
):
if
obj
.
owner
:
if
obj
.
owner
:
rval
=
[
obj
.
owner
]
rval
=
[
obj
.
owner
]
if
isinstance
(
obj
,
Apply
):
el
if
isinstance
(
obj
,
Apply
):
rval
=
list
(
obj
.
inputs
)
rval
=
list
(
obj
.
inputs
)
rval
.
extend
(
orderings
.
get
(
obj
,
[]))
rval
.
extend
(
orderings
.
get
(
obj
,
[]))
else
:
else
:
...
...
theano/gof/opt.py
浏览文件 @
2fe4b0b8
...
@@ -514,7 +514,8 @@ class MergeFeature(object):
...
@@ -514,7 +514,8 @@ class MergeFeature(object):
continue
continue
inputs_match
=
all
(
node_in
is
cand_in
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
))
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
))
if
inputs_match
and
node
.
op
==
candidate
.
op
:
if
inputs_match
and
node
.
op
==
candidate
.
op
:
if
(
node
,
candidate
)
in
self
.
blacklist
:
if
(
node
,
candidate
)
in
self
.
blacklist
:
# They were already tried, and there was an error
# They were already tried, and there was an error
...
@@ -566,6 +567,8 @@ class MergeOptimizer(Optimizer):
...
@@ -566,6 +567,8 @@ class MergeOptimizer(Optimizer):
if
fgraph
.
profile
:
if
fgraph
.
profile
:
validate_before
=
fgraph
.
profile
.
validate_time
validate_before
=
fgraph
.
profile
.
validate_time
callback_before
=
fgraph
.
execute_callbacks_time
callback_before
=
fgraph
.
execute_callbacks_time
callbacks_before
=
fgraph
.
execute_callbacks_times
.
copy
()
nb_merged
=
0
nb_merged
=
0
nb_constant
=
0
nb_constant
=
0
while
sched
:
while
sched
:
...
@@ -589,20 +592,28 @@ class MergeOptimizer(Optimizer):
...
@@ -589,20 +592,28 @@ class MergeOptimizer(Optimizer):
if
fgraph
.
profile
:
if
fgraph
.
profile
:
validate_time
=
fgraph
.
profile
.
validate_time
-
validate_before
validate_time
=
fgraph
.
profile
.
validate_time
-
validate_before
callback_time
=
fgraph
.
execute_callbacks_time
-
callback_before
callback_time
=
fgraph
.
execute_callbacks_time
-
callback_before
callbacks_time
=
{}
for
k
,
v
in
fgraph
.
execute_callbacks_times
.
iteritems
():
if
k
in
callbacks_before
:
callbacks_time
[
k
]
=
v
-
callbacks_before
[
k
]
else
:
callbacks_time
[
k
]
=
v
else
:
else
:
validate_time
=
None
validate_time
=
None
callback_time
=
None
callback_time
=
None
callbacks_time
=
{}
# clear blacklist
# clear blacklist
fgraph
.
merge_feature
.
blacklist
=
[]
fgraph
.
merge_feature
.
blacklist
=
[]
return
(
nb_fail
,
time
.
time
()
-
t0
,
validate_time
,
return
(
nb_fail
,
time
.
time
()
-
t0
,
validate_time
,
callback_time
,
nb_merged
,
nb_constant
)
callback_time
,
callbacks_time
,
nb_merged
,
nb_constant
)
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
@staticmethod
@staticmethod
def
print_profile
(
stream
,
prof
,
level
=
0
):
def
print_profile
(
stream
,
prof
,
level
=
0
):
nb_fail
,
replace_time
,
validate_time
,
callback_time
,
nb_merged
,
nb_constant
=
prof
(
nb_fail
,
replace_time
,
validate_time
,
callback_time
,
callbacks_time
,
nb_merged
,
nb_constant
)
=
prof
blanc
=
(
' '
*
level
)
blanc
=
(
' '
*
level
)
print
>>
stream
,
blanc
,
"MergeOptimizer"
print
>>
stream
,
blanc
,
"MergeOptimizer"
...
@@ -610,6 +621,7 @@ class MergeOptimizer(Optimizer):
...
@@ -610,6 +621,7 @@ class MergeOptimizer(Optimizer):
print
>>
stream
,
blanc
,
" replace_time"
,
replace_time
print
>>
stream
,
blanc
,
" replace_time"
,
replace_time
print
>>
stream
,
blanc
,
" validate_time"
,
validate_time
print
>>
stream
,
blanc
,
" validate_time"
,
validate_time
print
>>
stream
,
blanc
,
" callback_time"
,
callback_time
print
>>
stream
,
blanc
,
" callback_time"
,
callback_time
print
>>
stream
,
blanc
,
" callback_times"
,
callbacks_time
print
>>
stream
,
blanc
,
" nb_merged"
,
nb_merged
print
>>
stream
,
blanc
,
" nb_merged"
,
nb_merged
print
>>
stream
,
blanc
,
" nb_constant"
,
nb_constant
print
>>
stream
,
blanc
,
" nb_constant"
,
nb_constant
...
@@ -800,8 +812,8 @@ class LocalOptGroup(LocalOptimizer):
...
@@ -800,8 +812,8 @@ class LocalOptGroup(LocalOptimizer):
def
__str__
(
self
):
def
__str__
(
self
):
return
getattr
(
self
,
'__name__'
,
return
getattr
(
self
,
'__name__'
,
(
'<theano.gof.opt.LocalOptGroup instance>'
(
'<theano.gof.opt.LocalOptGroup instance>'
+
+
str
([
str
(
o
)
for
o
in
self
.
opts
])))
str
([
str
(
o
)
for
o
in
self
.
opts
])))
def
transform
(
self
,
node
):
def
transform
(
self
,
node
):
for
opt
in
self
.
opts
:
for
opt
in
self
.
opts
:
...
@@ -979,9 +991,9 @@ class PatternSub(LocalOptimizer):
...
@@ -979,9 +991,9 @@ class PatternSub(LocalOptimizer):
else
:
else
:
raise
TypeError
(
"The pattern to search for must start with "
raise
TypeError
(
"The pattern to search for must start with "
"a specific Op instance."
)
"a specific Op instance."
)
self
.
__doc__
=
(
self
.
__class__
.
__doc__
self
.
__doc__
=
(
self
.
__class__
.
__doc__
+
+
"
\n\n
This instance does: "
"
\n\n
This instance does: "
+
+
str
(
self
)
+
"
\n
"
)
str
(
self
)
+
"
\n
"
)
self
.
allow_multiple_clients
=
allow_multiple_clients
self
.
allow_multiple_clients
=
allow_multiple_clients
self
.
skip_identities_fn
=
skip_identities_fn
self
.
skip_identities_fn
=
skip_identities_fn
if
name
:
if
name
:
...
@@ -1275,7 +1287,8 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1275,7 +1287,8 @@ class NavigatorOptimizer(Optimizer):
except
Exception
,
e
:
except
Exception
,
e
:
if
self
.
failure_callback
is
not
None
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
e
,
self
,
self
.
failure_callback
(
e
,
self
,
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
)
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
)
return
False
return
False
else
:
else
:
raise
raise
...
@@ -1287,10 +1300,16 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1287,10 +1300,16 @@ class NavigatorOptimizer(Optimizer):
if
len
(
node
.
outputs
)
!=
len
(
replacements
):
if
len
(
node
.
outputs
)
!=
len
(
replacements
):
raise
ValueError
(
'Optimizer
%
s gave wrong number of replacements'
raise
ValueError
(
'Optimizer
%
s gave wrong number of replacements'
%
lopt
)
%
lopt
)
# None in the replacement mean that this variable isn't used
# and we want to remove it
for
r
,
rnew
in
zip
(
node
.
outputs
,
replacements
):
if
rnew
is
None
and
len
(
r
.
clients
)
>
0
:
raise
ValueError
(
"A local optimizer tried to remove a Variable that is used"
)
# If an output would be replaced by itself, no need to perform
# If an output would be replaced by itself, no need to perform
# the replacement
# the replacement
repl_pairs
=
[(
r
,
rnew
)
for
r
,
rnew
in
zip
(
node
.
outputs
,
replacements
)
repl_pairs
=
[(
r
,
rnew
)
for
r
,
rnew
in
zip
(
node
.
outputs
,
replacements
)
if
rnew
is
not
r
]
if
rnew
is
not
r
and
rnew
is
not
None
]
if
len
(
repl_pairs
)
==
0
:
if
len
(
repl_pairs
)
==
0
:
return
False
return
False
try
:
try
:
...
@@ -1513,6 +1532,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1513,6 +1532,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
max_use_abort
=
False
max_use_abort
=
False
opt_name
=
None
opt_name
=
None
global_process_count
=
{}
global_process_count
=
{}
start_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
max_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
max_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
max_use
=
max_nb_nodes
*
self
.
max_use_ratio
max_use
=
max_nb_nodes
*
self
.
max_use_ratio
...
@@ -1597,13 +1617,16 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1597,13 +1617,16 @@ class EquilibriumOptimizer(NavigatorOptimizer):
loop_process_count
.
append
(
process_count
)
loop_process_count
.
append
(
process_count
)
loop_timing
.
append
(
float
(
time
.
time
()
-
t0
))
loop_timing
.
append
(
float
(
time
.
time
()
-
t0
))
end_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
if
max_use_abort
:
if
max_use_abort
:
_logger
.
error
(
"EquilibriumOptimizer max'ed out by '
%
s'"
%
opt_name
_logger
.
error
(
"EquilibriumOptimizer max'ed out by '
%
s'"
%
opt_name
+
". You can safely raise the current threshold of "
+
". You can safely raise the current threshold of "
+
"
%
f with the theano flag 'optdb.max_use_ratio'."
%
+
"
%
f with the theano flag 'optdb.max_use_ratio'."
%
config
.
optdb
.
max_use_ratio
)
config
.
optdb
.
max_use_ratio
)
return
(
self
,
loop_timing
,
loop_process_count
,
max_nb_nodes
,
return
(
self
,
loop_timing
,
loop_process_count
,
(
start_nb_nodes
,
end_nb_nodes
,
max_nb_nodes
),
global_opt_timing
,
nb_nodes
,
time_opts
,
io_toposort_timing
)
global_opt_timing
,
nb_nodes
,
time_opts
,
io_toposort_timing
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
...
@@ -1617,15 +1640,18 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1617,15 +1640,18 @@ class EquilibriumOptimizer(NavigatorOptimizer):
@staticmethod
@staticmethod
def
print_profile
(
stream
,
prof
,
level
=
0
):
def
print_profile
(
stream
,
prof
,
level
=
0
):
(
opt
,
loop_timing
,
loop_process_count
,
max_nb_nodes
,
(
opt
,
loop_timing
,
loop_process_count
,
(
start_nb_nodes
,
end_nb_nodes
,
max_nb_nodes
),
global_opt_timing
,
nb_nodes
,
time_opts
,
io_toposort_timing
)
=
prof
global_opt_timing
,
nb_nodes
,
time_opts
,
io_toposort_timing
)
=
prof
blanc
=
(
' '
*
level
)
blanc
=
(
' '
*
level
)
print
>>
stream
,
blanc
,
"EquilibriumOptimizer"
,
print
>>
stream
,
blanc
,
"EquilibriumOptimizer"
,
print
>>
stream
,
blanc
,
getattr
(
opt
,
"name"
,
print
>>
stream
,
blanc
,
getattr
(
opt
,
"name"
,
getattr
(
opt
,
"__name__"
,
""
))
getattr
(
opt
,
"__name__"
,
""
))
print
>>
stream
,
blanc
,
" time
%.3
fs for
%
d passes,
%
d nodes max"
%
(
print
>>
stream
,
blanc
,
" time
%.3
fs for
%
d passes"
%
(
sum
(
loop_timing
),
len
(
loop_timing
),
max_nb_nodes
)
sum
(
loop_timing
),
len
(
loop_timing
))
print
>>
stream
,
blanc
,
" nb nodes (start, end, max)
%
d
%
d
%
d"
%
(
start_nb_nodes
,
end_nb_nodes
,
max_nb_nodes
)
print
>>
stream
,
blanc
,
" time io_toposort
%.3
fs"
%
sum
(
print
>>
stream
,
blanc
,
" time io_toposort
%.3
fs"
%
sum
(
io_toposort_timing
)
io_toposort_timing
)
s
=
sum
([
time_opts
[
o
]
for
o
in
opt
.
local_optimizers
])
s
=
sum
([
time_opts
[
o
]
for
o
in
opt
.
local_optimizers
])
...
...
theano/gof/toolbox.py
浏览文件 @
2fe4b0b8
...
@@ -248,10 +248,11 @@ class ReplaceValidate(History, Validator):
...
@@ -248,10 +248,11 @@ class ReplaceValidate(History, Validator):
raise
ReplacementDidntRemovedError
()
raise
ReplacementDidntRemovedError
()
class
NodeFinder
(
dict
,
Bookkeeper
):
class
NodeFinder
(
Bookkeeper
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
fgraph
=
None
self
.
fgraph
=
None
self
.
d
=
{}
def
on_attach
(
self
,
fgraph
):
def
on_attach
(
self
,
fgraph
):
if
self
.
fgraph
is
not
None
:
if
self
.
fgraph
is
not
None
:
...
@@ -273,7 +274,7 @@ class NodeFinder(dict, Bookkeeper):
...
@@ -273,7 +274,7 @@ class NodeFinder(dict, Bookkeeper):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
try
:
try
:
self
.
setdefault
(
node
.
op
,
[])
.
append
(
node
)
self
.
d
.
setdefault
(
node
.
op
,
[])
.
append
(
node
)
except
TypeError
:
# node.op is unhashable
except
TypeError
:
# node.op is unhashable
return
return
except
Exception
,
e
:
except
Exception
,
e
:
...
@@ -286,16 +287,16 @@ class NodeFinder(dict, Bookkeeper):
...
@@ -286,16 +287,16 @@ class NodeFinder(dict, Bookkeeper):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
try
:
try
:
nodes
=
self
[
node
.
op
]
nodes
=
self
.
d
[
node
.
op
]
except
TypeError
:
# node.op is unhashable
except
TypeError
:
# node.op is unhashable
return
return
nodes
.
remove
(
node
)
nodes
.
remove
(
node
)
if
not
nodes
:
if
not
nodes
:
del
self
[
node
.
op
]
del
self
.
d
[
node
.
op
]
def
query
(
self
,
fgraph
,
op
):
def
query
(
self
,
fgraph
,
op
):
try
:
try
:
all
=
self
.
get
(
op
,
[])
all
=
self
.
d
.
get
(
op
,
[])
except
TypeError
:
except
TypeError
:
raise
TypeError
(
"
%
s in unhashable and cannot be queried by the"
raise
TypeError
(
"
%
s in unhashable and cannot be queried by the"
" optimizer"
%
op
)
" optimizer"
%
op
)
...
...
theano/tensor/nnet/nnet.py
浏览文件 @
2fe4b0b8
...
@@ -353,7 +353,7 @@ class Softmax(gof.Op):
...
@@ -353,7 +353,7 @@ class Softmax(gof.Op):
x
=
tensor
.
as_tensor_variable
(
x
)
x
=
tensor
.
as_tensor_variable
(
x
)
if
x
.
type
.
ndim
not
in
(
1
,
2
)
\
if
x
.
type
.
ndim
not
in
(
1
,
2
)
\
or
x
.
type
.
dtype
not
in
tensor
.
float_dtypes
:
or
x
.
type
.
dtype
not
in
tensor
.
float_dtypes
:
raise
ValueError
(
'x must be 1-d or 2-d tensor of floats
'
)
raise
ValueError
(
'x must be 1-d or 2-d tensor of floats
. Got '
,
x
.
type
)
if
x
.
ndim
==
1
:
if
x
.
ndim
==
1
:
x
=
tensor
.
shape_padleft
(
x
,
n_ones
=
1
)
x
=
tensor
.
shape_padleft
(
x
,
n_ones
=
1
)
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
...
...
theano/tensor/opt.py
浏览文件 @
2fe4b0b8
...
@@ -915,6 +915,13 @@ class ShapeFeature(object):
...
@@ -915,6 +915,13 @@ class ShapeFeature(object):
# If no info is known on r's shape, use other_shape
# If no info is known on r's shape, use other_shape
self
.
set_shape
(
r
,
other_shape
)
self
.
set_shape
(
r
,
other_shape
)
return
return
if
(
other_r
.
owner
and
r
.
owner
and
other_r
.
owner
.
inputs
==
r
.
owner
.
inputs
and
other_r
.
owner
.
op
==
r
.
owner
.
op
):
# We are doing a merge. So the 2 shapes graph will be the
# same. This is only a speed optimization to call
# ancestors() less frequently.
return
# Merge other_shape with r_shape, giving the priority to other_shape
# Merge other_shape with r_shape, giving the priority to other_shape
merged_shape
=
[]
merged_shape
=
[]
...
@@ -928,6 +935,18 @@ class ShapeFeature(object):
...
@@ -928,6 +935,18 @@ class ShapeFeature(object):
# - Shape_i(i)(other_r);
# - Shape_i(i)(other_r);
# - Shape_i(i)(r).
# - Shape_i(i)(r).
merged_shape
.
append
(
r_shape
[
i
])
merged_shape
.
append
(
r_shape
[
i
])
elif
isinstance
(
r_shape
[
i
],
(
Constant
,
int
)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape
.
append
(
r_shape
[
i
])
elif
isinstance
(
other_shape
[
i
],
(
Constant
,
int
)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape
.
append
(
other_shape
[
i
])
elif
other_shape
[
i
]
==
r_shape
[
i
]:
# This mean the shape is equivalent
# We do not want to do the ancestor check in those cases
merged_shape
.
append
(
r_shape
[
i
])
elif
r_shape
[
i
]
in
theano
.
gof
.
graph
.
ancestors
([
other_shape
[
i
]]):
elif
r_shape
[
i
]
in
theano
.
gof
.
graph
.
ancestors
([
other_shape
[
i
]]):
# Another case where we want to use r_shape[i] is when
# Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case,
# other_shape[i] actually depends on r_shape[i]. In that case,
...
...
theano/tensor/opt_uncanonicalize.py
浏览文件 @
2fe4b0b8
...
@@ -26,54 +26,33 @@ import logging
...
@@ -26,54 +26,33 @@ import logging
_logger
=
logging
.
getLogger
(
'theano.tensor.opt'
)
_logger
=
logging
.
getLogger
(
'theano.tensor.opt'
)
from
theano
import
gof
from
theano
import
gof
from
theano.compat.python2x
import
deque
from
theano.tensor.elemwise
import
CAReduce
from
theano.tensor.elemwise
import
CAReduce
from
theano.tensor
import
basic
as
T
from
theano.tensor
import
basic
as
T
from
theano.gof.opt
import
Optimizer
from
theano.gof
import
InconsistencyError
,
toolbox
from
theano.tensor.basic
import
(
get_scalar_constant_value
,
from
theano.tensor.basic
import
(
get_scalar_constant_value
,
NotScalarConstantError
)
NotScalarConstantError
)
from
theano.tensor.opt
import
register_uncanonicalize
from
theano.tensor.opt
import
register_uncanonicalize
from
theano
import
scalar
as
scal
from
theano
import
scalar
as
scal
class
MaxAndArgmaxOptimizer
(
Optimizer
):
@register_uncanonicalize
"""Replace MaxAndArgmax by CAReduce when the argmax is not used
@gof.local_optimizer
([
T
.
_max_and_argmax
])
def
local_max_and_argmax
(
node
):
This is faster as MaxAndArgmax don't have c code and execute it
"""
in two pass
.
If we don't use the argmax, change it to a max only
.
"""
"""
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
toolbox
.
ReplaceValidate
())
def
apply
(
self
,
fgraph
):
did_something
=
True
while
did_something
:
nodelist
=
fgraph
.
toposort
()
did_something
=
False
for
node
in
nodelist
:
if
node
.
op
==
T
.
_max_and_argmax
:
if
node
.
op
==
T
.
_max_and_argmax
:
if
len
(
node
.
outputs
[
1
]
.
clients
)
==
0
:
if
len
(
node
.
outputs
[
1
]
.
clients
)
==
0
:
#MaxAndArgmax support variable axis,
#but CAReduce support only constant axis.
try
:
try
:
axis
=
get_scalar_constant_value
(
node
.
inputs
[
1
])
axis
=
get_scalar_constant_value
(
node
.
inputs
[
1
])
except
NotScalarConstantError
:
except
NotScalarConstantError
:
return
False
return
False
new
=
CAReduce
(
scal
.
maximum
,
axis
)(
node
.
inputs
[
0
])
new
=
CAReduce
(
scal
.
maximum
,
axis
)(
node
.
inputs
[
0
])
try
:
return
[
new
,
None
]
fgraph
.
replace_all_validate
(
((
node
.
outputs
[
0
],
new
),),
reason
=
self
.
__class__
.
__name__
)
did_something
=
True
break
except
InconsistencyError
,
e
:
pass
register_uncanonicalize
(
MaxAndArgmaxOptimizer
(),
name
=
'MaxAndArgmaxOptimizer'
)
@register_uncanonicalize
@register_uncanonicalize
@gof.local_optimizer
([
T
.
_shape
])
@gof.local_optimizer
([
T
.
_shape
])
...
...
theano/tensor/var.py
浏览文件 @
2fe4b0b8
...
@@ -3,7 +3,7 @@ import copy
...
@@ -3,7 +3,7 @@ import copy
import
numpy
import
numpy
import
theano
import
theano
from
theano.compat
import
PY3
from
theano.compat
import
all
,
PY3
from
theano.scalar
import
ComplexError
,
IntegerDivisionError
from
theano.scalar
import
ComplexError
,
IntegerDivisionError
from
theano.gof
import
Constant
,
Variable
from
theano.gof
import
Constant
,
Variable
from
theano.gof.utils
import
hashtype
from
theano.gof.utils
import
hashtype
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论