Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a351e3db
提交
a351e3db
authored
9月 03, 2013
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1482 from nouiz/rnade
Scan crash fix
上级
e523802f
66a9e220
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
403 行增加
和
239 行删除
+403
-239
config.txt
doc/library/config.txt
+14
-0
debugmode.py
theano/compile/debugmode.py
+9
-5
function_module.py
theano/compile/function_module.py
+1
-1
configdefaults.py
theano/configdefaults.py
+13
-1
destroyhandler.py
theano/gof/destroyhandler.py
+4
-4
fg.py
theano/gof/fg.py
+94
-79
op.py
theano/gof/op.py
+9
-2
opt.py
theano/gof/opt.py
+8
-8
toolbox.py
theano/gof/toolbox.py
+18
-14
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+0
-1
ops.py
theano/sandbox/linalg/ops.py
+2
-2
scan_utils.py
theano/sandbox/scan_module/scan_utils.py
+1
-1
basic.py
theano/scalar/basic.py
+18
-0
scan.py
theano/scan_module/scan.py
+3
-3
scan_op.py
theano/scan_module/scan_op.py
+24
-2
scan_opt.py
theano/scan_module/scan_opt.py
+88
-69
scan_utils.py
theano/scan_module/scan_utils.py
+1
-1
test_scan.py
theano/scan_module/tests/test_scan.py
+44
-25
basic.py
theano/tensor/basic.py
+2
-2
opt.py
theano/tensor/opt.py
+50
-19
没有找到文件。
doc/library/config.txt
浏览文件 @
a351e3db
...
...
@@ -482,6 +482,14 @@ import theano and print the config variable, as in:
This flag's value cannot be modified during the program execution.
.. attribute:: optimizer_verbose
Bool value: either True or False
Default: False
When True, we print on the stdout the optimization applied.
.. attribute:: nocleanup
Bool value: either True or False
...
...
@@ -630,6 +638,12 @@ import theano and print the config variable, as in:
this Op
- ``'raise'`` will raise an Exception
.. attribute:: config.compute_test_value_opt
As ``compute_test_value``, but it is the value used during Theano
optimization phase. Theano user's do not need to use this. This is
to help debug shape error in Theano optimization.
.. attribute:: config.exception_verbosity
String Value: ``'low'``, ``'high'``.
...
...
theano/compile/debugmode.py
浏览文件 @
a351e3db
...
...
@@ -1428,21 +1428,25 @@ class _VariableEquivalenceTracker(object):
self
.
reasons
=
{}
self
.
replaced_by
=
{}
self
.
event_list
=
[]
for
node
in
fgraph
.
toposort
():
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_detach
(
self
,
fgraph
):
assert
fgraph
is
self
.
fgraph
self
.
fgraph
=
None
def
on_prune
(
self
,
fgraph
,
node
):
self
.
event_list
.
append
(
_FunctionGraphEvent
(
'prune'
,
node
))
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
self
.
event_list
.
append
(
_FunctionGraphEvent
(
'prune'
,
node
,
reason
=
reason
))
#print 'PRUNING NODE', node, id(node)
assert
node
in
self
.
active_nodes
assert
node
not
in
self
.
inactive_nodes
self
.
active_nodes
.
remove
(
node
)
self
.
inactive_nodes
.
add
(
node
)
def
on_import
(
self
,
fgraph
,
node
):
self
.
event_list
.
append
(
_FunctionGraphEvent
(
'import'
,
node
))
def
on_import
(
self
,
fgraph
,
node
,
reason
):
self
.
event_list
.
append
(
_FunctionGraphEvent
(
'import'
,
node
,
reason
=
reason
))
#print 'NEW NODE', node, id(node)
assert
node
not
in
self
.
active_nodes
...
...
@@ -2114,7 +2118,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# optimize the fgraph
compute_test_value_orig
=
theano
.
config
.
compute_test_value
try
:
theano
.
config
.
compute_test_value
=
"off"
theano
.
config
.
compute_test_value
=
theano
.
config
.
compute_test_value_opt
optimizer
(
fgraph
)
theano
.
compile
.
function_module
.
insert_deepcopy
(
fgraph
,
inputs
,
...
...
theano/compile/function_module.py
浏览文件 @
a351e3db
...
...
@@ -1018,7 +1018,7 @@ class FunctionMaker(object):
compute_test_value_orig
=
theano
.
config
.
compute_test_value
add_stack_trace_on_call
=
gof
.
Op
.
add_stack_trace_on_call
try
:
theano
.
config
.
compute_test_value
=
"off"
theano
.
config
.
compute_test_value
=
theano
.
config
.
compute_test_value_opt
gof
.
Op
.
add_stack_trace_on_call
=
False
start_optimizer
=
time
.
time
()
optimizer_profile
=
optimizer
(
fgraph
)
...
...
theano/configdefaults.py
浏览文件 @
a351e3db
...
...
@@ -157,6 +157,11 @@ AddConfigVar('optimizer',
EnumStr
(
'fast_run'
,
'merge'
,
'fast_compile'
,
'None'
),
in_c_key
=
False
)
AddConfigVar
(
'optimizer_verbose'
,
"If True, we print all optimization being applied"
,
BoolParam
(
False
),
in_c_key
=
False
)
AddConfigVar
(
'on_opt_error'
,
(
"What to do when an optimization crashes: warn and skip it, raise "
"the exception, or fall into the pdb debugger."
),
...
...
@@ -379,10 +384,17 @@ AddConfigVar('compute_test_value',
"Constants, SharedVariables and the tag 'test_value' as inputs "
"to the function. This helps the user track down problems in the "
"graph before it gets optimized."
),
EnumStr
(
'off'
,
'ignore'
,
'warn'
,
'raise'
),
EnumStr
(
'off'
,
'ignore'
,
'warn'
,
'raise'
,
'pdb'
),
in_c_key
=
False
)
AddConfigVar
(
'compute_test_value_opt'
,
(
"For debugging Theano optimization only."
" Same as compute_test_value, but is used"
" during Theano optimization"
),
EnumStr
(
'off'
,
'ignore'
,
'warn'
,
'raise'
,
'pdb'
),
in_c_key
=
False
)
"""Note to developers:
Generally your exceptions should use an apply node's __str__
method when exception_verbosity == 'low'. When exception_verbosity
...
...
theano/gof/destroyhandler.py
浏览文件 @
a351e3db
...
...
@@ -380,7 +380,7 @@ if 0:
delattr
(
self
.
fgraph
,
'destroy_handler'
)
self
.
fgraph
=
None
def
on_import
(
self
,
fgraph
,
app
):
def
on_import
(
self
,
fgraph
,
app
,
reason
):
"""Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import")
...
...
@@ -410,7 +410,7 @@ if 0:
self
.
stale_droot
=
True
def
on_prune
(
self
,
fgraph
,
app
):
def
on_prune
(
self
,
fgraph
,
app
,
reason
):
"""Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#self.debug_all_apps.remove(app)
...
...
@@ -765,7 +765,7 @@ class DestroyHandler(toolbox.Bookkeeper):
delattr
(
self
.
fgraph
,
'destroy_handler'
)
self
.
fgraph
=
None
def
on_import
(
self
,
fgraph
,
app
):
def
on_import
(
self
,
fgraph
,
app
,
reason
):
"""Add Apply instance to set which must be computed"""
if
app
in
self
.
debug_all_apps
:
raise
ProtocolError
(
"double import"
)
...
...
@@ -795,7 +795,7 @@ class DestroyHandler(toolbox.Bookkeeper):
self
.
stale_droot
=
True
def
on_prune
(
self
,
fgraph
,
app
):
def
on_prune
(
self
,
fgraph
,
app
,
reason
):
"""Remove Apply instance from set which must be computed"""
if
app
not
in
self
.
debug_all_apps
:
raise
ProtocolError
(
"prune without import"
)
self
.
debug_all_apps
.
remove
(
app
)
...
...
theano/gof/fg.py
浏览文件 @
a351e3db
...
...
@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception
types that it can raise
"""
import
sys
from
theano.gof
import
graph
from
theano.gof
import
utils
from
theano.gof
import
toolbox
...
...
@@ -16,6 +17,7 @@ NullType = None
from
theano.gof.python25
import
OrderedDict
from
theano.misc.ordered_set
import
OrderedSet
class
InconsistencyError
(
Exception
):
"""
This exception should be thrown by listeners to FunctionGraph when the
...
...
@@ -82,7 +84,8 @@ class FunctionGraph(utils.object2):
# so I probably am) this should be a set.
self
.
_features
=
[]
# All apply nodes in the subgraph defined by inputs and outputs are cached in this field
# All apply nodes in the subgraph defined by inputs and
# outputs are cached in this field
self
.
apply_nodes
=
set
()
# Ditto for variable nodes
...
...
@@ -104,7 +107,7 @@ class FunctionGraph(utils.object2):
self
.
__setup_r__
(
input
)
self
.
variables
.
add
(
input
)
self
.
__import_r__
(
outputs
)
self
.
__import_r__
(
outputs
,
reason
=
"init"
)
for
i
,
output
in
enumerate
(
outputs
):
output
.
clients
.
append
((
'output'
,
i
))
...
...
@@ -112,12 +115,12 @@ class FunctionGraph(utils.object2):
self
.
variable_locks
=
{}
self
.
profile
=
None
### Setup a Variable ###
def
__setup_r__
(
self
,
r
):
# sets up r so it belongs to this fgraph
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
self
:
if
(
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
self
):
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
r
.
fgraph
=
self
r
.
clients
=
[]
...
...
@@ -165,13 +168,13 @@ class FunctionGraph(utils.object2):
self
.
inputs
=
None
self
.
outputs
=
None
### clients ###
def
clients
(
self
,
r
):
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Tell differently, a list of (node,i) such that each node have r as input at index i.
Tell differently, a list of (node,i) such that each node have
r as input at index i.
"""
return
r
.
clients
...
...
@@ -184,12 +187,15 @@ class FunctionGraph(utils.object2):
"""
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
print
>>
sys
.
stderr
,
'ERROR: clients intersect!'
print
>>
sys
.
stderr
,
' RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
r
.
clients
]
print
>>
sys
.
stderr
,
' NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
new_clients
]
print
>>
sys
.
stderr
,
' RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
r
.
clients
]
print
>>
sys
.
stderr
,
' NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
new_clients
]
assert
not
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
))
r
.
clients
+=
new_clients
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
):
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
,
reason
=
None
):
""" WRITEME
r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
...
...
@@ -202,18 +208,16 @@ class FunctionGraph(utils.object2):
print
>>
sys
.
stderr
,
'ERROR: DUPLICATE CLIENT ENTRY...'
print
>>
sys
.
stderr
,
' ENTRY'
,
repr
(
entry
),
type
(
entry
[
0
])
print
>>
sys
.
stderr
,
' CLIENTS'
,
repr
(
r
.
clients
)
assert
entry
not
in
r
.
clients
# an op,i pair should be unique
assert
entry
not
in
r
.
clients
# an op,i pair should be unique
if
not
r
.
clients
:
if
prune
:
self
.
__prune_r__
([
r
])
self
.
__prune_r__
([
r
]
,
reason
)
return
False
return
True
return
False
### import ###
def
__import_r__
(
self
,
variables
):
def
__import_r__
(
self
,
variables
,
reason
):
global
NullType
if
NullType
is
None
:
from
null_type
import
NullType
...
...
@@ -222,17 +226,18 @@ class FunctionGraph(utils.object2):
for
apply_node
in
[
r
.
owner
for
r
in
variables
if
r
.
owner
is
not
None
]:
if
apply_node
not
in
r_owner_done
:
r_owner_done
.
add
(
apply_node
)
self
.
__import__
(
apply_node
)
self
.
__import__
(
apply_node
,
reason
=
reason
)
for
r
in
variables
:
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
:
if
isinstance
(
r
.
type
,
NullType
):
raise
TypeError
(
"Computation graph contains a NaN. "
+
r
.
type
.
why_null
)
if
isinstance
(
r
.
type
,
NullType
):
raise
TypeError
(
"Computation graph contains a NaN. "
+
r
.
type
.
why_null
)
raise
MissingInputError
(
"Undeclared input"
,
r
)
if
not
getattr
(
r
,
'fgraph'
,
None
)
is
self
:
self
.
__setup_r__
(
r
)
self
.
variables
.
add
(
r
)
def
__import__
(
self
,
apply_node
,
check
=
Tru
e
):
def
__import__
(
self
,
apply_node
,
check
=
True
,
reason
=
Non
e
):
node
=
apply_node
# We import the nodes in topological order. We only are interested
...
...
@@ -248,7 +253,9 @@ class FunctionGraph(utils.object2):
for
r
in
node
.
inputs
:
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
self
:
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
:
if
(
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
):
#Verbose error message
#Show a complete chain of variables from the missing input to an output
...
...
@@ -328,20 +335,18 @@ class FunctionGraph(utils.object2):
self
.
variables
.
add
(
input
)
self
.
__add_clients__
(
input
,
[(
node
,
i
)])
assert
node
.
fgraph
is
self
self
.
execute_callbacks
(
'on_import'
,
node
)
self
.
execute_callbacks
(
'on_import'
,
node
,
reason
)
### prune ###
def
__prune_r__
(
self
,
variables
):
def
__prune_r__
(
self
,
variables
,
reason
=
None
):
# Prunes the owners of the variables.
for
node
in
set
(
r
.
owner
for
r
in
variables
if
r
.
owner
is
not
None
):
self
.
__prune__
(
node
)
self
.
__prune__
(
node
,
reason
)
for
r
in
variables
:
if
not
r
.
clients
and
r
in
self
.
variables
:
self
.
variables
.
remove
(
r
)
def
__prune__
(
self
,
apply_node
):
def
__prune__
(
self
,
apply_node
,
reason
=
None
):
node
=
apply_node
if
node
not
in
self
.
apply_nodes
:
raise
Exception
(
"
%
s does not belong to this FunctionGraph and cannot be pruned."
%
node
)
...
...
@@ -356,16 +361,13 @@ class FunctionGraph(utils.object2):
return
self
.
apply_nodes
.
remove
(
node
)
self
.
variables
.
difference_update
(
node
.
outputs
)
self
.
execute_callbacks
(
'on_prune'
,
node
)
self
.
execute_callbacks
(
'on_prune'
,
node
,
reason
)
for
i
,
input
in
enumerate
(
node
.
inputs
):
self
.
__remove_clients__
(
input
,
[(
node
,
i
)])
self
.
__remove_clients__
(
input
,
[(
node
,
i
)]
,
reason
=
reason
)
#self.__prune_r__(node.inputs)
### change input ###
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
"""WRITEME
Changes node.inputs[i] to new_r.
...
...
@@ -381,42 +383,45 @@ class FunctionGraph(utils.object2):
r
=
self
.
outputs
[
i
]
if
not
r
.
type
==
new_r
.
type
:
raise
TypeError
(
"The type of the replacement must be the"
" same as the type of the original Variable."
,
r
,
new_r
)
" same as the type of the original Variable."
,
r
,
new_r
)
self
.
outputs
[
i
]
=
new_r
else
:
if
node
.
fgraph
is
not
self
:
raise
Exception
(
"Cannot operate on
%
s because it does not"
" belong to this FunctionGraph"
%
node
)
" belong to this FunctionGraph"
%
node
)
r
=
node
.
inputs
[
i
]
if
not
r
.
type
==
new_r
.
type
:
raise
TypeError
(
"The type of the replacement must be the"
" same as the type of the original Variable."
,
r
,
new_r
)
" same as the type of the original Variable."
,
r
,
new_r
)
node
.
inputs
[
i
]
=
new_r
if
r
is
new_r
:
return
self
.
__import_r__
([
new_r
])
self
.
__import_r__
([
new_r
]
,
reason
=
reason
)
self
.
__add_clients__
(
new_r
,
[(
node
,
i
)])
prune
=
self
.
__remove_clients__
(
r
,
[(
node
,
i
)],
False
)
# Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the
# transaction will be reverted later.
self
.
execute_callbacks
(
'on_change_input'
,
node
,
i
,
r
,
new_r
,
reason
=
reason
)
self
.
execute_callbacks
(
'on_change_input'
,
node
,
i
,
r
,
new_r
,
reason
=
reason
)
if
prune
:
self
.
__prune_r__
([
r
])
self
.
__prune_r__
([
r
],
reason
=
reason
)
### replace ###
def
replace
(
self
,
r
,
new_r
,
reason
=
None
):
def
replace
(
self
,
r
,
new_r
,
reason
=
None
,
verbose
=
None
):
""" WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead.
"""
if
verbose
is
None
:
verbose
=
config
.
optimizer_verbose
if
verbose
:
print
reason
,
r
,
new_r
if
r
.
fgraph
is
not
self
:
raise
Exception
(
"Cannot replace
%
s because it does not belong to this FunctionGraph"
%
r
,
str
(
reason
))
if
not
r
.
type
==
new_r
.
type
:
...
...
@@ -426,7 +431,7 @@ class FunctionGraph(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops
return
for
node
,
i
in
list
(
r
.
clients
):
# copy the client list for iteration
for
node
,
i
in
list
(
r
.
clients
):
# copy the client list for iteration
assert
(
node
==
'output'
and
self
.
outputs
[
i
]
is
r
)
or
(
node
.
inputs
[
i
]
is
r
)
self
.
change_input
(
node
,
i
,
new_r
,
reason
=
reason
)
...
...
@@ -440,11 +445,9 @@ class FunctionGraph(utils.object2):
for
r
,
new_r
in
pairs
:
self
.
replace
(
r
,
new_r
,
reason
=
reason
)
def
extend
(
self
,
feature
):
warnings
.
warn
(
"FunctionGraph.extend is deprecatd. It has been "
"renamed to FunctionGraph.attach_feature"
)
"renamed to FunctionGraph.attach_feature"
)
return
self
.
attach_feature
(
feature
)
def
attach_feature
(
self
,
feature
):
...
...
@@ -455,7 +458,7 @@ class FunctionGraph(utils.object2):
# Filter out literally identical features
if
feature
in
self
.
_features
:
return
# the feature is already present
return
# the feature is already present
# Filter out functionally identical features.
# Features may use their on_attach method to raise
...
...
@@ -481,7 +484,9 @@ class FunctionGraph(utils.object2):
"""WRITEME
Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method is defined.
Calls feature.on_detach(function_graph) if an on_detach method
is defined.
"""
try
:
self
.
_features
.
remove
(
feature
)
...
...
@@ -491,9 +496,7 @@ class FunctionGraph(utils.object2):
if
detach
is
not
None
:
detach
(
self
)
### callback utils ###
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
"""WRITEME
Calls
...
...
@@ -518,7 +521,6 @@ class FunctionGraph(utils.object2):
else
:
raise
def
collect_callbacks
(
self
,
name
,
*
args
):
"""WRITEME
Returns a dictionary d such that:
...
...
@@ -534,9 +536,7 @@ class FunctionGraph(utils.object2):
d
[
feature
]
=
fn
(
*
args
)
return
d
### misc ###
def
toposort
(
self
):
"""WRITEME
Returns an ordering of the graph's Apply nodes such that:
...
...
@@ -552,8 +552,8 @@ class FunctionGraph(utils.object2):
if
len
(
self
.
apply_nodes
)
<
2
:
# optimization
# when there are 0 or 1 nodes, no sorting is necessary
# This special case happens a lot because the OpWiseCLinker
produces
# 1-element graphs.
# This special case happens a lot because the OpWiseCLinker
#
produces
1-element graphs.
return
list
(
self
.
apply_nodes
)
fg
=
self
...
...
@@ -568,30 +568,33 @@ class FunctionGraph(utils.object2):
Return dict d s.t. d[node] is a list of nodes that must be evaluated
before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that all
clients of any destroyed inputs have already computed their outputs.
This is used primarily by the destroy_handler feature to ensure that
all clients of any destroyed inputs have already computed their
outputs.
:note: This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
"""
ords
=
OrderedDict
()
ords
=
OrderedDict
()
assert
isinstance
(
self
.
_features
,
list
)
for
feature
in
self
.
_features
:
if
hasattr
(
feature
,
'orderings'
):
orderings
=
feature
.
orderings
(
self
)
if
not
isinstance
(
orderings
,
OrderedDict
):
raise
TypeError
(
"Non-deterministic return value from "
\
+
str
(
feature
.
orderings
)
\
+
". Nondeterministic object is "
+
str
(
orderings
))
raise
TypeError
(
"Non-deterministic return value from "
+
str
(
feature
.
orderings
)
+
". Nondeterministic object is "
+
str
(
orderings
))
for
node
,
prereqs
in
orderings
.
items
():
if
not
isinstance
(
prereqs
,
(
list
,
OrderedSet
)):
raise
TypeError
(
"prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic."
)
raise
TypeError
(
"prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic."
)
ords
.
setdefault
(
node
,
[])
.
extend
(
prereqs
)
# eliminate duplicate prereqs
for
(
node
,
prereqs
)
in
ords
.
items
():
for
(
node
,
prereqs
)
in
ords
.
items
():
ords
[
node
]
=
list
(
OrderedSet
(
prereqs
))
return
ords
...
...
@@ -624,34 +627,48 @@ class FunctionGraph(utils.object2):
if
self
.
apply_nodes
!=
nodes
:
missing
=
nodes
.
difference
(
self
.
apply_nodes
)
excess
=
self
.
apply_nodes
.
difference
(
nodes
)
raise
Exception
(
"The nodes are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
raise
Exception
(
"The nodes are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
for
node
in
nodes
:
if
node
.
fgraph
is
not
self
:
raise
Exception
(
"Node should belong to the FunctionGraph."
,
node
)
raise
Exception
(
"Node should belong to the FunctionGraph."
,
node
)
for
i
,
variable
in
enumerate
(
node
.
inputs
):
if
variable
.
fgraph
is
not
self
:
raise
Exception
(
"Input of node should belong to the FunctionGraph."
,
variable
,
(
node
,
i
))
raise
Exception
(
"Input of node should belong to the FunctionGraph."
,
variable
,
(
node
,
i
))
if
(
node
,
i
)
not
in
variable
.
clients
:
raise
Exception
(
"Inconsistent clients list."
,
(
node
,
i
),
variable
.
clients
)
raise
Exception
(
"Inconsistent clients list."
,
(
node
,
i
),
variable
.
clients
)
variables
=
set
(
graph
.
variables
(
self
.
inputs
,
self
.
outputs
))
if
set
(
self
.
variables
)
!=
variables
:
missing
=
variables
.
difference
(
self
.
variables
)
excess
=
self
.
variables
.
difference
(
variables
)
raise
Exception
(
"The variables are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
raise
Exception
(
"The variables are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
for
variable
in
variables
:
if
variable
.
owner
is
None
and
variable
not
in
self
.
inputs
and
not
isinstance
(
variable
,
graph
.
Constant
):
if
(
variable
.
owner
is
None
and
variable
not
in
self
.
inputs
and
not
isinstance
(
variable
,
graph
.
Constant
)):
raise
Exception
(
"Undeclared input."
,
variable
)
if
variable
.
fgraph
is
not
self
:
raise
Exception
(
"Variable should belong to the FunctionGraph."
,
variable
)
raise
Exception
(
"Variable should belong to the FunctionGraph."
,
variable
)
for
node
,
i
in
variable
.
clients
:
if
node
==
'output'
:
if
self
.
outputs
[
i
]
is
not
variable
:
raise
Exception
(
"Inconsistent clients list."
,
variable
,
self
.
outputs
[
i
])
raise
Exception
(
"Inconsistent clients list."
,
variable
,
self
.
outputs
[
i
])
continue
if
node
not
in
nodes
:
raise
Exception
(
"Client not in FunctionGraph."
,
variable
,
(
node
,
i
))
raise
Exception
(
"Client not in FunctionGraph."
,
variable
,
(
node
,
i
))
if
node
.
inputs
[
i
]
is
not
variable
:
raise
Exception
(
"Inconsistent clients list."
,
variable
,
node
.
inputs
[
i
])
raise
Exception
(
"Inconsistent clients list."
,
variable
,
node
.
inputs
[
i
])
def
__str__
(
self
):
return
"[
%
s]"
%
", "
.
join
(
graph
.
as_string
(
self
.
inputs
,
self
.
outputs
))
...
...
@@ -659,9 +676,7 @@ class FunctionGraph(utils.object2):
def
__repr__
(
self
):
return
self
.
__str__
()
### clone ###
def
clone
(
self
):
"""WRITEME"""
return
self
.
clone_get_equiv
()[
0
]
...
...
@@ -671,7 +686,7 @@ class FunctionGraph(utils.object2):
equiv
=
graph
.
clone_get_equiv
(
self
.
inputs
,
self
.
outputs
)
self
.
check_integrity
()
e
=
FunctionGraph
([
equiv
[
i
]
for
i
in
self
.
inputs
],
[
equiv
[
o
]
for
o
in
self
.
outputs
])
[
equiv
[
o
]
for
o
in
self
.
outputs
])
e
.
check_integrity
()
for
feature
in
self
.
_features
:
e
.
attach_feature
(
feature
)
...
...
theano/gof/op.py
浏览文件 @
a351e3db
...
...
@@ -13,6 +13,7 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__
=
"restructuredtext en"
import
logging
import
sys
import
warnings
import
theano
...
...
@@ -408,6 +409,9 @@ class PureOp(object):
elif
config
.
compute_test_value
==
'ignore'
:
# silently skip test
run_perform
=
False
elif
config
.
compute_test_value
==
'pdb'
:
import
pdb
pdb
.
post_mortem
(
sys
.
exc_info
()[
2
])
else
:
raise
ValueError
(
'
%
s is invalid for option config.compute_Test_value'
%
config
.
compute_test_value
)
...
...
@@ -638,8 +642,11 @@ def get_test_value(v):
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value.
"""
v_tensor
=
theano
.
tensor
.
as_tensor_variable
(
v
)
return
PureOp
.
_get_test_value
(
v_tensor
)
if
not
isinstance
(
v
,
graph
.
Variable
):
v_var
=
theano
.
tensor
.
as_tensor_variable
(
v
)
else
:
v_var
=
v
return
PureOp
.
_get_test_value
(
v_var
)
def
missing_test_message
(
msg
):
...
...
theano/gof/opt.py
浏览文件 @
a351e3db
...
...
@@ -421,7 +421,7 @@ class MergeFeature(object):
self
.
blacklist
=
[]
for
node
in
fgraph
.
toposort
():
self
.
on_import
(
fgraph
,
node
)
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
# If inputs to node change, it is not guaranteed that it is distinct
...
...
@@ -433,14 +433,14 @@ class MergeFeature(object):
if
isinstance
(
new_r
,
graph
.
Constant
):
self
.
process_constant
(
fgraph
,
new_r
)
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
for
c
in
node
.
inputs
:
if
isinstance
(
c
,
graph
.
Constant
):
self
.
process_constant
(
fgraph
,
c
)
self
.
process_node
(
fgraph
,
node
)
def
on_prune
(
self
,
fgraph
,
node
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
self
.
nodes_seen
.
discard
(
node
)
for
c
in
node
.
inputs
:
if
isinstance
(
c
,
graph
.
Constant
)
and
(
len
(
c
.
clients
)
<=
1
):
...
...
@@ -548,7 +548,7 @@ class MergeOptimizer(Optimizer):
except
InconsistencyError
:
success
=
False
fgraph
.
merge_feature
.
blacklist
.
append
(
(
pairs
[
0
][
0
]
.
owner
,
pairs
[
0
][
1
]
.
owner
))
(
pairs
[
0
][
0
]
.
owner
,
pairs
[
0
][
1
]
.
owner
))
if
success
:
break
...
...
@@ -1027,7 +1027,7 @@ class PatternSub(LocalOptimizer):
else
:
return
pattern
.
clone
()
u
=
match
(
self
.
in_pattern
,
node
.
out
,
unify
.
Unification
(),
True
,
self
.
pdb
)
self
.
pdb
)
if
u
:
p
=
self
.
out_pattern
new
=
build
(
p
,
u
)
...
...
@@ -1165,10 +1165,10 @@ class NavigatorOptimizer(Optimizer):
class
Updater
:
if
importer
is
not
None
:
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
importer
(
node
)
if
pruner
is
not
None
:
def
on_prune
(
self
,
fgraph
,
node
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
pruner
(
node
)
if
chin
is
not
None
:
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
...
...
@@ -1357,7 +1357,7 @@ class ChangeTracker:
def
__init__
(
self
):
self
.
changed
=
False
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
self
.
changed
=
True
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
...
...
theano/gof/toolbox.py
浏览文件 @
a351e3db
import
sys
import
time
from
theano
import
config
from
theano.gof.python25
import
partial
from
theano.gof.python25
import
OrderedDict
from
theano.gof
import
graph
class
AlreadyThere
(
Exception
):
"""Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical
...
...
@@ -57,7 +56,7 @@ class Feature(object):
functionality that it installed into the function_graph.
"""
def
on_import
(
self
,
function_graph
,
node
):
def
on_import
(
self
,
function_graph
,
node
,
reason
):
"""
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
...
...
@@ -66,7 +65,7 @@ class Feature(object):
you should do this by implementing on_attach.
"""
def
on_prune
(
self
,
function_graph
,
node
):
def
on_prune
(
self
,
function_graph
,
node
,
reason
):
"""
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
...
...
@@ -98,11 +97,11 @@ class Bookkeeper(Feature):
def
on_attach
(
self
,
fgraph
):
for
node
in
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
):
self
.
on_import
(
fgraph
,
node
)
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_detach
(
self
,
fgraph
):
for
node
in
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
):
self
.
on_prune
(
fgraph
,
node
)
self
.
on_prune
(
fgraph
,
node
,
'Bookkeeper.detach'
)
class
History
(
Feature
):
...
...
@@ -199,11 +198,14 @@ class ReplaceValidate(History, Validator):
def
replace_validate
(
self
,
fgraph
,
r
,
new_r
,
reason
=
None
):
self
.
replace_all_validate
(
fgraph
,
[(
r
,
new_r
)],
reason
=
reason
)
def
replace_all_validate
(
self
,
fgraph
,
replacements
,
reason
=
None
):
def
replace_all_validate
(
self
,
fgraph
,
replacements
,
reason
=
None
,
verbose
=
None
):
chk
=
fgraph
.
checkpoint
()
if
verbose
is
None
:
verbose
=
config
.
optimizer_verbose
for
r
,
new_r
in
replacements
:
try
:
fgraph
.
replace
(
r
,
new_r
,
reason
=
reason
)
fgraph
.
replace
(
r
,
new_r
,
reason
=
reason
,
verbose
=
False
)
except
Exception
,
e
:
if
(
'The type of the replacement must be the same'
not
in
str
(
e
)
and
'does not belong to this FunctionGraph'
not
in
str
(
e
)):
...
...
@@ -219,6 +221,8 @@ class ReplaceValidate(History, Validator):
except
Exception
,
e
:
fgraph
.
revert
(
chk
)
raise
if
verbose
:
print
reason
,
r
,
new_r
return
chk
def
replace_all_validate_remove
(
self
,
fgraph
,
replacements
,
...
...
@@ -267,7 +271,7 @@ class NodeFinder(dict, Bookkeeper):
del
fgraph
.
get_nodes
Bookkeeper
.
on_detach
(
self
,
fgraph
)
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
try
:
self
.
setdefault
(
node
.
op
,
[])
.
append
(
node
)
except
TypeError
:
# node.op is unhashable
...
...
@@ -280,7 +284,7 @@ class NodeFinder(dict, Bookkeeper):
print
>>
sys
.
stderr
,
'OFFENDING node not hashable'
raise
e
def
on_prune
(
self
,
fgraph
,
node
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
try
:
nodes
=
self
[
node
.
op
]
except
TypeError
:
# node.op is unhashable
...
...
@@ -312,13 +316,13 @@ class PrintListener(Feature):
if
self
.
active
:
print
"-- detaching from: "
,
fgraph
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
print
"-- importing:
%
s
"
%
node
print
"-- importing:
%
s
, reason:
%
s"
%
(
node
,
reason
)
def
on_prune
(
self
,
fgraph
,
node
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
print
"-- pruning:
%
s
"
%
node
print
"-- pruning:
%
s
, reason:
%
s"
%
(
node
,
reason
)
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
if
self
.
active
:
...
...
theano/sandbox/cuda/basic_ops.py
浏览文件 @
a351e3db
...
...
@@ -2953,7 +2953,6 @@ class GpuJoin(tensor.Join, GpuOp):
axis
=
inputs
[
0
]
n_cndas
=
len
(
inputs
[
1
:])
input_1
=
inputs
[
1
]
axis
=
inputs
[
0
]
fail
=
sub
[
'fail'
]
out
=
out_
[
0
]
...
...
theano/sandbox/linalg/ops.py
浏览文件 @
a351e3db
...
...
@@ -137,9 +137,9 @@ class HintsFeature(object):
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self
.
hints
=
{}
for
node
in
fgraph
.
toposort
():
self
.
on_import
(
fgraph
,
node
)
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
node
.
outputs
[
0
]
in
self
.
hints
:
# this is a revert, not really an import
for
r
in
node
.
outputs
+
node
.
inputs
:
...
...
theano/sandbox/scan_module/scan_utils.py
浏览文件 @
a351e3db
...
...
@@ -338,7 +338,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph
=
None
shape_feature
.
on_import
(
dummy_fgraph
,
out
.
owner
)
shape_feature
.
on_import
(
dummy_fgraph
,
out
.
owner
,
reason
=
"dummy"
)
ret
=
[]
for
o
in
outs
:
...
...
theano/scalar/basic.py
浏览文件 @
a351e3db
...
...
@@ -183,6 +183,24 @@ class Scalar(Type):
def
dtype_specs
(
self
):
try
:
# To help debug dtype/typenum problem, here is code to get
# the list of numpy typenum. This list change between 32
# and 64 bit platform and probably also also between
# Windows and Linux.
# NOTE: equivalent type on a platform can have different typenum.
# This is the source of all dtype/typenum problem found up to
# now, as Theano always expect the exact typenum that
# correspond to our supported dtype.
"""
for dtype in ['int8', 'uint8', 'short', 'ushort', 'intc', 'uintc',
'longlong', 'ulonglong', 'single', 'double',
'longdouble', 'csingle', 'cdouble', 'clongdouble',
'float32', 'float64', 'int8', 'int16', 'int32',
'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'float', 'double',
'int', 'uint']:
print dtype, np.zeros(1, dtype=dtype).dtype.num
"""
return
{
# dtype: (py_type, c_type, cls_name)
'float32'
:
(
numpy
.
float32
,
'npy_float32'
,
'Float32'
),
'float64'
:
(
numpy
.
float64
,
'npy_float64'
,
'Float64'
),
...
...
theano/scan_module/scan.py
浏览文件 @
a351e3db
...
...
@@ -101,7 +101,7 @@ def scan(fn,
The order of the sequences is the same as the one in the list
`sequences` given to scan. The order of the outputs is the same
as the order of ``output_info``. For any sequence or output the
as the order of ``output
s
_info``. For any sequence or output the
order of the time slices is the same as the one in which they have
been given as taps. For example if one writes the following :
...
...
@@ -262,7 +262,7 @@ def scan(fn,
outputs will have *0 rows*. If the value is negative, ``scan``
will run backwards in time. If the ``go_backwards`` flag is already
set and also ``n_steps`` is negative, ``scan`` will run forward
in time. If n
stpe
s is not provided, ``scan`` will figure
in time. If n
_step
s is not provided, ``scan`` will figure
out the amount of steps it should run given its input sequences.
...
...
@@ -817,7 +817,7 @@ def scan(fn,
if
as_while
:
tmp_dummy_f_outs
-=
1
if
not
(
tmp_dummy_f_outs
==
n_outs
or
outs_info
==
[]):
raise
ValueError
(
'Please provide None as output_info for '
raise
ValueError
(
'Please provide None as output
s
_info for '
'any output that does not feed back into '
'scan (i.e. it behaves like a map) '
)
...
...
theano/scan_module/scan_op.py
浏览文件 @
a351e3db
...
...
@@ -1581,8 +1581,30 @@ class Scan(PureOp):
if
not
isinstance
(
x
.
type
,
DisconnectedType
):
outer_inp_seqs
.
append
(
x
[::
-
1
])
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_mitsot_outs
(
outs
)]
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_sitsot_outs
(
outs
)]
if
hasattr
(
inputs
[
0
]
.
tag
,
'test_value'
):
# Here we tests that the new scan input sequence all have
# the same shape[0]. This is a properties that the scan()
# fct add and we want to keep it for all Scan op. This is
# used in T_Scan.test_grad_multiple_outs_taps to test
# that.
for
taps
,
x
in
zip
(
self
.
mitsot_taps
(),
self
.
outer_mitsot_outs
(
outs
)):
mintap
=
numpy
.
min
(
taps
)
if
hasattr
(
x
[::
-
1
][:
mintap
],
'test_value'
):
assert
(
x
[::
-
1
][:
mintap
]
.
tag
.
test_value
.
shape
[
0
]
==
inputs
[
0
]
.
tag
.
test_value
)
for
x
in
self
.
outer_sitsot_outs
(
outs
):
if
hasattr
(
x
[::
-
1
][:
-
1
]
.
tag
,
'test_value'
):
assert
(
x
[::
-
1
][:
-
1
]
.
tag
.
test_value
.
shape
[
0
]
==
inputs
[
0
]
.
tag
.
test_value
)
for
x
in
self
.
outer_nitsot_outs
(
outs
):
if
hasattr
(
x
[::
-
1
]
.
tag
,
'test_value'
):
assert
(
x
[::
-
1
]
.
tag
.
test_value
.
shape
[
0
]
==
inputs
[
0
]
.
tag
.
test_value
)
outer_inp_seqs
+=
[
x
[::
-
1
][:
numpy
.
min
(
taps
)]
for
taps
,
x
in
zip
(
self
.
mitsot_taps
(),
self
.
outer_mitsot_outs
(
outs
))]
outer_inp_seqs
+=
[
x
[::
-
1
][:
-
1
]
for
x
in
self
.
outer_sitsot_outs
(
outs
)]
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_nitsot_outs
(
outs
)]
inner_inp_seqs
=
self
.
inner_seqs
(
self_inputs
)
...
...
theano/scan_module/scan_opt.py
浏览文件 @
a351e3db
...
...
@@ -66,7 +66,7 @@ def remove_constants_and_unused_inputs_scan(node):
# We only need to take care of sequences and other arguments
st
=
op
.
n_seqs
st
+=
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
op
.
tap_array
[:(
op
.
n_mit_mot
+
op
.
n_mit_sot
)]]))
op
.
tap_array
[:(
op
.
n_mit_mot
+
op
.
n_mit_sot
)]]))
st
+=
op
.
n_sit_sot
st
+=
op
.
n_shared_outs
op_ins
,
op_outs
=
scan_utils
.
reconstruct_graph
(
op
.
inputs
,
op
.
outputs
)
...
...
@@ -105,8 +105,8 @@ def remove_constants_and_unused_inputs_scan(node):
elif
op_ins
[
idx
]
in
all_ins
:
# Check for identical other sequence
identical_seqs
=
[
x
for
x
in
nw_outer
if
scan_utils
.
equal_computations
(
[
x
],
[
node
.
inputs
[
idx
+
1
]])]
if
scan_utils
.
equal_computations
(
[
x
],
[
node
.
inputs
[
idx
+
1
]])]
if
identical_seqs
:
index
=
node
.
inputs
.
index
(
identical_seqs
[
0
])
-
1
givens
[
op_ins
[
idx
]]
=
op_ins
[
index
]
...
...
@@ -144,7 +144,7 @@ def remove_constants_and_unused_inputs_scan(node):
nw_info
[
'n_seqs'
]
=
nw_n_seqs
# DEBUG CHECK
nwScan
=
scan_op
.
Scan
(
nw_inner
,
op_outs
,
nw_info
)
nw_outs
=
nwScan
.
make_node
(
*
nw_outer
)
.
outputs
nw_outs
=
nwScan
(
*
nw_outer
,
**
dict
(
return_list
=
True
))
return
nw_outs
else
:
return
False
...
...
@@ -162,7 +162,7 @@ class PushOutNonSeqScan(gof.Optimizer):
def
apply
(
self
,
fgraph
):
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
scan_op
.
Scan
)]
scan_op
.
Scan
)]
for
node
in
nodelist
:
self
.
process_node
(
fgraph
,
node
)
...
...
@@ -170,7 +170,7 @@ class PushOutNonSeqScan(gof.Optimizer):
# this flag tells if there was any change during the last iterations
changed
=
True
clean_inputs
,
clean_outputs
=
scan_utils
.
reconstruct_graph
(
node
.
op
.
inputs
,
node
.
op
.
outputs
)
node
.
op
.
inputs
,
node
.
op
.
outputs
)
local_fgraph
=
gof
.
FunctionGraph
(
clean_inputs
,
clean_outputs
)
max_iterations
=
2
*
len
(
local_fgraph
.
toposort
())
+
3
...
...
@@ -196,7 +196,7 @@ class PushOutNonSeqScan(gof.Optimizer):
if
(
numpy
.
all
([(
x
in
inner_non_seqs
)
or
(
x
.
owner
in
to_remove
)
or
isinstance
(
x
,
tensor
.
Constant
)
for
x
in
nd
.
inputs
])
and
for
x
in
nd
.
inputs
])
and
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
...
...
@@ -227,7 +227,11 @@ class PushOutNonSeqScan(gof.Optimizer):
'this on theano-users list'
),
x
)
outside_ins
=
[
x
.
type
.
filter_variable
(
y
)
for
x
,
y
in
zip
(
nd
.
inputs
,
outside_ins
)]
nw_outer_node
=
nd
.
op
.
make_node
(
*
outside_ins
)
# Do not call make_node for test_value
nw_outer_node
=
nd
.
op
(
*
outside_ins
,
**
dict
(
return_list
=
True
))[
0
]
.
owner
# Step 2. Create variables for replacements
for
idx
,
y
in
enumerate
(
nd
.
outputs
):
...
...
@@ -250,7 +254,7 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_replace_with_in
=
[]
clean_replace_with_out
=
[]
existent_nodes
=
[
nd
for
nd
in
local_fgraph
.
toposort
()
if
nd
not
in
to_remove
]
if
nd
not
in
to_remove
]
to_keep
=
[]
for
nd
in
existent_nodes
:
to_keep
+=
nd
.
inputs
...
...
@@ -270,8 +274,8 @@ class PushOutNonSeqScan(gof.Optimizer):
nw_outer
=
[]
nw_inner
=
[]
for
to_repl
,
repl_in
,
repl_out
in
zip
(
clean_to_replace
,
clean_replace_with_in
,
clean_replace_with_out
):
clean_replace_with_in
,
clean_replace_with_out
):
if
isinstance
(
repl_out
,
theano
.
Constant
):
repl_in
=
repl_out
.
clone
()
else
:
...
...
@@ -285,11 +289,15 @@ class PushOutNonSeqScan(gof.Optimizer):
op_ins
,
op_outs
=
scan_utils
.
reconstruct_graph
(
_op_ins
,
_op_outs
)
# Reconstruct node
nwScan
=
scan_op
.
Scan
(
op_ins
,
op_outs
,
op
.
info
)
nw_node
=
nwScan
.
make_node
(
*
(
node
.
inputs
+
nw_outer
))
# Do not call make_node for test_value
nw_node
=
nwScan
(
*
(
node
.
inputs
+
nw_outer
),
**
dict
(
return_list
=
True
))[
0
]
.
owner
fgraph
.
replace_all_validate_remove
(
zip
(
node
.
outputs
,
nw_node
.
outputs
),
remove
=
[
node
],
reason
=
'scan
_push_computation_out
'
)
reason
=
'scan
Op_pushout_nonseqs_ops
'
)
return
True
elif
to_keep
==
[]:
# Nothing in the inner graph should be kept
...
...
@@ -310,7 +318,7 @@ class PushOutNonSeqScan(gof.Optimizer):
fgraph
.
replace_all_validate_remove
(
replace_with
.
items
(),
remove
=
[
node
],
reason
=
'scan
_push_computation_out
'
)
reason
=
'scan
Op_pushout_nonseqs_ops
'
)
else
:
return
False
...
...
@@ -327,8 +335,8 @@ class PushOutSeqScan(gof.Optimizer):
fgraph
.
attach_feature
(
gof
.
toolbox
.
ReplaceValidate
())
def
apply
(
self
,
fgraph
):
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
scan_op
.
Scan
)]
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
scan_op
.
Scan
)]
for
node
in
nodelist
:
self
.
process_node
(
fgraph
,
node
)
...
...
@@ -336,7 +344,7 @@ class PushOutSeqScan(gof.Optimizer):
# this flag tells if there was any change during the last iterations
changed
=
True
clean_inputs
,
clean_outputs
=
scan_utils
.
reconstruct_graph
(
node
.
op
.
inputs
,
node
.
op
.
outputs
)
node
.
op
.
inputs
,
node
.
op
.
outputs
)
local_fgraph
=
gof
.
FunctionGraph
(
clean_inputs
,
clean_outputs
)
max_iterations
=
2
*
len
(
local_fgraph
.
toposort
())
+
3
...
...
@@ -361,12 +369,12 @@ class PushOutSeqScan(gof.Optimizer):
for
nd
in
local_fgraph
.
toposort
():
if
(
isinstance
(
nd
.
op
,
theano
.
tensor
.
Elemwise
)
and
numpy
.
all
([(
x
in
inner_non_seqs
)
or
(
x
.
owner
in
to_remove
)
or
isinstance
(
x
,
tensor
.
Constant
)
or
(
x
in
inner_seqs
)
for
x
in
nd
.
inputs
])
and
not
nd
in
to_remove
):
numpy
.
all
([(
x
in
inner_non_seqs
)
or
(
x
.
owner
in
to_remove
)
or
isinstance
(
x
,
tensor
.
Constant
)
or
(
x
in
inner_seqs
)
for
x
in
nd
.
inputs
])
and
not
nd
in
to_remove
):
to_remove
.
append
(
nd
)
outside_ins
=
[]
for
x
in
nd
.
inputs
:
...
...
@@ -376,18 +384,21 @@ class PushOutSeqScan(gof.Optimizer):
elif
x
in
inner_seqs
:
outside_ins
+=
[
outer_seqs
[
inner_seqs
.
index
(
x
)]]
elif
x
in
to_replace
:
outside_ins
+=
[
replace_with_out
[
\
to_replace
.
index
(
x
)]]
outside_ins
+=
[
replace_with_out
[
to_replace
.
index
(
x
)]]
elif
isinstance
(
x
,
theano
.
Constant
):
outside_ins
+=
[
x
.
clone
()]
else
:
raise
Exception
(
(
'Error in the `scan_pushout_
non_
seq_'
(
'Error in the `scan_pushout_seq_'
'operations`. The optimization tries '
'to move some computation fron scan '
'which is not allowed to move. Report '
'this on theano-users list'
),
x
)
nw_outer_node
=
nd
.
op
.
make_node
(
*
outside_ins
)
# Do not call make_node for test_value
nw_outer_node
=
nd
.
op
(
*
outside_ins
,
**
dict
(
return_list
=
True
))[
0
]
.
owner
# Step 2. Create variables for replacements
for
idx
,
y
in
enumerate
(
nd
.
outputs
):
...
...
@@ -420,10 +431,15 @@ class PushOutSeqScan(gof.Optimizer):
to_replace
+=
[
y
]
replace_with_in
+=
[
y_place_holder
]
replace_with_out
+=
[
new_outer
]
if
hasattr
(
new_outer
.
tag
,
"test_value"
):
new_sh
=
new_outer
.
tag
.
test_value
.
shape
ref_sh
=
(
outside_ins
.
tag
.
test_value
.
shape
[
0
],)
ref_sh
+=
nd
.
outputs
[
0
]
.
tag
.
test_value
.
shape
assert
new_sh
==
ref_sh
changed
=
True
if
counts
>=
max_iterations
:
raise
Exception
(
'Error in the `scan_pushout_
non_
seq_operations`.'
raise
Exception
(
'Error in the `scan_pushout_seq_operations`.'
' The optimization exhausted the maximal number '
'of iterations allowed!'
)
# We need to check all candidate replacements and choose those that
...
...
@@ -436,7 +452,7 @@ class PushOutSeqScan(gof.Optimizer):
clean_replace_with_out
=
[]
existent_nodes
=
[
nd
for
nd
in
local_fgraph
.
toposort
()
if
nd
not
in
to_remove
]
if
nd
not
in
to_remove
]
to_keep
=
[]
for
nd
in
existent_nodes
:
to_keep
+=
nd
.
inputs
...
...
@@ -456,8 +472,8 @@ class PushOutSeqScan(gof.Optimizer):
nw_outer
=
[]
nw_inner
=
[]
for
to_repl
,
repl_in
,
repl_out
in
zip
(
clean_to_replace
,
clean_replace_with_in
,
clean_replace_with_out
):
clean_replace_with_in
,
clean_replace_with_out
):
if
isinstance
(
repl_out
,
theano
.
Constant
):
repl_in
=
repl_out
.
clone
()
else
:
...
...
@@ -473,12 +489,14 @@ class PushOutSeqScan(gof.Optimizer):
nw_info
=
op
.
info
.
copy
()
nw_info
[
'n_seqs'
]
+=
len
(
nw_inner
)
nwScan
=
scan_op
.
Scan
(
op_ins
,
op_outs
,
nw_info
)
nw_node
=
nwScan
.
make_node
(
*
(
node
.
inputs
[:
1
]
+
nw_outer
+
node
.
inputs
[
1
:]))
# Do not call make_node for test_value
nw_node
=
nwScan
(
*
(
node
.
inputs
[:
1
]
+
nw_outer
+
node
.
inputs
[
1
:]),
**
dict
(
return_list
=
True
))[
0
]
.
owner
fgraph
.
replace_all_validate_remove
(
zip
(
node
.
outputs
,
nw_node
.
outputs
),
remove
=
[
node
],
reason
=
'scan
_push_computation_out
'
)
reason
=
'scan
Op_pushout_seqs_ops
'
)
return
True
elif
(
to_keep
==
[]
and
not
op
.
as_while
and
...
...
@@ -510,8 +528,8 @@ class PushOutSeqScan(gof.Optimizer):
fgraph
.
replace_all_validate_remove
(
replace_with
.
items
(),
remove
=
[
node
],
reason
=
'scan
_push_seq_computation_out
'
)
reason
=
'scan
Op_pushout_seqs_ops
'
)
return
True
else
:
return
False
...
...
@@ -532,7 +550,7 @@ class ScanInplaceOptimizer(Optimizer):
nodes
=
fgraph
.
toposort
()
scan_nodes
=
[
x
for
x
in
nodes
if
(
isinstance
(
x
.
op
,
scan_op
.
Scan
)
and
x
.
op
.
info
[
'gpu'
]
==
self
.
gpu_flag
)]
x
.
op
.
info
[
'gpu'
]
==
self
.
gpu_flag
)]
for
scan_idx
in
xrange
(
len
(
scan_nodes
)):
node
=
scan_nodes
[
scan_idx
]
op
=
node
.
op
...
...
@@ -563,12 +581,13 @@ class ScanInplaceOptimizer(Optimizer):
info
,
typeConstructor
=
self
.
typeConstructor
)
new_outs
=
new_op
.
make_node
(
*
inputs
)
.
outputs
# Do not call make_node for test_value
new_outs
=
new_op
(
*
inputs
,
**
dict
(
return_list
=
True
))
try
:
fgraph
.
replace_all_validate_remove
(
zip
(
node
.
outputs
,
new_outs
),
remove
=
[
node
],
reason
=
self
.
__class__
.
__name__
)
reason
=
'scanOp_make_inplace'
)
op
=
new_op
node
=
new_outs
[
0
]
.
owner
except
InconsistencyError
,
e
:
...
...
@@ -720,7 +739,7 @@ class ScanSaveMem(gof.Optimizer):
except
KeyError
:
length
=
out
.
shape
[
0
]
cf_slice
=
tensor
.
get_canonical_form_slice
(
this_slice
[
0
],
length
)
this_slice
[
0
],
length
)
slices
[
i
]
+=
[(
cf_slice
,
this_slice
)]
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
...
...
@@ -847,9 +866,8 @@ class ScanSaveMem(gof.Optimizer):
nw_inputs
[
0
]
=
nw_steps
# 3.2 check orphane outputs to see if we can eliminate any
required
,
not_required
=
\
scan_utils
.
scan_can_remove_outs
(
node
.
op
,
orphane_outs
)
required
,
not_required
=
scan_utils
.
scan_can_remove_outs
(
node
.
op
,
orphane_outs
)
# 3.3. compose replace pairs for those nodes that need not
# to store everything in memory ( or ar orphane and required
# by the inner function .. )
...
...
@@ -947,9 +965,10 @@ class ScanSaveMem(gof.Optimizer):
# I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that
info
[
'_scan_savemem_visited'
]
=
True
new_outs
=
scan_op
.
Scan
(
inps
,
outs
,
info
)
.
make_node
(
*
node_ins
)
.
outputs
# Do not call make_node for test_value
new_outs
=
scan_op
.
Scan
(
inps
,
outs
,
info
)(
*
node_ins
,
**
dict
(
return_list
=
True
))
old_new
=
[]
# 3.7 Get replace pairs for those outputs that do not change
...
...
@@ -978,9 +997,8 @@ class ScanSaveMem(gof.Optimizer):
sl_ins
=
tensor
.
Subtensor
.
collapse
(
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
tensor
.
Variable
))
new_o
=
subtens
.
make_node
(
new_outs
[
nw_pos
],
*
sl_ins
)
.
outputs
[
0
]
tensor
.
Variable
))
new_o
=
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
)
if
new_o
.
ndim
>
0
:
new_o
=
new_o
[::
cnf_slice
[
1
]]
replaced_outs
.
append
(
idx
)
...
...
@@ -1009,18 +1027,16 @@ class ScanSaveMem(gof.Optimizer):
else
:
position
=
(
cnf_slice
[
0
]
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
])
nw_slice
=
(
sanitize
(
position
),)
+
\
tuple
(
old_slices
[
1
:])
init_l
[
pos
]
+
store_steps
[
pos
])
nw_slice
=
(
sanitize
(
position
),)
+
tuple
(
old_slices
[
1
:])
subtens
=
tensor
.
Subtensor
(
nw_slice
)
sl_ins
=
tensor
.
Subtensor
.
collapse
(
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
tensor
.
Variable
))
new_o
=
subtens
.
make_node
(
new_outs
[
nw_pos
],
*
sl_ins
)
.
outputs
[
0
]
new_o
=
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
)
if
new_o
.
ndim
>
0
:
new_o
=
new_o
[::
cnf_slice
[
1
]]
old_new
+=
[(
old
,
new_o
)]
...
...
@@ -1042,12 +1058,12 @@ class ScanSaveMem(gof.Optimizer):
remove
.
append
(
node
)
fgraph
.
replace_all_validate_remove
(
old_new
,
remove
,
reason
=
'scan_save_mem'
)
reason
=
'scan
Op
_save_mem'
)
def
apply
(
self
,
fgraph
):
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
scan_op
.
Scan
)]
scan_op
.
Scan
)]
for
node
in
nodelist
:
if
not
hasattr
(
node
.
op
,
'_scan_savemem_visited'
):
self
.
process_node
(
fgraph
,
node
)
...
...
@@ -1230,7 +1246,7 @@ class ScanMerge(gof.Optimizer):
proposal
=
self
.
merge
(
subset
)
fgraph
.
replace_all_validate_remove
(
proposal
,
remove
=
subset
,
reason
=
'scan_merge'
)
reason
=
'scan
Op
_merge'
)
def
has_duplicates
(
l
):
...
...
@@ -1389,13 +1405,13 @@ def scan_merge_inouts(node):
# items scan is supposed to store for this nit_sot sequence
shapes
.
append
(
x
)
tmp
=
[
map_nitsot_out
(
i
,
o
,
sh
,
seen
)
for
i
,
o
,
sh
in
zip
(
na
.
inner_out_nit_sot
,
na
.
outer_out_nit_sot
,
shapes
)]
for
i
,
o
,
sh
in
zip
(
na
.
inner_out_nit_sot
,
na
.
outer_out_nit_sot
,
shapes
)]
na
.
outer_out_nit_sot
=
[
map_nitsot_out
(
i
,
o
,
sh
,
seen
)
for
i
,
o
,
sh
in
zip
(
na
.
inner_out_nit_sot
,
na
.
outer_out_nit_sot
,
shapes
)]
na
.
outer_out_nit_sot
,
shapes
)]
seen
=
[]
na
.
outer_out_sit_sot
=
[
map_out
(
i
,
o
,
seen
)
...
...
@@ -1592,10 +1608,8 @@ class PushOutDot1(gof.Optimizer):
old
=
node
.
outputs
[
pos
]
.
clients
[
0
][
0
]
.
outputs
[
0
]
old_new
.
append
((
old
,
new_out
))
old_new
+=
zip
(
node
.
outputs
[
pos
+
1
:],
new_outs
[
pos
:])
fgraph
.
replace_all_validate_remove
(
old_new
,
remove
=
[
node
],
reason
=
'PushOutDot1'
)
fgraph
.
replace_all_validate_remove
(
old_new
,
remove
=
[
node
],
reason
=
'scan_pushout_dot1'
)
# I've added an equilibrium because later scan optimization in the sequence
...
...
@@ -1612,7 +1626,7 @@ optdb.register('scan_eqopt1', scan_eqopt1, .1, 'fast_run', 'scan')
optdb
.
register
(
'scan_eqopt2'
,
scan_eqopt2
,
1.6
,
'fast_run'
,
'scan'
)
optdb
.
register
(
'scanOp_make_inplace'
,
ScanInplaceOptimizer
(
typeConstructor
=
None
,
gpu_flag
=
False
),
gpu_flag
=
False
),
75
,
'fast_run'
,
'inplace'
,
...
...
@@ -1628,6 +1642,7 @@ scan_seqopt1.register('scanOp_remove_constants_and_unused_inputs0',
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
1
,
'remove_constants_and_unused_inputs_scan'
,
'fast_run'
,
'scan'
)
...
...
@@ -1662,10 +1677,11 @@ scan_seqopt2.register('constant_folding_for_scan2',
'scan'
)
scan_seqopt2
.
register
(
'scanOp_remove_constants_and_unused_inputs
0
'
,
scan_seqopt2
.
register
(
'scanOp_remove_constants_and_unused_inputs
1
'
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
2
,
'remove_constants_and_unused_inputs_scan'
,
'fast_run'
,
'scan'
)
...
...
@@ -1684,12 +1700,14 @@ scan_seqopt2.register('scanop_remove_constants_and_unused_inputs2',
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
5
,
'remove_constants_and_unused_inputs_scan'
,
'fast_run'
,
'scan'
)
scan_seqopt2
.
register
(
'scanOp_merge_inouts'
,
opt
.
in2out
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
6
,
'scan_merge_inouts'
,
'fast_run'
,
'scan'
)
...
...
@@ -1707,5 +1725,6 @@ scan_seqopt2.register('scanOp_remove_constants_and_unused_inputs3',
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
8
,
'remove_constants_and_unused_inputs_scan'
,
'fast_run'
,
'scan'
)
theano/scan_module/scan_utils.py
浏览文件 @
a351e3db
...
...
@@ -500,7 +500,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph
=
None
shape_feature
.
on_import
(
dummy_fgraph
,
out
.
owner
)
shape_feature
.
on_import
(
dummy_fgraph
,
out
.
owner
,
reason
=
"dummy"
)
ret
=
[]
for
o
in
outs
:
...
...
theano/scan_module/tests/test_scan.py
浏览文件 @
a351e3db
...
...
@@ -1141,7 +1141,7 @@ class T_Scan(unittest.TestCase):
go_backwards
=
False
)
gX
,
gY
=
tensor
.
grad
(
values
[
1
]
.
sum
(),
[
x
,
y
])
f
=
theano
.
function
([
c
,
x
,
y
],
[
gX
,
gY
],
allow_input_downcast
=
True
)
allow_input_downcast
=
True
)
# Check for runtime errors
f
(
numpy
.
int32
(
0
),
numpy
.
float32
(
1.
),
numpy
.
float32
(
.
5
))
...
...
@@ -1545,6 +1545,12 @@ class T_Scan(unittest.TestCase):
x0
=
theano
.
tensor
.
vector
(
'x0'
)
y0
=
theano
.
tensor
.
vector
(
'y0'
)
W_in1
.
tag
.
test_value
=
vW_in1
u1
.
tag
.
test_value
=
v_u1
u2
.
tag
.
test_value
=
v_u2
x0
.
tag
.
test_value
=
v_x0
y0
.
tag
.
test_value
=
v_y0
def
f_rnn_cmpl
(
u1_t
,
u2_tm1
,
u2_t
,
...
...
@@ -1553,33 +1559,46 @@ class T_Scan(unittest.TestCase):
y_tm1
,
y_tm3
,
W_in1
):
return
[
theano
.
dot
(
u1_t
,
W_in1
)
+
\
(
u2_t
+
u2_tm1
*
u2_tp1
)
*
W_in2
+
\
theano
.
dot
(
x_tm1
,
W
),
return
[
theano
.
dot
(
u1_t
,
W_in1
)
+
(
u2_t
+
u2_tm1
*
u2_tp1
)
*
W_in2
+
theano
.
dot
(
x_tm1
,
W
),
(
y_tm1
+
y_tm3
)
*
theano
.
dot
(
x_tm1
,
W_out
),
theano
.
dot
(
u1_t
,
W_in1
)]
cost
,
updates
=
scan_project_sum
(
f_rnn_cmpl
,
[
u1
,
dict
(
input
=
u2
,
taps
=
[
-
1
,
0
,
1
])],
[
x0
,
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
]),
None
],
W_in1
,
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
)
vparams
=
[
v_u1
,
v_u2
,
v_x0
,
v_y0
,
vW_in1
]
params
=
[
u1
,
u2
,
x0
,
y0
,
W_in1
]
gparams
=
theano
.
tensor
.
grad
(
cost
,
params
)
grad_fn
=
theano
.
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
gparams
,
updates
=
updates
,
no_default_updates
=
True
,
allow_input_downcast
=
True
)
cost_fn
=
theano
.
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
cost
,
updates
=
updates
,
no_default_updates
=
True
,
allow_input_downcast
=
True
)
# We change the compute_test_value[_opt] flag to run the
# assert in Scan.grad() of the new scan input sequence related
# to outer_mitsot_outs, outer_sitsot_outs and
# outer_nitsot_outs. This allow to test an old Scan bug.
old1
=
theano
.
config
.
compute_test_value
old2
=
theano
.
config
.
compute_test_value_opt
theano
.
config
.
compute_test_value
=
'raise'
theano
.
config
.
compute_test_value_opt
=
'raise'
try
:
cost
,
updates
=
scan_project_sum
(
f_rnn_cmpl
,
[
u1
,
dict
(
input
=
u2
,
taps
=
[
-
1
,
0
,
1
])],
[
x0
,
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
]),
None
],
W_in1
,
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
)
vparams
=
[
v_u1
,
v_u2
,
v_x0
,
v_y0
,
vW_in1
]
params
=
[
u1
,
u2
,
x0
,
y0
,
W_in1
]
gparams
=
theano
.
tensor
.
grad
(
cost
,
params
)
grad_fn
=
theano
.
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
gparams
,
updates
=
updates
,
no_default_updates
=
True
,
allow_input_downcast
=
True
)
cost_fn
=
theano
.
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
cost
,
updates
=
updates
,
no_default_updates
=
True
,
allow_input_downcast
=
True
)
finally
:
theano
.
config
.
compute_test_value
=
old1
theano
.
config
.
compute_test_value_opt
=
old2
num_grad
=
multiple_outputs_numeric_grad
(
cost_fn
,
[
v_u1
,
...
...
theano/tensor/basic.py
浏览文件 @
a351e3db
...
...
@@ -2543,7 +2543,7 @@ class Alloc(gof.Op):
#change.
return
[
gx
]
+
[
DisconnectedType
()()
for
i
in
inputs
[
1
:]]
def
__call__
(
self
,
val
,
*
shapes
):
def
__call__
(
self
,
val
,
*
shapes
,
**
kwargs
):
"""
If the alloc would be useless, this function returns val.
...
...
@@ -2554,7 +2554,7 @@ class Alloc(gof.Op):
If you always want an Alloc node, call make_node.
"""
ret
=
super
(
Alloc
,
self
)
.
__call__
(
val
,
*
shapes
)
ret
=
super
(
Alloc
,
self
)
.
__call__
(
val
,
*
shapes
,
**
kwargs
)
try
:
# It makes optimization difficult when useless allocs are thrown
# into the graph at every stage of optimization. This little logic
...
...
theano/tensor/opt.py
浏览文件 @
a351e3db
...
...
@@ -49,14 +49,24 @@ theano.configparser.AddConfigVar('on_shape_error',
def
out2in
(
*
local_opts
):
"""WRITEME """
return
opt
.
TopoOptimizer
(
opt
.
LocalOptGroup
(
*
local_opts
),
if
len
(
local_opts
)
>
1
:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts
=
opt
.
LocalOptGroup
(
*
local_opts
),
else
:
local_opts
,
=
local_opts
return
opt
.
TopoOptimizer
(
local_opts
,
order
=
'out_to_in'
,
failure_callback
=
TopoOptimizer
.
warn_inplace
)
def
in2out
(
*
local_opts
,
**
kwargs
):
"""WRITEME """
return
opt
.
TopoOptimizer
(
opt
.
LocalOptGroup
(
*
local_opts
),
if
len
(
local_opts
)
>
1
:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts
=
opt
.
LocalOptGroup
(
*
local_opts
),
else
:
local_opts
,
=
local_opts
return
opt
.
TopoOptimizer
(
local_opts
,
order
=
'in_to_out'
,
failure_callback
=
TopoOptimizer
.
warn_inplace
,
**
kwargs
)
...
...
@@ -384,10 +394,12 @@ def local_dimshuffle_lift(node):
input
=
node
.
inputs
[
0
]
inode
=
input
.
owner
if
inode
and
isinstance
(
inode
.
op
,
Elemwise
)
and
(
len
(
input
.
clients
)
==
1
):
return
inode
.
op
.
make_node
(
*
[
DimShuffle
(
input
.
type
.
broadcastable
,
op
.
new_order
,
op
.
inplace
)(
input
)
for
input
in
inode
.
inputs
])
.
outputs
# Don't use make_node to have tag.test_value set.
ret
=
inode
.
op
(
*
[
DimShuffle
(
input
.
type
.
broadcastable
,
op
.
new_order
,
op
.
inplace
)(
input
)
for
input
in
inode
.
inputs
],
**
dict
(
return_list
=
True
))
return
ret
if
inode
and
isinstance
(
inode
.
op
,
DimShuffle
):
new_order
=
[
x
==
'x'
and
'x'
or
inode
.
op
.
new_order
[
x
]
for
x
in
op
.
new_order
]
...
...
@@ -397,8 +409,9 @@ def local_dimshuffle_lift(node):
iinput
.
type
.
ndim
):
return
[
iinput
]
else
:
return
DimShuffle
(
iinput
.
type
.
broadcastable
,
new_order
,
inplace
)
.
make_node
(
iinput
)
.
outputs
ret
=
DimShuffle
(
iinput
.
type
.
broadcastable
,
new_order
,
inplace
)(
iinput
,
**
dict
(
return_list
=
True
))
return
ret
@register_canonicalize
...
...
@@ -437,8 +450,10 @@ def dimshuffle_as_view(node):
#Step 60 is the inplace optimization stage.
compile
.
optdb
.
register
(
'dimshuffle_as_view'
,
TopoOptimizer
(
dimshuffle_as_view
,
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
TopoOptimizer
(
dimshuffle_as_view
,
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
'fast_run'
,
'inplace'
)
register_canonicalize
(
local_dimshuffle_lift
)
register_specialize
(
local_dimshuffle_lift
)
...
...
@@ -771,7 +786,8 @@ class ShapeFeature(object):
if
hasattr
(
r
.
type
,
"broadcastable"
)
and
r
.
type
.
broadcastable
[
i
]:
return
self
.
lscalar_one
else
:
return
Shape_i
(
i
)
.
make_node
(
r
)
.
outputs
[
0
]
# Do not call make_node for test_value
return
Shape_i
(
i
)(
r
)
def
shape_tuple
(
self
,
r
):
"""Return a tuple of symbolic shape vars for tensor variable r"""
...
...
@@ -970,9 +986,9 @@ class ShapeFeature(object):
# shape var -> graph v
for
node
in
fgraph
.
toposort
():
self
.
on_import
(
fgraph
,
node
)
self
.
on_import
(
fgraph
,
node
,
reason
=
'on_attach'
)
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
node
.
outputs
[
0
]
in
self
.
shape_of
:
# this is a revert, not really an import
for
r
in
node
.
outputs
+
node
.
inputs
:
...
...
@@ -1933,7 +1949,8 @@ def local_subtensor_merge(node):
sl_ins
=
Subtensor
.
collapse
(
merged_slices
,
lambda
x
:
isinstance
(
x
,
T
.
Variable
))
out
=
subtens
.
make_node
(
x
,
*
sl_ins
)
.
outputs
[
0
]
# Do not call make_node for test_value
out
=
subtens
(
x
,
*
sl_ins
)
return
[
out
]
...
...
@@ -4583,8 +4600,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
elif
ii
in
tmp_input
:
tmp_s_input
.
append
(
tmp_scalar
[
tmp_input
.
index
(
ii
)])
else
:
tmp_s_input
.
append
(
scalar
.
Scalar
(
ii
.
dtype
)
.
make_variable
())
tmp
=
scalar
.
Scalar
(
ii
.
dtype
)
.
make_variable
()
try
:
tmp
.
tag
.
test_value
=
gof
.
op
.
get_test_value
(
ii
)
.
flatten
()[
0
]
except
AttributeError
:
pass
tmp_s_input
.
append
(
tmp
)
tmp_input
.
append
(
ii
)
tmp_scalar
.
append
(
tmp_s_input
[
-
1
])
s_op
=
i
.
owner
.
op
.
scalar_op
(
*
tmp_s_input
)
...
...
@@ -4634,6 +4655,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
s
=
s_inputs
[
inputs
.
index
(
i
)]
else
:
s
=
scalar
.
Scalar
(
i
.
dtype
)
.
make_variable
()
try
:
v
=
gof
.
op
.
get_test_value
(
i
)
if
v
.
size
>
0
:
s
.
tag
.
test_value
=
gof
.
op
.
get_test_value
(
i
)
.
flatten
()[
0
]
except
AttributeError
:
pass
inputs
.
append
(
i
)
s_inputs
.
append
(
s
)
s_g
.
append
(
s
)
...
...
@@ -4667,7 +4695,8 @@ your code will run correctly, but may be slower.""")
C
=
scalar
.
Composite
(
s_inputs
,
[
s_new_out
])
#create the new node.
n
=
OP
(
C
)
.
make_node
(
*
inputs
)
#Do not call make_node to have test_value
n
=
OP
(
C
)(
*
inputs
)
.
owner
assert
len
(
n
.
outputs
)
==
1
assert
node
.
outputs
[
0
]
.
dtype
==
n
.
outputs
[
0
]
.
dtype
...
...
@@ -4728,9 +4757,11 @@ if config.tensor.local_elemwise_fusion:
_logger
.
debug
(
"enabling optimization fusion elemwise in fast_run"
)
compile
.
optdb
.
register
(
'elemwise_fusion'
,
FusionOptimizer
(
local_elemwise_fusion
),
71.00
,
'fast_run'
,
'fusion'
,
'local_elemwise_fusion'
)
'fast_run'
,
'fusion'
,
'local_elemwise_fusion'
,
'FusionOptimizer'
)
else
:
_logger
.
debug
(
"not enabling optimization fusion elemwise in fast_run"
)
compile
.
optdb
.
register
(
'elemwise_fusion'
,
FusionOptimizer
(
local_elemwise_fusion
),
71.00
,
'fusion'
,
'local_elemwise_fusion'
)
'fusion'
,
'local_elemwise_fusion'
,
'FusionOptimizer'
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论