Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a351e3db
提交
a351e3db
authored
9月 03, 2013
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1482 from nouiz/rnade
Scan crash fix
上级
e523802f
66a9e220
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
403 行增加
和
239 行删除
+403
-239
config.txt
doc/library/config.txt
+14
-0
debugmode.py
theano/compile/debugmode.py
+9
-5
function_module.py
theano/compile/function_module.py
+1
-1
configdefaults.py
theano/configdefaults.py
+13
-1
destroyhandler.py
theano/gof/destroyhandler.py
+4
-4
fg.py
theano/gof/fg.py
+94
-79
op.py
theano/gof/op.py
+9
-2
opt.py
theano/gof/opt.py
+8
-8
toolbox.py
theano/gof/toolbox.py
+18
-14
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+0
-1
ops.py
theano/sandbox/linalg/ops.py
+2
-2
scan_utils.py
theano/sandbox/scan_module/scan_utils.py
+1
-1
basic.py
theano/scalar/basic.py
+18
-0
scan.py
theano/scan_module/scan.py
+3
-3
scan_op.py
theano/scan_module/scan_op.py
+24
-2
scan_opt.py
theano/scan_module/scan_opt.py
+88
-69
scan_utils.py
theano/scan_module/scan_utils.py
+1
-1
test_scan.py
theano/scan_module/tests/test_scan.py
+44
-25
basic.py
theano/tensor/basic.py
+2
-2
opt.py
theano/tensor/opt.py
+50
-19
没有找到文件。
doc/library/config.txt
浏览文件 @
a351e3db
...
@@ -482,6 +482,14 @@ import theano and print the config variable, as in:
...
@@ -482,6 +482,14 @@ import theano and print the config variable, as in:
This flag's value cannot be modified during the program execution.
This flag's value cannot be modified during the program execution.
.. attribute:: optimizer_verbose
Bool value: either True or False
Default: False
When True, we print on the stdout the optimization applied.
.. attribute:: nocleanup
.. attribute:: nocleanup
Bool value: either True or False
Bool value: either True or False
...
@@ -630,6 +638,12 @@ import theano and print the config variable, as in:
...
@@ -630,6 +638,12 @@ import theano and print the config variable, as in:
this Op
this Op
- ``'raise'`` will raise an Exception
- ``'raise'`` will raise an Exception
.. attribute:: config.compute_test_value_opt
As ``compute_test_value``, but it is the value used during Theano
optimization phase. Theano user's do not need to use this. This is
to help debug shape error in Theano optimization.
.. attribute:: config.exception_verbosity
.. attribute:: config.exception_verbosity
String Value: ``'low'``, ``'high'``.
String Value: ``'low'``, ``'high'``.
...
...
theano/compile/debugmode.py
浏览文件 @
a351e3db
...
@@ -1428,21 +1428,25 @@ class _VariableEquivalenceTracker(object):
...
@@ -1428,21 +1428,25 @@ class _VariableEquivalenceTracker(object):
self
.
reasons
=
{}
self
.
reasons
=
{}
self
.
replaced_by
=
{}
self
.
replaced_by
=
{}
self
.
event_list
=
[]
self
.
event_list
=
[]
for
node
in
fgraph
.
toposort
():
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_detach
(
self
,
fgraph
):
def
on_detach
(
self
,
fgraph
):
assert
fgraph
is
self
.
fgraph
assert
fgraph
is
self
.
fgraph
self
.
fgraph
=
None
self
.
fgraph
=
None
def
on_prune
(
self
,
fgraph
,
node
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
self
.
event_list
.
append
(
_FunctionGraphEvent
(
'prune'
,
node
))
self
.
event_list
.
append
(
_FunctionGraphEvent
(
'prune'
,
node
,
reason
=
reason
))
#print 'PRUNING NODE', node, id(node)
#print 'PRUNING NODE', node, id(node)
assert
node
in
self
.
active_nodes
assert
node
in
self
.
active_nodes
assert
node
not
in
self
.
inactive_nodes
assert
node
not
in
self
.
inactive_nodes
self
.
active_nodes
.
remove
(
node
)
self
.
active_nodes
.
remove
(
node
)
self
.
inactive_nodes
.
add
(
node
)
self
.
inactive_nodes
.
add
(
node
)
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
self
.
event_list
.
append
(
_FunctionGraphEvent
(
'import'
,
node
))
self
.
event_list
.
append
(
_FunctionGraphEvent
(
'import'
,
node
,
reason
=
reason
))
#print 'NEW NODE', node, id(node)
#print 'NEW NODE', node, id(node)
assert
node
not
in
self
.
active_nodes
assert
node
not
in
self
.
active_nodes
...
@@ -2114,7 +2118,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
...
@@ -2114,7 +2118,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# optimize the fgraph
# optimize the fgraph
compute_test_value_orig
=
theano
.
config
.
compute_test_value
compute_test_value_orig
=
theano
.
config
.
compute_test_value
try
:
try
:
theano
.
config
.
compute_test_value
=
"off"
theano
.
config
.
compute_test_value
=
theano
.
config
.
compute_test_value_opt
optimizer
(
fgraph
)
optimizer
(
fgraph
)
theano
.
compile
.
function_module
.
insert_deepcopy
(
fgraph
,
inputs
,
theano
.
compile
.
function_module
.
insert_deepcopy
(
fgraph
,
inputs
,
...
...
theano/compile/function_module.py
浏览文件 @
a351e3db
...
@@ -1018,7 +1018,7 @@ class FunctionMaker(object):
...
@@ -1018,7 +1018,7 @@ class FunctionMaker(object):
compute_test_value_orig
=
theano
.
config
.
compute_test_value
compute_test_value_orig
=
theano
.
config
.
compute_test_value
add_stack_trace_on_call
=
gof
.
Op
.
add_stack_trace_on_call
add_stack_trace_on_call
=
gof
.
Op
.
add_stack_trace_on_call
try
:
try
:
theano
.
config
.
compute_test_value
=
"off"
theano
.
config
.
compute_test_value
=
theano
.
config
.
compute_test_value_opt
gof
.
Op
.
add_stack_trace_on_call
=
False
gof
.
Op
.
add_stack_trace_on_call
=
False
start_optimizer
=
time
.
time
()
start_optimizer
=
time
.
time
()
optimizer_profile
=
optimizer
(
fgraph
)
optimizer_profile
=
optimizer
(
fgraph
)
...
...
theano/configdefaults.py
浏览文件 @
a351e3db
...
@@ -157,6 +157,11 @@ AddConfigVar('optimizer',
...
@@ -157,6 +157,11 @@ AddConfigVar('optimizer',
EnumStr
(
'fast_run'
,
'merge'
,
'fast_compile'
,
'None'
),
EnumStr
(
'fast_run'
,
'merge'
,
'fast_compile'
,
'None'
),
in_c_key
=
False
)
in_c_key
=
False
)
AddConfigVar
(
'optimizer_verbose'
,
"If True, we print all optimization being applied"
,
BoolParam
(
False
),
in_c_key
=
False
)
AddConfigVar
(
'on_opt_error'
,
AddConfigVar
(
'on_opt_error'
,
(
"What to do when an optimization crashes: warn and skip it, raise "
(
"What to do when an optimization crashes: warn and skip it, raise "
"the exception, or fall into the pdb debugger."
),
"the exception, or fall into the pdb debugger."
),
...
@@ -379,10 +384,17 @@ AddConfigVar('compute_test_value',
...
@@ -379,10 +384,17 @@ AddConfigVar('compute_test_value',
"Constants, SharedVariables and the tag 'test_value' as inputs "
"Constants, SharedVariables and the tag 'test_value' as inputs "
"to the function. This helps the user track down problems in the "
"to the function. This helps the user track down problems in the "
"graph before it gets optimized."
),
"graph before it gets optimized."
),
EnumStr
(
'off'
,
'ignore'
,
'warn'
,
'raise'
),
EnumStr
(
'off'
,
'ignore'
,
'warn'
,
'raise'
,
'pdb'
),
in_c_key
=
False
)
in_c_key
=
False
)
AddConfigVar
(
'compute_test_value_opt'
,
(
"For debugging Theano optimization only."
" Same as compute_test_value, but is used"
" during Theano optimization"
),
EnumStr
(
'off'
,
'ignore'
,
'warn'
,
'raise'
,
'pdb'
),
in_c_key
=
False
)
"""Note to developers:
"""Note to developers:
Generally your exceptions should use an apply node's __str__
Generally your exceptions should use an apply node's __str__
method when exception_verbosity == 'low'. When exception_verbosity
method when exception_verbosity == 'low'. When exception_verbosity
...
...
theano/gof/destroyhandler.py
浏览文件 @
a351e3db
...
@@ -380,7 +380,7 @@ if 0:
...
@@ -380,7 +380,7 @@ if 0:
delattr
(
self
.
fgraph
,
'destroy_handler'
)
delattr
(
self
.
fgraph
,
'destroy_handler'
)
self
.
fgraph
=
None
self
.
fgraph
=
None
def
on_import
(
self
,
fgraph
,
app
):
def
on_import
(
self
,
fgraph
,
app
,
reason
):
"""Add Apply instance to set which must be computed"""
"""Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import")
#if app in self.debug_all_apps: raise ProtocolError("double import")
...
@@ -410,7 +410,7 @@ if 0:
...
@@ -410,7 +410,7 @@ if 0:
self
.
stale_droot
=
True
self
.
stale_droot
=
True
def
on_prune
(
self
,
fgraph
,
app
):
def
on_prune
(
self
,
fgraph
,
app
,
reason
):
"""Remove Apply instance from set which must be computed"""
"""Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#self.debug_all_apps.remove(app)
#self.debug_all_apps.remove(app)
...
@@ -765,7 +765,7 @@ class DestroyHandler(toolbox.Bookkeeper):
...
@@ -765,7 +765,7 @@ class DestroyHandler(toolbox.Bookkeeper):
delattr
(
self
.
fgraph
,
'destroy_handler'
)
delattr
(
self
.
fgraph
,
'destroy_handler'
)
self
.
fgraph
=
None
self
.
fgraph
=
None
def
on_import
(
self
,
fgraph
,
app
):
def
on_import
(
self
,
fgraph
,
app
,
reason
):
"""Add Apply instance to set which must be computed"""
"""Add Apply instance to set which must be computed"""
if
app
in
self
.
debug_all_apps
:
raise
ProtocolError
(
"double import"
)
if
app
in
self
.
debug_all_apps
:
raise
ProtocolError
(
"double import"
)
...
@@ -795,7 +795,7 @@ class DestroyHandler(toolbox.Bookkeeper):
...
@@ -795,7 +795,7 @@ class DestroyHandler(toolbox.Bookkeeper):
self
.
stale_droot
=
True
self
.
stale_droot
=
True
def
on_prune
(
self
,
fgraph
,
app
):
def
on_prune
(
self
,
fgraph
,
app
,
reason
):
"""Remove Apply instance from set which must be computed"""
"""Remove Apply instance from set which must be computed"""
if
app
not
in
self
.
debug_all_apps
:
raise
ProtocolError
(
"prune without import"
)
if
app
not
in
self
.
debug_all_apps
:
raise
ProtocolError
(
"prune without import"
)
self
.
debug_all_apps
.
remove
(
app
)
self
.
debug_all_apps
.
remove
(
app
)
...
...
theano/gof/fg.py
浏览文件 @
a351e3db
...
@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception
...
@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception
types that it can raise
types that it can raise
"""
"""
import
sys
import
sys
from
theano.gof
import
graph
from
theano.gof
import
graph
from
theano.gof
import
utils
from
theano.gof
import
utils
from
theano.gof
import
toolbox
from
theano.gof
import
toolbox
...
@@ -16,6 +17,7 @@ NullType = None
...
@@ -16,6 +17,7 @@ NullType = None
from
theano.gof.python25
import
OrderedDict
from
theano.gof.python25
import
OrderedDict
from
theano.misc.ordered_set
import
OrderedSet
from
theano.misc.ordered_set
import
OrderedSet
class
InconsistencyError
(
Exception
):
class
InconsistencyError
(
Exception
):
"""
"""
This exception should be thrown by listeners to FunctionGraph when the
This exception should be thrown by listeners to FunctionGraph when the
...
@@ -82,7 +84,8 @@ class FunctionGraph(utils.object2):
...
@@ -82,7 +84,8 @@ class FunctionGraph(utils.object2):
# so I probably am) this should be a set.
# so I probably am) this should be a set.
self
.
_features
=
[]
self
.
_features
=
[]
# All apply nodes in the subgraph defined by inputs and outputs are cached in this field
# All apply nodes in the subgraph defined by inputs and
# outputs are cached in this field
self
.
apply_nodes
=
set
()
self
.
apply_nodes
=
set
()
# Ditto for variable nodes
# Ditto for variable nodes
...
@@ -104,7 +107,7 @@ class FunctionGraph(utils.object2):
...
@@ -104,7 +107,7 @@ class FunctionGraph(utils.object2):
self
.
__setup_r__
(
input
)
self
.
__setup_r__
(
input
)
self
.
variables
.
add
(
input
)
self
.
variables
.
add
(
input
)
self
.
__import_r__
(
outputs
)
self
.
__import_r__
(
outputs
,
reason
=
"init"
)
for
i
,
output
in
enumerate
(
outputs
):
for
i
,
output
in
enumerate
(
outputs
):
output
.
clients
.
append
((
'output'
,
i
))
output
.
clients
.
append
((
'output'
,
i
))
...
@@ -112,12 +115,12 @@ class FunctionGraph(utils.object2):
...
@@ -112,12 +115,12 @@ class FunctionGraph(utils.object2):
self
.
variable_locks
=
{}
self
.
variable_locks
=
{}
self
.
profile
=
None
self
.
profile
=
None
### Setup a Variable ###
### Setup a Variable ###
def
__setup_r__
(
self
,
r
):
def
__setup_r__
(
self
,
r
):
# sets up r so it belongs to this fgraph
# sets up r so it belongs to this fgraph
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
self
:
if
(
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
self
):
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
r
.
fgraph
=
self
r
.
fgraph
=
self
r
.
clients
=
[]
r
.
clients
=
[]
...
@@ -165,13 +168,13 @@ class FunctionGraph(utils.object2):
...
@@ -165,13 +168,13 @@ class FunctionGraph(utils.object2):
self
.
inputs
=
None
self
.
inputs
=
None
self
.
outputs
=
None
self
.
outputs
=
None
### clients ###
### clients ###
def
clients
(
self
,
r
):
def
clients
(
self
,
r
):
"""
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Set of all the (node, i) pairs such that node.inputs[i] is r.
Tell differently, a list of (node,i) such that each node have r as input at index i.
Tell differently, a list of (node,i) such that each node have
r as input at index i.
"""
"""
return
r
.
clients
return
r
.
clients
...
@@ -184,12 +187,15 @@ class FunctionGraph(utils.object2):
...
@@ -184,12 +187,15 @@ class FunctionGraph(utils.object2):
"""
"""
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
print
>>
sys
.
stderr
,
'ERROR: clients intersect!'
print
>>
sys
.
stderr
,
'ERROR: clients intersect!'
print
>>
sys
.
stderr
,
' RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
r
.
clients
]
print
>>
sys
.
stderr
,
' RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
print
>>
sys
.
stderr
,
' NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
new_clients
]
for
n
,
i
in
r
.
clients
]
print
>>
sys
.
stderr
,
' NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
new_clients
]
assert
not
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
))
assert
not
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
))
r
.
clients
+=
new_clients
r
.
clients
+=
new_clients
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
):
def
__remove_clients__
(
self
,
r
,
clients_to_remove
,
prune
=
True
,
reason
=
None
):
""" WRITEME
""" WRITEME
r -> variable
r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
...
@@ -202,18 +208,16 @@ class FunctionGraph(utils.object2):
...
@@ -202,18 +208,16 @@ class FunctionGraph(utils.object2):
print
>>
sys
.
stderr
,
'ERROR: DUPLICATE CLIENT ENTRY...'
print
>>
sys
.
stderr
,
'ERROR: DUPLICATE CLIENT ENTRY...'
print
>>
sys
.
stderr
,
' ENTRY'
,
repr
(
entry
),
type
(
entry
[
0
])
print
>>
sys
.
stderr
,
' ENTRY'
,
repr
(
entry
),
type
(
entry
[
0
])
print
>>
sys
.
stderr
,
' CLIENTS'
,
repr
(
r
.
clients
)
print
>>
sys
.
stderr
,
' CLIENTS'
,
repr
(
r
.
clients
)
assert
entry
not
in
r
.
clients
# an op,i pair should be unique
assert
entry
not
in
r
.
clients
# an op,i pair should be unique
if
not
r
.
clients
:
if
not
r
.
clients
:
if
prune
:
if
prune
:
self
.
__prune_r__
([
r
])
self
.
__prune_r__
([
r
]
,
reason
)
return
False
return
False
return
True
return
True
return
False
return
False
### import ###
### import ###
def
__import_r__
(
self
,
variables
,
reason
):
def
__import_r__
(
self
,
variables
):
global
NullType
global
NullType
if
NullType
is
None
:
if
NullType
is
None
:
from
null_type
import
NullType
from
null_type
import
NullType
...
@@ -222,17 +226,18 @@ class FunctionGraph(utils.object2):
...
@@ -222,17 +226,18 @@ class FunctionGraph(utils.object2):
for
apply_node
in
[
r
.
owner
for
r
in
variables
if
r
.
owner
is
not
None
]:
for
apply_node
in
[
r
.
owner
for
r
in
variables
if
r
.
owner
is
not
None
]:
if
apply_node
not
in
r_owner_done
:
if
apply_node
not
in
r_owner_done
:
r_owner_done
.
add
(
apply_node
)
r_owner_done
.
add
(
apply_node
)
self
.
__import__
(
apply_node
)
self
.
__import__
(
apply_node
,
reason
=
reason
)
for
r
in
variables
:
for
r
in
variables
:
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
:
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
:
if
isinstance
(
r
.
type
,
NullType
):
if
isinstance
(
r
.
type
,
NullType
):
raise
TypeError
(
"Computation graph contains a NaN. "
+
r
.
type
.
why_null
)
raise
TypeError
(
"Computation graph contains a NaN. "
+
r
.
type
.
why_null
)
raise
MissingInputError
(
"Undeclared input"
,
r
)
raise
MissingInputError
(
"Undeclared input"
,
r
)
if
not
getattr
(
r
,
'fgraph'
,
None
)
is
self
:
if
not
getattr
(
r
,
'fgraph'
,
None
)
is
self
:
self
.
__setup_r__
(
r
)
self
.
__setup_r__
(
r
)
self
.
variables
.
add
(
r
)
self
.
variables
.
add
(
r
)
def
__import__
(
self
,
apply_node
,
check
=
Tru
e
):
def
__import__
(
self
,
apply_node
,
check
=
True
,
reason
=
Non
e
):
node
=
apply_node
node
=
apply_node
# We import the nodes in topological order. We only are interested
# We import the nodes in topological order. We only are interested
...
@@ -248,7 +253,9 @@ class FunctionGraph(utils.object2):
...
@@ -248,7 +253,9 @@ class FunctionGraph(utils.object2):
for
r
in
node
.
inputs
:
for
r
in
node
.
inputs
:
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
self
:
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
self
:
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
if
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
:
if
(
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
):
#Verbose error message
#Verbose error message
#Show a complete chain of variables from the missing input to an output
#Show a complete chain of variables from the missing input to an output
...
@@ -328,20 +335,18 @@ class FunctionGraph(utils.object2):
...
@@ -328,20 +335,18 @@ class FunctionGraph(utils.object2):
self
.
variables
.
add
(
input
)
self
.
variables
.
add
(
input
)
self
.
__add_clients__
(
input
,
[(
node
,
i
)])
self
.
__add_clients__
(
input
,
[(
node
,
i
)])
assert
node
.
fgraph
is
self
assert
node
.
fgraph
is
self
self
.
execute_callbacks
(
'on_import'
,
node
)
self
.
execute_callbacks
(
'on_import'
,
node
,
reason
)
### prune ###
### prune ###
def
__prune_r__
(
self
,
variables
,
reason
=
None
):
def
__prune_r__
(
self
,
variables
):
# Prunes the owners of the variables.
# Prunes the owners of the variables.
for
node
in
set
(
r
.
owner
for
r
in
variables
if
r
.
owner
is
not
None
):
for
node
in
set
(
r
.
owner
for
r
in
variables
if
r
.
owner
is
not
None
):
self
.
__prune__
(
node
)
self
.
__prune__
(
node
,
reason
)
for
r
in
variables
:
for
r
in
variables
:
if
not
r
.
clients
and
r
in
self
.
variables
:
if
not
r
.
clients
and
r
in
self
.
variables
:
self
.
variables
.
remove
(
r
)
self
.
variables
.
remove
(
r
)
def
__prune__
(
self
,
apply_node
):
def
__prune__
(
self
,
apply_node
,
reason
=
None
):
node
=
apply_node
node
=
apply_node
if
node
not
in
self
.
apply_nodes
:
if
node
not
in
self
.
apply_nodes
:
raise
Exception
(
"
%
s does not belong to this FunctionGraph and cannot be pruned."
%
node
)
raise
Exception
(
"
%
s does not belong to this FunctionGraph and cannot be pruned."
%
node
)
...
@@ -356,16 +361,13 @@ class FunctionGraph(utils.object2):
...
@@ -356,16 +361,13 @@ class FunctionGraph(utils.object2):
return
return
self
.
apply_nodes
.
remove
(
node
)
self
.
apply_nodes
.
remove
(
node
)
self
.
variables
.
difference_update
(
node
.
outputs
)
self
.
variables
.
difference_update
(
node
.
outputs
)
self
.
execute_callbacks
(
'on_prune'
,
node
)
self
.
execute_callbacks
(
'on_prune'
,
node
,
reason
)
for
i
,
input
in
enumerate
(
node
.
inputs
):
for
i
,
input
in
enumerate
(
node
.
inputs
):
self
.
__remove_clients__
(
input
,
[(
node
,
i
)])
self
.
__remove_clients__
(
input
,
[(
node
,
i
)]
,
reason
=
reason
)
#self.__prune_r__(node.inputs)
#self.__prune_r__(node.inputs)
### change input ###
### change input ###
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
"""WRITEME
"""WRITEME
Changes node.inputs[i] to new_r.
Changes node.inputs[i] to new_r.
...
@@ -381,42 +383,45 @@ class FunctionGraph(utils.object2):
...
@@ -381,42 +383,45 @@ class FunctionGraph(utils.object2):
r
=
self
.
outputs
[
i
]
r
=
self
.
outputs
[
i
]
if
not
r
.
type
==
new_r
.
type
:
if
not
r
.
type
==
new_r
.
type
:
raise
TypeError
(
"The type of the replacement must be the"
raise
TypeError
(
"The type of the replacement must be the"
" same as the type of the original Variable."
,
" same as the type of the original Variable."
,
r
,
new_r
)
r
,
new_r
)
self
.
outputs
[
i
]
=
new_r
self
.
outputs
[
i
]
=
new_r
else
:
else
:
if
node
.
fgraph
is
not
self
:
if
node
.
fgraph
is
not
self
:
raise
Exception
(
"Cannot operate on
%
s because it does not"
raise
Exception
(
"Cannot operate on
%
s because it does not"
" belong to this FunctionGraph"
%
node
)
" belong to this FunctionGraph"
%
node
)
r
=
node
.
inputs
[
i
]
r
=
node
.
inputs
[
i
]
if
not
r
.
type
==
new_r
.
type
:
if
not
r
.
type
==
new_r
.
type
:
raise
TypeError
(
"The type of the replacement must be the"
raise
TypeError
(
"The type of the replacement must be the"
" same as the type of the original Variable."
,
" same as the type of the original Variable."
,
r
,
new_r
)
r
,
new_r
)
node
.
inputs
[
i
]
=
new_r
node
.
inputs
[
i
]
=
new_r
if
r
is
new_r
:
if
r
is
new_r
:
return
return
self
.
__import_r__
([
new_r
])
self
.
__import_r__
([
new_r
]
,
reason
=
reason
)
self
.
__add_clients__
(
new_r
,
[(
node
,
i
)])
self
.
__add_clients__
(
new_r
,
[(
node
,
i
)])
prune
=
self
.
__remove_clients__
(
r
,
[(
node
,
i
)],
False
)
prune
=
self
.
__remove_clients__
(
r
,
[(
node
,
i
)],
False
)
# Precondition: the substitution is semantically valid
# Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the
# However it may introduce cycles to the graph, in which case the
# transaction will be reverted later.
# transaction will be reverted later.
self
.
execute_callbacks
(
'on_change_input'
,
node
,
i
,
r
,
new_r
,
reason
=
reason
)
self
.
execute_callbacks
(
'on_change_input'
,
node
,
i
,
r
,
new_r
,
reason
=
reason
)
if
prune
:
if
prune
:
self
.
__prune_r__
([
r
])
self
.
__prune_r__
([
r
],
reason
=
reason
)
### replace ###
### replace ###
def
replace
(
self
,
r
,
new_r
,
reason
=
None
,
verbose
=
None
):
def
replace
(
self
,
r
,
new_r
,
reason
=
None
):
""" WRITEME
""" WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead.
For every node that uses r as input, makes it use new_r instead.
"""
"""
if
verbose
is
None
:
verbose
=
config
.
optimizer_verbose
if
verbose
:
print
reason
,
r
,
new_r
if
r
.
fgraph
is
not
self
:
if
r
.
fgraph
is
not
self
:
raise
Exception
(
"Cannot replace
%
s because it does not belong to this FunctionGraph"
%
r
,
str
(
reason
))
raise
Exception
(
"Cannot replace
%
s because it does not belong to this FunctionGraph"
%
r
,
str
(
reason
))
if
not
r
.
type
==
new_r
.
type
:
if
not
r
.
type
==
new_r
.
type
:
...
@@ -426,7 +431,7 @@ class FunctionGraph(utils.object2):
...
@@ -426,7 +431,7 @@ class FunctionGraph(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops
# because it makes it easier to implement some optimizations for multiple-output ops
return
return
for
node
,
i
in
list
(
r
.
clients
):
# copy the client list for iteration
for
node
,
i
in
list
(
r
.
clients
):
# copy the client list for iteration
assert
(
node
==
'output'
and
self
.
outputs
[
i
]
is
r
)
or
(
node
.
inputs
[
i
]
is
r
)
assert
(
node
==
'output'
and
self
.
outputs
[
i
]
is
r
)
or
(
node
.
inputs
[
i
]
is
r
)
self
.
change_input
(
node
,
i
,
new_r
,
reason
=
reason
)
self
.
change_input
(
node
,
i
,
new_r
,
reason
=
reason
)
...
@@ -440,11 +445,9 @@ class FunctionGraph(utils.object2):
...
@@ -440,11 +445,9 @@ class FunctionGraph(utils.object2):
for
r
,
new_r
in
pairs
:
for
r
,
new_r
in
pairs
:
self
.
replace
(
r
,
new_r
,
reason
=
reason
)
self
.
replace
(
r
,
new_r
,
reason
=
reason
)
def
extend
(
self
,
feature
):
def
extend
(
self
,
feature
):
warnings
.
warn
(
"FunctionGraph.extend is deprecatd. It has been "
warnings
.
warn
(
"FunctionGraph.extend is deprecatd. It has been "
"renamed to FunctionGraph.attach_feature"
)
"renamed to FunctionGraph.attach_feature"
)
return
self
.
attach_feature
(
feature
)
return
self
.
attach_feature
(
feature
)
def
attach_feature
(
self
,
feature
):
def
attach_feature
(
self
,
feature
):
...
@@ -455,7 +458,7 @@ class FunctionGraph(utils.object2):
...
@@ -455,7 +458,7 @@ class FunctionGraph(utils.object2):
# Filter out literally identical features
# Filter out literally identical features
if
feature
in
self
.
_features
:
if
feature
in
self
.
_features
:
return
# the feature is already present
return
# the feature is already present
# Filter out functionally identical features.
# Filter out functionally identical features.
# Features may use their on_attach method to raise
# Features may use their on_attach method to raise
...
@@ -481,7 +484,9 @@ class FunctionGraph(utils.object2):
...
@@ -481,7 +484,9 @@ class FunctionGraph(utils.object2):
"""WRITEME
"""WRITEME
Removes the feature from the graph.
Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method is defined.
Calls feature.on_detach(function_graph) if an on_detach method
is defined.
"""
"""
try
:
try
:
self
.
_features
.
remove
(
feature
)
self
.
_features
.
remove
(
feature
)
...
@@ -491,9 +496,7 @@ class FunctionGraph(utils.object2):
...
@@ -491,9 +496,7 @@ class FunctionGraph(utils.object2):
if
detach
is
not
None
:
if
detach
is
not
None
:
detach
(
self
)
detach
(
self
)
### callback utils ###
### callback utils ###
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
"""WRITEME
"""WRITEME
Calls
Calls
...
@@ -518,7 +521,6 @@ class FunctionGraph(utils.object2):
...
@@ -518,7 +521,6 @@ class FunctionGraph(utils.object2):
else
:
else
:
raise
raise
def
collect_callbacks
(
self
,
name
,
*
args
):
def
collect_callbacks
(
self
,
name
,
*
args
):
"""WRITEME
"""WRITEME
Returns a dictionary d such that:
Returns a dictionary d such that:
...
@@ -534,9 +536,7 @@ class FunctionGraph(utils.object2):
...
@@ -534,9 +536,7 @@ class FunctionGraph(utils.object2):
d
[
feature
]
=
fn
(
*
args
)
d
[
feature
]
=
fn
(
*
args
)
return
d
return
d
### misc ###
### misc ###
def
toposort
(
self
):
def
toposort
(
self
):
"""WRITEME
"""WRITEME
Returns an ordering of the graph's Apply nodes such that:
Returns an ordering of the graph's Apply nodes such that:
...
@@ -552,8 +552,8 @@ class FunctionGraph(utils.object2):
...
@@ -552,8 +552,8 @@ class FunctionGraph(utils.object2):
if
len
(
self
.
apply_nodes
)
<
2
:
if
len
(
self
.
apply_nodes
)
<
2
:
# optimization
# optimization
# when there are 0 or 1 nodes, no sorting is necessary
# when there are 0 or 1 nodes, no sorting is necessary
# This special case happens a lot because the OpWiseCLinker
produces
# This special case happens a lot because the OpWiseCLinker
# 1-element graphs.
#
produces
1-element graphs.
return
list
(
self
.
apply_nodes
)
return
list
(
self
.
apply_nodes
)
fg
=
self
fg
=
self
...
@@ -568,30 +568,33 @@ class FunctionGraph(utils.object2):
...
@@ -568,30 +568,33 @@ class FunctionGraph(utils.object2):
Return dict d s.t. d[node] is a list of nodes that must be evaluated
Return dict d s.t. d[node] is a list of nodes that must be evaluated
before node itself can be evaluated.
before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that all
This is used primarily by the destroy_handler feature to ensure that
clients of any destroyed inputs have already computed their outputs.
all clients of any destroyed inputs have already computed their
outputs.
:note: This only calls the orderings() fct on all features. It does not
:note: This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
take care of computing dependencies by itself.
"""
"""
ords
=
OrderedDict
()
ords
=
OrderedDict
()
assert
isinstance
(
self
.
_features
,
list
)
assert
isinstance
(
self
.
_features
,
list
)
for
feature
in
self
.
_features
:
for
feature
in
self
.
_features
:
if
hasattr
(
feature
,
'orderings'
):
if
hasattr
(
feature
,
'orderings'
):
orderings
=
feature
.
orderings
(
self
)
orderings
=
feature
.
orderings
(
self
)
if
not
isinstance
(
orderings
,
OrderedDict
):
if
not
isinstance
(
orderings
,
OrderedDict
):
raise
TypeError
(
"Non-deterministic return value from "
\
raise
TypeError
(
"Non-deterministic return value from "
+
+
str
(
feature
.
orderings
)
\
str
(
feature
.
orderings
)
+
+
". Nondeterministic object is "
+
str
(
orderings
))
". Nondeterministic object is "
+
str
(
orderings
))
for
node
,
prereqs
in
orderings
.
items
():
for
node
,
prereqs
in
orderings
.
items
():
if
not
isinstance
(
prereqs
,
(
list
,
OrderedSet
)):
if
not
isinstance
(
prereqs
,
(
list
,
OrderedSet
)):
raise
TypeError
(
"prereqs must be a type with a "
raise
TypeError
(
"deterministic iteration order, or toposort "
"prereqs must be a type with a "
" will be non-deterministic."
)
"deterministic iteration order, or toposort "
" will be non-deterministic."
)
ords
.
setdefault
(
node
,
[])
.
extend
(
prereqs
)
ords
.
setdefault
(
node
,
[])
.
extend
(
prereqs
)
# eliminate duplicate prereqs
# eliminate duplicate prereqs
for
(
node
,
prereqs
)
in
ords
.
items
():
for
(
node
,
prereqs
)
in
ords
.
items
():
ords
[
node
]
=
list
(
OrderedSet
(
prereqs
))
ords
[
node
]
=
list
(
OrderedSet
(
prereqs
))
return
ords
return
ords
...
@@ -624,34 +627,48 @@ class FunctionGraph(utils.object2):
...
@@ -624,34 +627,48 @@ class FunctionGraph(utils.object2):
if
self
.
apply_nodes
!=
nodes
:
if
self
.
apply_nodes
!=
nodes
:
missing
=
nodes
.
difference
(
self
.
apply_nodes
)
missing
=
nodes
.
difference
(
self
.
apply_nodes
)
excess
=
self
.
apply_nodes
.
difference
(
nodes
)
excess
=
self
.
apply_nodes
.
difference
(
nodes
)
raise
Exception
(
"The nodes are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
raise
Exception
(
"The nodes are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
for
node
in
nodes
:
for
node
in
nodes
:
if
node
.
fgraph
is
not
self
:
if
node
.
fgraph
is
not
self
:
raise
Exception
(
"Node should belong to the FunctionGraph."
,
node
)
raise
Exception
(
"Node should belong to the FunctionGraph."
,
node
)
for
i
,
variable
in
enumerate
(
node
.
inputs
):
for
i
,
variable
in
enumerate
(
node
.
inputs
):
if
variable
.
fgraph
is
not
self
:
if
variable
.
fgraph
is
not
self
:
raise
Exception
(
"Input of node should belong to the FunctionGraph."
,
variable
,
(
node
,
i
))
raise
Exception
(
"Input of node should belong to the FunctionGraph."
,
variable
,
(
node
,
i
))
if
(
node
,
i
)
not
in
variable
.
clients
:
if
(
node
,
i
)
not
in
variable
.
clients
:
raise
Exception
(
"Inconsistent clients list."
,
(
node
,
i
),
variable
.
clients
)
raise
Exception
(
"Inconsistent clients list."
,
(
node
,
i
),
variable
.
clients
)
variables
=
set
(
graph
.
variables
(
self
.
inputs
,
self
.
outputs
))
variables
=
set
(
graph
.
variables
(
self
.
inputs
,
self
.
outputs
))
if
set
(
self
.
variables
)
!=
variables
:
if
set
(
self
.
variables
)
!=
variables
:
missing
=
variables
.
difference
(
self
.
variables
)
missing
=
variables
.
difference
(
self
.
variables
)
excess
=
self
.
variables
.
difference
(
variables
)
excess
=
self
.
variables
.
difference
(
variables
)
raise
Exception
(
"The variables are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
raise
Exception
(
"The variables are inappropriately cached. missing, in excess: "
,
missing
,
excess
)
for
variable
in
variables
:
for
variable
in
variables
:
if
variable
.
owner
is
None
and
variable
not
in
self
.
inputs
and
not
isinstance
(
variable
,
graph
.
Constant
):
if
(
variable
.
owner
is
None
and
variable
not
in
self
.
inputs
and
not
isinstance
(
variable
,
graph
.
Constant
)):
raise
Exception
(
"Undeclared input."
,
variable
)
raise
Exception
(
"Undeclared input."
,
variable
)
if
variable
.
fgraph
is
not
self
:
if
variable
.
fgraph
is
not
self
:
raise
Exception
(
"Variable should belong to the FunctionGraph."
,
variable
)
raise
Exception
(
"Variable should belong to the FunctionGraph."
,
variable
)
for
node
,
i
in
variable
.
clients
:
for
node
,
i
in
variable
.
clients
:
if
node
==
'output'
:
if
node
==
'output'
:
if
self
.
outputs
[
i
]
is
not
variable
:
if
self
.
outputs
[
i
]
is
not
variable
:
raise
Exception
(
"Inconsistent clients list."
,
variable
,
self
.
outputs
[
i
])
raise
Exception
(
"Inconsistent clients list."
,
variable
,
self
.
outputs
[
i
])
continue
continue
if
node
not
in
nodes
:
if
node
not
in
nodes
:
raise
Exception
(
"Client not in FunctionGraph."
,
variable
,
(
node
,
i
))
raise
Exception
(
"Client not in FunctionGraph."
,
variable
,
(
node
,
i
))
if
node
.
inputs
[
i
]
is
not
variable
:
if
node
.
inputs
[
i
]
is
not
variable
:
raise
Exception
(
"Inconsistent clients list."
,
variable
,
node
.
inputs
[
i
])
raise
Exception
(
"Inconsistent clients list."
,
variable
,
node
.
inputs
[
i
])
def
__str__
(
self
):
def
__str__
(
self
):
return
"[
%
s]"
%
", "
.
join
(
graph
.
as_string
(
self
.
inputs
,
self
.
outputs
))
return
"[
%
s]"
%
", "
.
join
(
graph
.
as_string
(
self
.
inputs
,
self
.
outputs
))
...
@@ -659,9 +676,7 @@ class FunctionGraph(utils.object2):
...
@@ -659,9 +676,7 @@ class FunctionGraph(utils.object2):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__str__
()
return
self
.
__str__
()
### clone ###
### clone ###
def
clone
(
self
):
def
clone
(
self
):
"""WRITEME"""
"""WRITEME"""
return
self
.
clone_get_equiv
()[
0
]
return
self
.
clone_get_equiv
()[
0
]
...
@@ -671,7 +686,7 @@ class FunctionGraph(utils.object2):
...
@@ -671,7 +686,7 @@ class FunctionGraph(utils.object2):
equiv
=
graph
.
clone_get_equiv
(
self
.
inputs
,
self
.
outputs
)
equiv
=
graph
.
clone_get_equiv
(
self
.
inputs
,
self
.
outputs
)
self
.
check_integrity
()
self
.
check_integrity
()
e
=
FunctionGraph
([
equiv
[
i
]
for
i
in
self
.
inputs
],
e
=
FunctionGraph
([
equiv
[
i
]
for
i
in
self
.
inputs
],
[
equiv
[
o
]
for
o
in
self
.
outputs
])
[
equiv
[
o
]
for
o
in
self
.
outputs
])
e
.
check_integrity
()
e
.
check_integrity
()
for
feature
in
self
.
_features
:
for
feature
in
self
.
_features
:
e
.
attach_feature
(
feature
)
e
.
attach_feature
(
feature
)
...
...
theano/gof/op.py
浏览文件 @
a351e3db
...
@@ -13,6 +13,7 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>"
...
@@ -13,6 +13,7 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__
=
"restructuredtext en"
__docformat__
=
"restructuredtext en"
import
logging
import
logging
import
sys
import
warnings
import
warnings
import
theano
import
theano
...
@@ -408,6 +409,9 @@ class PureOp(object):
...
@@ -408,6 +409,9 @@ class PureOp(object):
elif
config
.
compute_test_value
==
'ignore'
:
elif
config
.
compute_test_value
==
'ignore'
:
# silently skip test
# silently skip test
run_perform
=
False
run_perform
=
False
elif
config
.
compute_test_value
==
'pdb'
:
import
pdb
pdb
.
post_mortem
(
sys
.
exc_info
()[
2
])
else
:
else
:
raise
ValueError
(
'
%
s is invalid for option config.compute_Test_value'
%
config
.
compute_test_value
)
raise
ValueError
(
'
%
s is invalid for option config.compute_Test_value'
%
config
.
compute_test_value
)
...
@@ -638,8 +642,11 @@ def get_test_value(v):
...
@@ -638,8 +642,11 @@ def get_test_value(v):
For a Shared variable, it is the internal value.
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value.
For another Variable, it is the content of v.tag.test_value.
"""
"""
v_tensor
=
theano
.
tensor
.
as_tensor_variable
(
v
)
if
not
isinstance
(
v
,
graph
.
Variable
):
return
PureOp
.
_get_test_value
(
v_tensor
)
v_var
=
theano
.
tensor
.
as_tensor_variable
(
v
)
else
:
v_var
=
v
return
PureOp
.
_get_test_value
(
v_var
)
def
missing_test_message
(
msg
):
def
missing_test_message
(
msg
):
...
...
theano/gof/opt.py
浏览文件 @
a351e3db
...
@@ -421,7 +421,7 @@ class MergeFeature(object):
...
@@ -421,7 +421,7 @@ class MergeFeature(object):
self
.
blacklist
=
[]
self
.
blacklist
=
[]
for
node
in
fgraph
.
toposort
():
for
node
in
fgraph
.
toposort
():
self
.
on_import
(
fgraph
,
node
)
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
# If inputs to node change, it is not guaranteed that it is distinct
# If inputs to node change, it is not guaranteed that it is distinct
...
@@ -433,14 +433,14 @@ class MergeFeature(object):
...
@@ -433,14 +433,14 @@ class MergeFeature(object):
if
isinstance
(
new_r
,
graph
.
Constant
):
if
isinstance
(
new_r
,
graph
.
Constant
):
self
.
process_constant
(
fgraph
,
new_r
)
self
.
process_constant
(
fgraph
,
new_r
)
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
for
c
in
node
.
inputs
:
for
c
in
node
.
inputs
:
if
isinstance
(
c
,
graph
.
Constant
):
if
isinstance
(
c
,
graph
.
Constant
):
self
.
process_constant
(
fgraph
,
c
)
self
.
process_constant
(
fgraph
,
c
)
self
.
process_node
(
fgraph
,
node
)
self
.
process_node
(
fgraph
,
node
)
def
on_prune
(
self
,
fgraph
,
node
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
self
.
nodes_seen
.
discard
(
node
)
self
.
nodes_seen
.
discard
(
node
)
for
c
in
node
.
inputs
:
for
c
in
node
.
inputs
:
if
isinstance
(
c
,
graph
.
Constant
)
and
(
len
(
c
.
clients
)
<=
1
):
if
isinstance
(
c
,
graph
.
Constant
)
and
(
len
(
c
.
clients
)
<=
1
):
...
@@ -548,7 +548,7 @@ class MergeOptimizer(Optimizer):
...
@@ -548,7 +548,7 @@ class MergeOptimizer(Optimizer):
except
InconsistencyError
:
except
InconsistencyError
:
success
=
False
success
=
False
fgraph
.
merge_feature
.
blacklist
.
append
(
fgraph
.
merge_feature
.
blacklist
.
append
(
(
pairs
[
0
][
0
]
.
owner
,
pairs
[
0
][
1
]
.
owner
))
(
pairs
[
0
][
0
]
.
owner
,
pairs
[
0
][
1
]
.
owner
))
if
success
:
if
success
:
break
break
...
@@ -1027,7 +1027,7 @@ class PatternSub(LocalOptimizer):
...
@@ -1027,7 +1027,7 @@ class PatternSub(LocalOptimizer):
else
:
else
:
return
pattern
.
clone
()
return
pattern
.
clone
()
u
=
match
(
self
.
in_pattern
,
node
.
out
,
unify
.
Unification
(),
True
,
u
=
match
(
self
.
in_pattern
,
node
.
out
,
unify
.
Unification
(),
True
,
self
.
pdb
)
self
.
pdb
)
if
u
:
if
u
:
p
=
self
.
out_pattern
p
=
self
.
out_pattern
new
=
build
(
p
,
u
)
new
=
build
(
p
,
u
)
...
@@ -1165,10 +1165,10 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1165,10 +1165,10 @@ class NavigatorOptimizer(Optimizer):
class
Updater
:
class
Updater
:
if
importer
is
not
None
:
if
importer
is
not
None
:
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
importer
(
node
)
importer
(
node
)
if
pruner
is
not
None
:
if
pruner
is
not
None
:
def
on_prune
(
self
,
fgraph
,
node
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
pruner
(
node
)
pruner
(
node
)
if
chin
is
not
None
:
if
chin
is
not
None
:
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
...
@@ -1357,7 +1357,7 @@ class ChangeTracker:
...
@@ -1357,7 +1357,7 @@ class ChangeTracker:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
changed
=
False
self
.
changed
=
False
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
self
.
changed
=
True
self
.
changed
=
True
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
...
...
theano/gof/toolbox.py
浏览文件 @
a351e3db
import
sys
import
sys
import
time
import
time
from
theano
import
config
from
theano.gof.python25
import
partial
from
theano.gof.python25
import
partial
from
theano.gof.python25
import
OrderedDict
from
theano.gof.python25
import
OrderedDict
from
theano.gof
import
graph
from
theano.gof
import
graph
class
AlreadyThere
(
Exception
):
class
AlreadyThere
(
Exception
):
"""Raised by a Feature's on_attach callback method if the FunctionGraph
"""Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical
attempting to attach the feature already has a functionally identical
...
@@ -57,7 +56,7 @@ class Feature(object):
...
@@ -57,7 +56,7 @@ class Feature(object):
functionality that it installed into the function_graph.
functionality that it installed into the function_graph.
"""
"""
def
on_import
(
self
,
function_graph
,
node
):
def
on_import
(
self
,
function_graph
,
node
,
reason
):
"""
"""
Called whenever a node is imported into function_graph, which is
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
just before the node is actually connected to the graph.
...
@@ -66,7 +65,7 @@ class Feature(object):
...
@@ -66,7 +65,7 @@ class Feature(object):
you should do this by implementing on_attach.
you should do this by implementing on_attach.
"""
"""
def
on_prune
(
self
,
function_graph
,
node
):
def
on_prune
(
self
,
function_graph
,
node
,
reason
):
"""
"""
Called whenever a node is pruned (removed) from the function_graph,
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
after it is disconnected from the graph.
...
@@ -98,11 +97,11 @@ class Bookkeeper(Feature):
...
@@ -98,11 +97,11 @@ class Bookkeeper(Feature):
def
on_attach
(
self
,
fgraph
):
def
on_attach
(
self
,
fgraph
):
for
node
in
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
):
for
node
in
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
):
self
.
on_import
(
fgraph
,
node
)
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_detach
(
self
,
fgraph
):
def
on_detach
(
self
,
fgraph
):
for
node
in
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
):
for
node
in
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
):
self
.
on_prune
(
fgraph
,
node
)
self
.
on_prune
(
fgraph
,
node
,
'Bookkeeper.detach'
)
class
History
(
Feature
):
class
History
(
Feature
):
...
@@ -199,11 +198,14 @@ class ReplaceValidate(History, Validator):
...
@@ -199,11 +198,14 @@ class ReplaceValidate(History, Validator):
def
replace_validate
(
self
,
fgraph
,
r
,
new_r
,
reason
=
None
):
def
replace_validate
(
self
,
fgraph
,
r
,
new_r
,
reason
=
None
):
self
.
replace_all_validate
(
fgraph
,
[(
r
,
new_r
)],
reason
=
reason
)
self
.
replace_all_validate
(
fgraph
,
[(
r
,
new_r
)],
reason
=
reason
)
def
replace_all_validate
(
self
,
fgraph
,
replacements
,
reason
=
None
):
def
replace_all_validate
(
self
,
fgraph
,
replacements
,
reason
=
None
,
verbose
=
None
):
chk
=
fgraph
.
checkpoint
()
chk
=
fgraph
.
checkpoint
()
if
verbose
is
None
:
verbose
=
config
.
optimizer_verbose
for
r
,
new_r
in
replacements
:
for
r
,
new_r
in
replacements
:
try
:
try
:
fgraph
.
replace
(
r
,
new_r
,
reason
=
reason
)
fgraph
.
replace
(
r
,
new_r
,
reason
=
reason
,
verbose
=
False
)
except
Exception
,
e
:
except
Exception
,
e
:
if
(
'The type of the replacement must be the same'
not
in
if
(
'The type of the replacement must be the same'
not
in
str
(
e
)
and
'does not belong to this FunctionGraph'
not
in
str
(
e
)):
str
(
e
)
and
'does not belong to this FunctionGraph'
not
in
str
(
e
)):
...
@@ -219,6 +221,8 @@ class ReplaceValidate(History, Validator):
...
@@ -219,6 +221,8 @@ class ReplaceValidate(History, Validator):
except
Exception
,
e
:
except
Exception
,
e
:
fgraph
.
revert
(
chk
)
fgraph
.
revert
(
chk
)
raise
raise
if
verbose
:
print
reason
,
r
,
new_r
return
chk
return
chk
def
replace_all_validate_remove
(
self
,
fgraph
,
replacements
,
def
replace_all_validate_remove
(
self
,
fgraph
,
replacements
,
...
@@ -267,7 +271,7 @@ class NodeFinder(dict, Bookkeeper):
...
@@ -267,7 +271,7 @@ class NodeFinder(dict, Bookkeeper):
del
fgraph
.
get_nodes
del
fgraph
.
get_nodes
Bookkeeper
.
on_detach
(
self
,
fgraph
)
Bookkeeper
.
on_detach
(
self
,
fgraph
)
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
try
:
try
:
self
.
setdefault
(
node
.
op
,
[])
.
append
(
node
)
self
.
setdefault
(
node
.
op
,
[])
.
append
(
node
)
except
TypeError
:
# node.op is unhashable
except
TypeError
:
# node.op is unhashable
...
@@ -280,7 +284,7 @@ class NodeFinder(dict, Bookkeeper):
...
@@ -280,7 +284,7 @@ class NodeFinder(dict, Bookkeeper):
print
>>
sys
.
stderr
,
'OFFENDING node not hashable'
print
>>
sys
.
stderr
,
'OFFENDING node not hashable'
raise
e
raise
e
def
on_prune
(
self
,
fgraph
,
node
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
try
:
try
:
nodes
=
self
[
node
.
op
]
nodes
=
self
[
node
.
op
]
except
TypeError
:
# node.op is unhashable
except
TypeError
:
# node.op is unhashable
...
@@ -312,13 +316,13 @@ class PrintListener(Feature):
...
@@ -312,13 +316,13 @@ class PrintListener(Feature):
if
self
.
active
:
if
self
.
active
:
print
"-- detaching from: "
,
fgraph
print
"-- detaching from: "
,
fgraph
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
if
self
.
active
:
print
"-- importing:
%
s
"
%
node
print
"-- importing:
%
s
, reason:
%
s"
%
(
node
,
reason
)
def
on_prune
(
self
,
fgraph
,
node
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
if
self
.
active
:
print
"-- pruning:
%
s
"
%
node
print
"-- pruning:
%
s
, reason:
%
s"
%
(
node
,
reason
)
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
if
self
.
active
:
if
self
.
active
:
...
...
theano/sandbox/cuda/basic_ops.py
浏览文件 @
a351e3db
...
@@ -2953,7 +2953,6 @@ class GpuJoin(tensor.Join, GpuOp):
...
@@ -2953,7 +2953,6 @@ class GpuJoin(tensor.Join, GpuOp):
axis
=
inputs
[
0
]
axis
=
inputs
[
0
]
n_cndas
=
len
(
inputs
[
1
:])
n_cndas
=
len
(
inputs
[
1
:])
input_1
=
inputs
[
1
]
input_1
=
inputs
[
1
]
axis
=
inputs
[
0
]
fail
=
sub
[
'fail'
]
fail
=
sub
[
'fail'
]
out
=
out_
[
0
]
out
=
out_
[
0
]
...
...
theano/sandbox/linalg/ops.py
浏览文件 @
a351e3db
...
@@ -137,9 +137,9 @@ class HintsFeature(object):
...
@@ -137,9 +137,9 @@ class HintsFeature(object):
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self
.
hints
=
{}
self
.
hints
=
{}
for
node
in
fgraph
.
toposort
():
for
node
in
fgraph
.
toposort
():
self
.
on_import
(
fgraph
,
node
)
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
node
.
outputs
[
0
]
in
self
.
hints
:
if
node
.
outputs
[
0
]
in
self
.
hints
:
# this is a revert, not really an import
# this is a revert, not really an import
for
r
in
node
.
outputs
+
node
.
inputs
:
for
r
in
node
.
outputs
+
node
.
inputs
:
...
...
theano/sandbox/scan_module/scan_utils.py
浏览文件 @
a351e3db
...
@@ -338,7 +338,7 @@ def infer_shape(outs, inputs, input_shapes):
...
@@ -338,7 +338,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph
# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
# It will call infer_shape and set_shape appropriately
dummy_fgraph
=
None
dummy_fgraph
=
None
shape_feature
.
on_import
(
dummy_fgraph
,
out
.
owner
)
shape_feature
.
on_import
(
dummy_fgraph
,
out
.
owner
,
reason
=
"dummy"
)
ret
=
[]
ret
=
[]
for
o
in
outs
:
for
o
in
outs
:
...
...
theano/scalar/basic.py
浏览文件 @
a351e3db
...
@@ -183,6 +183,24 @@ class Scalar(Type):
...
@@ -183,6 +183,24 @@ class Scalar(Type):
def
dtype_specs
(
self
):
def
dtype_specs
(
self
):
try
:
try
:
# To help debug dtype/typenum problem, here is code to get
# the list of numpy typenum. This list change between 32
# and 64 bit platform and probably also also between
# Windows and Linux.
# NOTE: equivalent type on a platform can have different typenum.
# This is the source of all dtype/typenum problem found up to
# now, as Theano always expect the exact typenum that
# correspond to our supported dtype.
"""
for dtype in ['int8', 'uint8', 'short', 'ushort', 'intc', 'uintc',
'longlong', 'ulonglong', 'single', 'double',
'longdouble', 'csingle', 'cdouble', 'clongdouble',
'float32', 'float64', 'int8', 'int16', 'int32',
'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'float', 'double',
'int', 'uint']:
print dtype, np.zeros(1, dtype=dtype).dtype.num
"""
return
{
# dtype: (py_type, c_type, cls_name)
return
{
# dtype: (py_type, c_type, cls_name)
'float32'
:
(
numpy
.
float32
,
'npy_float32'
,
'Float32'
),
'float32'
:
(
numpy
.
float32
,
'npy_float32'
,
'Float32'
),
'float64'
:
(
numpy
.
float64
,
'npy_float64'
,
'Float64'
),
'float64'
:
(
numpy
.
float64
,
'npy_float64'
,
'Float64'
),
...
...
theano/scan_module/scan.py
浏览文件 @
a351e3db
...
@@ -101,7 +101,7 @@ def scan(fn,
...
@@ -101,7 +101,7 @@ def scan(fn,
The order of the sequences is the same as the one in the list
The order of the sequences is the same as the one in the list
`sequences` given to scan. The order of the outputs is the same
`sequences` given to scan. The order of the outputs is the same
as the order of ``output_info``. For any sequence or output the
as the order of ``output
s
_info``. For any sequence or output the
order of the time slices is the same as the one in which they have
order of the time slices is the same as the one in which they have
been given as taps. For example if one writes the following :
been given as taps. For example if one writes the following :
...
@@ -262,7 +262,7 @@ def scan(fn,
...
@@ -262,7 +262,7 @@ def scan(fn,
outputs will have *0 rows*. If the value is negative, ``scan``
outputs will have *0 rows*. If the value is negative, ``scan``
will run backwards in time. If the ``go_backwards`` flag is already
will run backwards in time. If the ``go_backwards`` flag is already
set and also ``n_steps`` is negative, ``scan`` will run forward
set and also ``n_steps`` is negative, ``scan`` will run forward
in time. If n
stpe
s is not provided, ``scan`` will figure
in time. If n
_step
s is not provided, ``scan`` will figure
out the amount of steps it should run given its input sequences.
out the amount of steps it should run given its input sequences.
...
@@ -817,7 +817,7 @@ def scan(fn,
...
@@ -817,7 +817,7 @@ def scan(fn,
if
as_while
:
if
as_while
:
tmp_dummy_f_outs
-=
1
tmp_dummy_f_outs
-=
1
if
not
(
tmp_dummy_f_outs
==
n_outs
or
outs_info
==
[]):
if
not
(
tmp_dummy_f_outs
==
n_outs
or
outs_info
==
[]):
raise
ValueError
(
'Please provide None as output_info for '
raise
ValueError
(
'Please provide None as output
s
_info for '
'any output that does not feed back into '
'any output that does not feed back into '
'scan (i.e. it behaves like a map) '
)
'scan (i.e. it behaves like a map) '
)
...
...
theano/scan_module/scan_op.py
浏览文件 @
a351e3db
...
@@ -1581,8 +1581,30 @@ class Scan(PureOp):
...
@@ -1581,8 +1581,30 @@ class Scan(PureOp):
if
not
isinstance
(
x
.
type
,
DisconnectedType
):
if
not
isinstance
(
x
.
type
,
DisconnectedType
):
outer_inp_seqs
.
append
(
x
[::
-
1
])
outer_inp_seqs
.
append
(
x
[::
-
1
])
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_mitsot_outs
(
outs
)]
if
hasattr
(
inputs
[
0
]
.
tag
,
'test_value'
):
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_sitsot_outs
(
outs
)]
# Here we tests that the new scan input sequence all have
# the same shape[0]. This is a properties that the scan()
# fct add and we want to keep it for all Scan op. This is
# used in T_Scan.test_grad_multiple_outs_taps to test
# that.
for
taps
,
x
in
zip
(
self
.
mitsot_taps
(),
self
.
outer_mitsot_outs
(
outs
)):
mintap
=
numpy
.
min
(
taps
)
if
hasattr
(
x
[::
-
1
][:
mintap
],
'test_value'
):
assert
(
x
[::
-
1
][:
mintap
]
.
tag
.
test_value
.
shape
[
0
]
==
inputs
[
0
]
.
tag
.
test_value
)
for
x
in
self
.
outer_sitsot_outs
(
outs
):
if
hasattr
(
x
[::
-
1
][:
-
1
]
.
tag
,
'test_value'
):
assert
(
x
[::
-
1
][:
-
1
]
.
tag
.
test_value
.
shape
[
0
]
==
inputs
[
0
]
.
tag
.
test_value
)
for
x
in
self
.
outer_nitsot_outs
(
outs
):
if
hasattr
(
x
[::
-
1
]
.
tag
,
'test_value'
):
assert
(
x
[::
-
1
]
.
tag
.
test_value
.
shape
[
0
]
==
inputs
[
0
]
.
tag
.
test_value
)
outer_inp_seqs
+=
[
x
[::
-
1
][:
numpy
.
min
(
taps
)]
for
taps
,
x
in
zip
(
self
.
mitsot_taps
(),
self
.
outer_mitsot_outs
(
outs
))]
outer_inp_seqs
+=
[
x
[::
-
1
][:
-
1
]
for
x
in
self
.
outer_sitsot_outs
(
outs
)]
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_nitsot_outs
(
outs
)]
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_nitsot_outs
(
outs
)]
inner_inp_seqs
=
self
.
inner_seqs
(
self_inputs
)
inner_inp_seqs
=
self
.
inner_seqs
(
self_inputs
)
...
...
theano/scan_module/scan_opt.py
浏览文件 @
a351e3db
...
@@ -66,7 +66,7 @@ def remove_constants_and_unused_inputs_scan(node):
...
@@ -66,7 +66,7 @@ def remove_constants_and_unused_inputs_scan(node):
# We only need to take care of sequences and other arguments
# We only need to take care of sequences and other arguments
st
=
op
.
n_seqs
st
=
op
.
n_seqs
st
+=
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
st
+=
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
op
.
tap_array
[:(
op
.
n_mit_mot
+
op
.
n_mit_sot
)]]))
op
.
tap_array
[:(
op
.
n_mit_mot
+
op
.
n_mit_sot
)]]))
st
+=
op
.
n_sit_sot
st
+=
op
.
n_sit_sot
st
+=
op
.
n_shared_outs
st
+=
op
.
n_shared_outs
op_ins
,
op_outs
=
scan_utils
.
reconstruct_graph
(
op
.
inputs
,
op
.
outputs
)
op_ins
,
op_outs
=
scan_utils
.
reconstruct_graph
(
op
.
inputs
,
op
.
outputs
)
...
@@ -105,8 +105,8 @@ def remove_constants_and_unused_inputs_scan(node):
...
@@ -105,8 +105,8 @@ def remove_constants_and_unused_inputs_scan(node):
elif
op_ins
[
idx
]
in
all_ins
:
elif
op_ins
[
idx
]
in
all_ins
:
# Check for identical other sequence
# Check for identical other sequence
identical_seqs
=
[
x
for
x
in
nw_outer
identical_seqs
=
[
x
for
x
in
nw_outer
if
scan_utils
.
equal_computations
(
if
scan_utils
.
equal_computations
(
[
x
],
[
node
.
inputs
[
idx
+
1
]])]
[
x
],
[
node
.
inputs
[
idx
+
1
]])]
if
identical_seqs
:
if
identical_seqs
:
index
=
node
.
inputs
.
index
(
identical_seqs
[
0
])
-
1
index
=
node
.
inputs
.
index
(
identical_seqs
[
0
])
-
1
givens
[
op_ins
[
idx
]]
=
op_ins
[
index
]
givens
[
op_ins
[
idx
]]
=
op_ins
[
index
]
...
@@ -144,7 +144,7 @@ def remove_constants_and_unused_inputs_scan(node):
...
@@ -144,7 +144,7 @@ def remove_constants_and_unused_inputs_scan(node):
nw_info
[
'n_seqs'
]
=
nw_n_seqs
nw_info
[
'n_seqs'
]
=
nw_n_seqs
# DEBUG CHECK
# DEBUG CHECK
nwScan
=
scan_op
.
Scan
(
nw_inner
,
op_outs
,
nw_info
)
nwScan
=
scan_op
.
Scan
(
nw_inner
,
op_outs
,
nw_info
)
nw_outs
=
nwScan
.
make_node
(
*
nw_outer
)
.
outputs
nw_outs
=
nwScan
(
*
nw_outer
,
**
dict
(
return_list
=
True
))
return
nw_outs
return
nw_outs
else
:
else
:
return
False
return
False
...
@@ -162,7 +162,7 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -162,7 +162,7 @@ class PushOutNonSeqScan(gof.Optimizer):
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
scan_op
.
Scan
)]
scan_op
.
Scan
)]
for
node
in
nodelist
:
for
node
in
nodelist
:
self
.
process_node
(
fgraph
,
node
)
self
.
process_node
(
fgraph
,
node
)
...
@@ -170,7 +170,7 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -170,7 +170,7 @@ class PushOutNonSeqScan(gof.Optimizer):
# this flag tells if there was any change during the last iterations
# this flag tells if there was any change during the last iterations
changed
=
True
changed
=
True
clean_inputs
,
clean_outputs
=
scan_utils
.
reconstruct_graph
(
clean_inputs
,
clean_outputs
=
scan_utils
.
reconstruct_graph
(
node
.
op
.
inputs
,
node
.
op
.
outputs
)
node
.
op
.
inputs
,
node
.
op
.
outputs
)
local_fgraph
=
gof
.
FunctionGraph
(
clean_inputs
,
clean_outputs
)
local_fgraph
=
gof
.
FunctionGraph
(
clean_inputs
,
clean_outputs
)
max_iterations
=
2
*
len
(
local_fgraph
.
toposort
())
+
3
max_iterations
=
2
*
len
(
local_fgraph
.
toposort
())
+
3
...
@@ -196,7 +196,7 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -196,7 +196,7 @@ class PushOutNonSeqScan(gof.Optimizer):
if
(
numpy
.
all
([(
x
in
inner_non_seqs
)
or
if
(
numpy
.
all
([(
x
in
inner_non_seqs
)
or
(
x
.
owner
in
to_remove
)
or
(
x
.
owner
in
to_remove
)
or
isinstance
(
x
,
tensor
.
Constant
)
isinstance
(
x
,
tensor
.
Constant
)
for
x
in
nd
.
inputs
])
and
for
x
in
nd
.
inputs
])
and
# we can do this because the assumption is that a
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
# function and not somewhere in the middle ..
...
@@ -227,7 +227,11 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -227,7 +227,11 @@ class PushOutNonSeqScan(gof.Optimizer):
'this on theano-users list'
),
x
)
'this on theano-users list'
),
x
)
outside_ins
=
[
x
.
type
.
filter_variable
(
y
)
for
x
,
y
in
outside_ins
=
[
x
.
type
.
filter_variable
(
y
)
for
x
,
y
in
zip
(
nd
.
inputs
,
outside_ins
)]
zip
(
nd
.
inputs
,
outside_ins
)]
nw_outer_node
=
nd
.
op
.
make_node
(
*
outside_ins
)
# Do not call make_node for test_value
nw_outer_node
=
nd
.
op
(
*
outside_ins
,
**
dict
(
return_list
=
True
))[
0
]
.
owner
# Step 2. Create variables for replacements
# Step 2. Create variables for replacements
for
idx
,
y
in
enumerate
(
nd
.
outputs
):
for
idx
,
y
in
enumerate
(
nd
.
outputs
):
...
@@ -250,7 +254,7 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -250,7 +254,7 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_replace_with_in
=
[]
clean_replace_with_in
=
[]
clean_replace_with_out
=
[]
clean_replace_with_out
=
[]
existent_nodes
=
[
nd
for
nd
in
local_fgraph
.
toposort
()
existent_nodes
=
[
nd
for
nd
in
local_fgraph
.
toposort
()
if
nd
not
in
to_remove
]
if
nd
not
in
to_remove
]
to_keep
=
[]
to_keep
=
[]
for
nd
in
existent_nodes
:
for
nd
in
existent_nodes
:
to_keep
+=
nd
.
inputs
to_keep
+=
nd
.
inputs
...
@@ -270,8 +274,8 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -270,8 +274,8 @@ class PushOutNonSeqScan(gof.Optimizer):
nw_outer
=
[]
nw_outer
=
[]
nw_inner
=
[]
nw_inner
=
[]
for
to_repl
,
repl_in
,
repl_out
in
zip
(
clean_to_replace
,
for
to_repl
,
repl_in
,
repl_out
in
zip
(
clean_to_replace
,
clean_replace_with_in
,
clean_replace_with_in
,
clean_replace_with_out
):
clean_replace_with_out
):
if
isinstance
(
repl_out
,
theano
.
Constant
):
if
isinstance
(
repl_out
,
theano
.
Constant
):
repl_in
=
repl_out
.
clone
()
repl_in
=
repl_out
.
clone
()
else
:
else
:
...
@@ -285,11 +289,15 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -285,11 +289,15 @@ class PushOutNonSeqScan(gof.Optimizer):
op_ins
,
op_outs
=
scan_utils
.
reconstruct_graph
(
_op_ins
,
_op_outs
)
op_ins
,
op_outs
=
scan_utils
.
reconstruct_graph
(
_op_ins
,
_op_outs
)
# Reconstruct node
# Reconstruct node
nwScan
=
scan_op
.
Scan
(
op_ins
,
op_outs
,
op
.
info
)
nwScan
=
scan_op
.
Scan
(
op_ins
,
op_outs
,
op
.
info
)
nw_node
=
nwScan
.
make_node
(
*
(
node
.
inputs
+
nw_outer
))
# Do not call make_node for test_value
nw_node
=
nwScan
(
*
(
node
.
inputs
+
nw_outer
),
**
dict
(
return_list
=
True
))[
0
]
.
owner
fgraph
.
replace_all_validate_remove
(
fgraph
.
replace_all_validate_remove
(
zip
(
node
.
outputs
,
nw_node
.
outputs
),
zip
(
node
.
outputs
,
nw_node
.
outputs
),
remove
=
[
node
],
remove
=
[
node
],
reason
=
'scan
_push_computation_out
'
)
reason
=
'scan
Op_pushout_nonseqs_ops
'
)
return
True
return
True
elif
to_keep
==
[]:
elif
to_keep
==
[]:
# Nothing in the inner graph should be kept
# Nothing in the inner graph should be kept
...
@@ -310,7 +318,7 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -310,7 +318,7 @@ class PushOutNonSeqScan(gof.Optimizer):
fgraph
.
replace_all_validate_remove
(
fgraph
.
replace_all_validate_remove
(
replace_with
.
items
(),
replace_with
.
items
(),
remove
=
[
node
],
remove
=
[
node
],
reason
=
'scan
_push_computation_out
'
)
reason
=
'scan
Op_pushout_nonseqs_ops
'
)
else
:
else
:
return
False
return
False
...
@@ -327,8 +335,8 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -327,8 +335,8 @@ class PushOutSeqScan(gof.Optimizer):
fgraph
.
attach_feature
(
gof
.
toolbox
.
ReplaceValidate
())
fgraph
.
attach_feature
(
gof
.
toolbox
.
ReplaceValidate
())
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
scan_op
.
Scan
)]
if
isinstance
(
x
.
op
,
scan_op
.
Scan
)]
for
node
in
nodelist
:
for
node
in
nodelist
:
self
.
process_node
(
fgraph
,
node
)
self
.
process_node
(
fgraph
,
node
)
...
@@ -336,7 +344,7 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -336,7 +344,7 @@ class PushOutSeqScan(gof.Optimizer):
# this flag tells if there was any change during the last iterations
# this flag tells if there was any change during the last iterations
changed
=
True
changed
=
True
clean_inputs
,
clean_outputs
=
scan_utils
.
reconstruct_graph
(
clean_inputs
,
clean_outputs
=
scan_utils
.
reconstruct_graph
(
node
.
op
.
inputs
,
node
.
op
.
outputs
)
node
.
op
.
inputs
,
node
.
op
.
outputs
)
local_fgraph
=
gof
.
FunctionGraph
(
clean_inputs
,
clean_outputs
)
local_fgraph
=
gof
.
FunctionGraph
(
clean_inputs
,
clean_outputs
)
max_iterations
=
2
*
len
(
local_fgraph
.
toposort
())
+
3
max_iterations
=
2
*
len
(
local_fgraph
.
toposort
())
+
3
...
@@ -361,12 +369,12 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -361,12 +369,12 @@ class PushOutSeqScan(gof.Optimizer):
for
nd
in
local_fgraph
.
toposort
():
for
nd
in
local_fgraph
.
toposort
():
if
(
isinstance
(
nd
.
op
,
theano
.
tensor
.
Elemwise
)
and
if
(
isinstance
(
nd
.
op
,
theano
.
tensor
.
Elemwise
)
and
numpy
.
all
([(
x
in
inner_non_seqs
)
or
numpy
.
all
([(
x
in
inner_non_seqs
)
or
(
x
.
owner
in
to_remove
)
or
(
x
.
owner
in
to_remove
)
or
isinstance
(
x
,
tensor
.
Constant
)
or
isinstance
(
x
,
tensor
.
Constant
)
or
(
x
in
inner_seqs
)
(
x
in
inner_seqs
)
for
x
in
nd
.
inputs
])
and
for
x
in
nd
.
inputs
])
and
not
nd
in
to_remove
):
not
nd
in
to_remove
):
to_remove
.
append
(
nd
)
to_remove
.
append
(
nd
)
outside_ins
=
[]
outside_ins
=
[]
for
x
in
nd
.
inputs
:
for
x
in
nd
.
inputs
:
...
@@ -376,18 +384,21 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -376,18 +384,21 @@ class PushOutSeqScan(gof.Optimizer):
elif
x
in
inner_seqs
:
elif
x
in
inner_seqs
:
outside_ins
+=
[
outer_seqs
[
inner_seqs
.
index
(
x
)]]
outside_ins
+=
[
outer_seqs
[
inner_seqs
.
index
(
x
)]]
elif
x
in
to_replace
:
elif
x
in
to_replace
:
outside_ins
+=
[
replace_with_out
[
\
outside_ins
+=
[
replace_with_out
[
to_replace
.
index
(
x
)]]
to_replace
.
index
(
x
)]]
elif
isinstance
(
x
,
theano
.
Constant
):
elif
isinstance
(
x
,
theano
.
Constant
):
outside_ins
+=
[
x
.
clone
()]
outside_ins
+=
[
x
.
clone
()]
else
:
else
:
raise
Exception
(
raise
Exception
(
(
'Error in the `scan_pushout_
non_
seq_'
(
'Error in the `scan_pushout_seq_'
'operations`. The optimization tries '
'operations`. The optimization tries '
'to move some computation fron scan '
'to move some computation fron scan '
'which is not allowed to move. Report '
'which is not allowed to move. Report '
'this on theano-users list'
),
x
)
'this on theano-users list'
),
x
)
nw_outer_node
=
nd
.
op
.
make_node
(
*
outside_ins
)
# Do not call make_node for test_value
nw_outer_node
=
nd
.
op
(
*
outside_ins
,
**
dict
(
return_list
=
True
))[
0
]
.
owner
# Step 2. Create variables for replacements
# Step 2. Create variables for replacements
for
idx
,
y
in
enumerate
(
nd
.
outputs
):
for
idx
,
y
in
enumerate
(
nd
.
outputs
):
...
@@ -420,10 +431,15 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -420,10 +431,15 @@ class PushOutSeqScan(gof.Optimizer):
to_replace
+=
[
y
]
to_replace
+=
[
y
]
replace_with_in
+=
[
y_place_holder
]
replace_with_in
+=
[
y_place_holder
]
replace_with_out
+=
[
new_outer
]
replace_with_out
+=
[
new_outer
]
if
hasattr
(
new_outer
.
tag
,
"test_value"
):
new_sh
=
new_outer
.
tag
.
test_value
.
shape
ref_sh
=
(
outside_ins
.
tag
.
test_value
.
shape
[
0
],)
ref_sh
+=
nd
.
outputs
[
0
]
.
tag
.
test_value
.
shape
assert
new_sh
==
ref_sh
changed
=
True
changed
=
True
if
counts
>=
max_iterations
:
if
counts
>=
max_iterations
:
raise
Exception
(
'Error in the `scan_pushout_
non_
seq_operations`.'
raise
Exception
(
'Error in the `scan_pushout_seq_operations`.'
' The optimization exhausted the maximal number '
' The optimization exhausted the maximal number '
'of iterations allowed!'
)
'of iterations allowed!'
)
# We need to check all candidate replacements and choose those that
# We need to check all candidate replacements and choose those that
...
@@ -436,7 +452,7 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -436,7 +452,7 @@ class PushOutSeqScan(gof.Optimizer):
clean_replace_with_out
=
[]
clean_replace_with_out
=
[]
existent_nodes
=
[
nd
for
nd
in
local_fgraph
.
toposort
()
existent_nodes
=
[
nd
for
nd
in
local_fgraph
.
toposort
()
if
nd
not
in
to_remove
]
if
nd
not
in
to_remove
]
to_keep
=
[]
to_keep
=
[]
for
nd
in
existent_nodes
:
for
nd
in
existent_nodes
:
to_keep
+=
nd
.
inputs
to_keep
+=
nd
.
inputs
...
@@ -456,8 +472,8 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -456,8 +472,8 @@ class PushOutSeqScan(gof.Optimizer):
nw_outer
=
[]
nw_outer
=
[]
nw_inner
=
[]
nw_inner
=
[]
for
to_repl
,
repl_in
,
repl_out
in
zip
(
clean_to_replace
,
for
to_repl
,
repl_in
,
repl_out
in
zip
(
clean_to_replace
,
clean_replace_with_in
,
clean_replace_with_in
,
clean_replace_with_out
):
clean_replace_with_out
):
if
isinstance
(
repl_out
,
theano
.
Constant
):
if
isinstance
(
repl_out
,
theano
.
Constant
):
repl_in
=
repl_out
.
clone
()
repl_in
=
repl_out
.
clone
()
else
:
else
:
...
@@ -473,12 +489,14 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -473,12 +489,14 @@ class PushOutSeqScan(gof.Optimizer):
nw_info
=
op
.
info
.
copy
()
nw_info
=
op
.
info
.
copy
()
nw_info
[
'n_seqs'
]
+=
len
(
nw_inner
)
nw_info
[
'n_seqs'
]
+=
len
(
nw_inner
)
nwScan
=
scan_op
.
Scan
(
op_ins
,
op_outs
,
nw_info
)
nwScan
=
scan_op
.
Scan
(
op_ins
,
op_outs
,
nw_info
)
nw_node
=
nwScan
.
make_node
(
*
(
node
.
inputs
[:
1
]
+
nw_outer
+
# Do not call make_node for test_value
node
.
inputs
[
1
:]))
nw_node
=
nwScan
(
*
(
node
.
inputs
[:
1
]
+
nw_outer
+
node
.
inputs
[
1
:]),
**
dict
(
return_list
=
True
))[
0
]
.
owner
fgraph
.
replace_all_validate_remove
(
fgraph
.
replace_all_validate_remove
(
zip
(
node
.
outputs
,
nw_node
.
outputs
),
zip
(
node
.
outputs
,
nw_node
.
outputs
),
remove
=
[
node
],
remove
=
[
node
],
reason
=
'scan
_push_computation_out
'
)
reason
=
'scan
Op_pushout_seqs_ops
'
)
return
True
return
True
elif
(
to_keep
==
[]
and
elif
(
to_keep
==
[]
and
not
op
.
as_while
and
not
op
.
as_while
and
...
@@ -510,8 +528,8 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -510,8 +528,8 @@ class PushOutSeqScan(gof.Optimizer):
fgraph
.
replace_all_validate_remove
(
fgraph
.
replace_all_validate_remove
(
replace_with
.
items
(),
replace_with
.
items
(),
remove
=
[
node
],
remove
=
[
node
],
reason
=
'scan
_push_seq_computation_out
'
)
reason
=
'scan
Op_pushout_seqs_ops
'
)
return
True
else
:
else
:
return
False
return
False
...
@@ -532,7 +550,7 @@ class ScanInplaceOptimizer(Optimizer):
...
@@ -532,7 +550,7 @@ class ScanInplaceOptimizer(Optimizer):
nodes
=
fgraph
.
toposort
()
nodes
=
fgraph
.
toposort
()
scan_nodes
=
[
x
for
x
in
nodes
scan_nodes
=
[
x
for
x
in
nodes
if
(
isinstance
(
x
.
op
,
scan_op
.
Scan
)
and
if
(
isinstance
(
x
.
op
,
scan_op
.
Scan
)
and
x
.
op
.
info
[
'gpu'
]
==
self
.
gpu_flag
)]
x
.
op
.
info
[
'gpu'
]
==
self
.
gpu_flag
)]
for
scan_idx
in
xrange
(
len
(
scan_nodes
)):
for
scan_idx
in
xrange
(
len
(
scan_nodes
)):
node
=
scan_nodes
[
scan_idx
]
node
=
scan_nodes
[
scan_idx
]
op
=
node
.
op
op
=
node
.
op
...
@@ -563,12 +581,13 @@ class ScanInplaceOptimizer(Optimizer):
...
@@ -563,12 +581,13 @@ class ScanInplaceOptimizer(Optimizer):
info
,
info
,
typeConstructor
=
self
.
typeConstructor
)
typeConstructor
=
self
.
typeConstructor
)
new_outs
=
new_op
.
make_node
(
*
inputs
)
.
outputs
# Do not call make_node for test_value
new_outs
=
new_op
(
*
inputs
,
**
dict
(
return_list
=
True
))
try
:
try
:
fgraph
.
replace_all_validate_remove
(
fgraph
.
replace_all_validate_remove
(
zip
(
node
.
outputs
,
new_outs
),
zip
(
node
.
outputs
,
new_outs
),
remove
=
[
node
],
remove
=
[
node
],
reason
=
self
.
__class__
.
__name__
)
reason
=
'scanOp_make_inplace'
)
op
=
new_op
op
=
new_op
node
=
new_outs
[
0
]
.
owner
node
=
new_outs
[
0
]
.
owner
except
InconsistencyError
,
e
:
except
InconsistencyError
,
e
:
...
@@ -720,7 +739,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -720,7 +739,7 @@ class ScanSaveMem(gof.Optimizer):
except
KeyError
:
except
KeyError
:
length
=
out
.
shape
[
0
]
length
=
out
.
shape
[
0
]
cf_slice
=
tensor
.
get_canonical_form_slice
(
cf_slice
=
tensor
.
get_canonical_form_slice
(
this_slice
[
0
],
length
)
this_slice
[
0
],
length
)
slices
[
i
]
+=
[(
cf_slice
,
this_slice
)]
slices
[
i
]
+=
[(
cf_slice
,
this_slice
)]
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
...
@@ -847,9 +866,8 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -847,9 +866,8 @@ class ScanSaveMem(gof.Optimizer):
nw_inputs
[
0
]
=
nw_steps
nw_inputs
[
0
]
=
nw_steps
# 3.2 check orphane outputs to see if we can eliminate any
# 3.2 check orphane outputs to see if we can eliminate any
required
,
not_required
=
\
required
,
not_required
=
scan_utils
.
scan_can_remove_outs
(
scan_utils
.
scan_can_remove_outs
(
node
.
op
,
node
.
op
,
orphane_outs
)
orphane_outs
)
# 3.3. compose replace pairs for those nodes that need not
# 3.3. compose replace pairs for those nodes that need not
# to store everything in memory ( or ar orphane and required
# to store everything in memory ( or ar orphane and required
# by the inner function .. )
# by the inner function .. )
...
@@ -947,9 +965,10 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -947,9 +965,10 @@ class ScanSaveMem(gof.Optimizer):
# I need to make sure I'm not reapplying the same optimization
# I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that
# twice since bad things usually happen if I do that
info
[
'_scan_savemem_visited'
]
=
True
info
[
'_scan_savemem_visited'
]
=
True
new_outs
=
scan_op
.
Scan
(
inps
,
outs
,
# Do not call make_node for test_value
info
)
.
make_node
(
*
node_ins
)
.
outputs
new_outs
=
scan_op
.
Scan
(
inps
,
outs
,
info
)(
*
node_ins
,
**
dict
(
return_list
=
True
))
old_new
=
[]
old_new
=
[]
# 3.7 Get replace pairs for those outputs that do not change
# 3.7 Get replace pairs for those outputs that do not change
...
@@ -978,9 +997,8 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -978,9 +997,8 @@ class ScanSaveMem(gof.Optimizer):
sl_ins
=
tensor
.
Subtensor
.
collapse
(
sl_ins
=
tensor
.
Subtensor
.
collapse
(
nw_slice
,
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
lambda
entry
:
isinstance
(
entry
,
tensor
.
Variable
))
tensor
.
Variable
))
new_o
=
subtens
.
make_node
(
new_outs
[
nw_pos
],
new_o
=
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
)
*
sl_ins
)
.
outputs
[
0
]
if
new_o
.
ndim
>
0
:
if
new_o
.
ndim
>
0
:
new_o
=
new_o
[::
cnf_slice
[
1
]]
new_o
=
new_o
[::
cnf_slice
[
1
]]
replaced_outs
.
append
(
idx
)
replaced_outs
.
append
(
idx
)
...
@@ -1009,18 +1027,16 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1009,18 +1027,16 @@ class ScanSaveMem(gof.Optimizer):
else
:
else
:
position
=
(
cnf_slice
[
0
]
-
nw_steps
-
position
=
(
cnf_slice
[
0
]
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
])
init_l
[
pos
]
+
store_steps
[
pos
])
nw_slice
=
(
sanitize
(
position
),)
+
\
tuple
(
old_slices
[
1
:])
nw_slice
=
(
sanitize
(
position
),)
+
tuple
(
old_slices
[
1
:])
subtens
=
tensor
.
Subtensor
(
nw_slice
)
subtens
=
tensor
.
Subtensor
(
nw_slice
)
sl_ins
=
tensor
.
Subtensor
.
collapse
(
sl_ins
=
tensor
.
Subtensor
.
collapse
(
nw_slice
,
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
lambda
entry
:
isinstance
(
entry
,
tensor
.
Variable
))
tensor
.
Variable
))
new_o
=
subtens
.
make_node
(
new_outs
[
nw_pos
],
new_o
=
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
)
*
sl_ins
)
.
outputs
[
0
]
if
new_o
.
ndim
>
0
:
if
new_o
.
ndim
>
0
:
new_o
=
new_o
[::
cnf_slice
[
1
]]
new_o
=
new_o
[::
cnf_slice
[
1
]]
old_new
+=
[(
old
,
new_o
)]
old_new
+=
[(
old
,
new_o
)]
...
@@ -1042,12 +1058,12 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1042,12 +1058,12 @@ class ScanSaveMem(gof.Optimizer):
remove
.
append
(
node
)
remove
.
append
(
node
)
fgraph
.
replace_all_validate_remove
(
old_new
,
fgraph
.
replace_all_validate_remove
(
old_new
,
remove
,
remove
,
reason
=
'scan_save_mem'
)
reason
=
'scan
Op
_save_mem'
)
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
nodelist
=
[
x
for
x
in
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
scan_op
.
Scan
)]
scan_op
.
Scan
)]
for
node
in
nodelist
:
for
node
in
nodelist
:
if
not
hasattr
(
node
.
op
,
'_scan_savemem_visited'
):
if
not
hasattr
(
node
.
op
,
'_scan_savemem_visited'
):
self
.
process_node
(
fgraph
,
node
)
self
.
process_node
(
fgraph
,
node
)
...
@@ -1230,7 +1246,7 @@ class ScanMerge(gof.Optimizer):
...
@@ -1230,7 +1246,7 @@ class ScanMerge(gof.Optimizer):
proposal
=
self
.
merge
(
subset
)
proposal
=
self
.
merge
(
subset
)
fgraph
.
replace_all_validate_remove
(
proposal
,
fgraph
.
replace_all_validate_remove
(
proposal
,
remove
=
subset
,
remove
=
subset
,
reason
=
'scan_merge'
)
reason
=
'scan
Op
_merge'
)
def
has_duplicates
(
l
):
def
has_duplicates
(
l
):
...
@@ -1389,13 +1405,13 @@ def scan_merge_inouts(node):
...
@@ -1389,13 +1405,13 @@ def scan_merge_inouts(node):
# items scan is supposed to store for this nit_sot sequence
# items scan is supposed to store for this nit_sot sequence
shapes
.
append
(
x
)
shapes
.
append
(
x
)
tmp
=
[
map_nitsot_out
(
i
,
o
,
sh
,
seen
)
tmp
=
[
map_nitsot_out
(
i
,
o
,
sh
,
seen
)
for
i
,
o
,
sh
in
zip
(
na
.
inner_out_nit_sot
,
for
i
,
o
,
sh
in
zip
(
na
.
inner_out_nit_sot
,
na
.
outer_out_nit_sot
,
na
.
outer_out_nit_sot
,
shapes
)]
shapes
)]
na
.
outer_out_nit_sot
=
[
map_nitsot_out
(
i
,
o
,
sh
,
seen
)
na
.
outer_out_nit_sot
=
[
map_nitsot_out
(
i
,
o
,
sh
,
seen
)
for
i
,
o
,
sh
in
zip
(
na
.
inner_out_nit_sot
,
for
i
,
o
,
sh
in
zip
(
na
.
inner_out_nit_sot
,
na
.
outer_out_nit_sot
,
na
.
outer_out_nit_sot
,
shapes
)]
shapes
)]
seen
=
[]
seen
=
[]
na
.
outer_out_sit_sot
=
[
map_out
(
i
,
o
,
seen
)
na
.
outer_out_sit_sot
=
[
map_out
(
i
,
o
,
seen
)
...
@@ -1592,10 +1608,8 @@ class PushOutDot1(gof.Optimizer):
...
@@ -1592,10 +1608,8 @@ class PushOutDot1(gof.Optimizer):
old
=
node
.
outputs
[
pos
]
.
clients
[
0
][
0
]
.
outputs
[
0
]
old
=
node
.
outputs
[
pos
]
.
clients
[
0
][
0
]
.
outputs
[
0
]
old_new
.
append
((
old
,
new_out
))
old_new
.
append
((
old
,
new_out
))
old_new
+=
zip
(
node
.
outputs
[
pos
+
1
:],
new_outs
[
pos
:])
old_new
+=
zip
(
node
.
outputs
[
pos
+
1
:],
new_outs
[
pos
:])
fgraph
.
replace_all_validate_remove
(
old_new
,
fgraph
.
replace_all_validate_remove
(
remove
=
[
node
],
old_new
,
remove
=
[
node
],
reason
=
'scan_pushout_dot1'
)
reason
=
'PushOutDot1'
)
# I've added an equilibrium because later scan optimization in the sequence
# I've added an equilibrium because later scan optimization in the sequence
...
@@ -1612,7 +1626,7 @@ optdb.register('scan_eqopt1', scan_eqopt1, .1, 'fast_run', 'scan')
...
@@ -1612,7 +1626,7 @@ optdb.register('scan_eqopt1', scan_eqopt1, .1, 'fast_run', 'scan')
optdb
.
register
(
'scan_eqopt2'
,
scan_eqopt2
,
1.6
,
'fast_run'
,
'scan'
)
optdb
.
register
(
'scan_eqopt2'
,
scan_eqopt2
,
1.6
,
'fast_run'
,
'scan'
)
optdb
.
register
(
'scanOp_make_inplace'
,
optdb
.
register
(
'scanOp_make_inplace'
,
ScanInplaceOptimizer
(
typeConstructor
=
None
,
ScanInplaceOptimizer
(
typeConstructor
=
None
,
gpu_flag
=
False
),
gpu_flag
=
False
),
75
,
75
,
'fast_run'
,
'fast_run'
,
'inplace'
,
'inplace'
,
...
@@ -1628,6 +1642,7 @@ scan_seqopt1.register('scanOp_remove_constants_and_unused_inputs0',
...
@@ -1628,6 +1642,7 @@ scan_seqopt1.register('scanOp_remove_constants_and_unused_inputs0',
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
ignore_newtrees
=
True
),
1
,
1
,
'remove_constants_and_unused_inputs_scan'
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
...
@@ -1662,10 +1677,11 @@ scan_seqopt2.register('constant_folding_for_scan2',
...
@@ -1662,10 +1677,11 @@ scan_seqopt2.register('constant_folding_for_scan2',
'scan'
)
'scan'
)
scan_seqopt2
.
register
(
'scanOp_remove_constants_and_unused_inputs
0
'
,
scan_seqopt2
.
register
(
'scanOp_remove_constants_and_unused_inputs
1
'
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
ignore_newtrees
=
True
),
2
,
2
,
'remove_constants_and_unused_inputs_scan'
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
...
@@ -1684,12 +1700,14 @@ scan_seqopt2.register('scanop_remove_constants_and_unused_inputs2',
...
@@ -1684,12 +1700,14 @@ scan_seqopt2.register('scanop_remove_constants_and_unused_inputs2',
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
ignore_newtrees
=
True
),
5
,
5
,
'remove_constants_and_unused_inputs_scan'
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
scan_seqopt2
.
register
(
'scanOp_merge_inouts'
,
scan_seqopt2
.
register
(
'scanOp_merge_inouts'
,
opt
.
in2out
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
opt
.
in2out
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
6
,
6
,
'scan_merge_inouts'
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
...
@@ -1707,5 +1725,6 @@ scan_seqopt2.register('scanOp_remove_constants_and_unused_inputs3',
...
@@ -1707,5 +1725,6 @@ scan_seqopt2.register('scanOp_remove_constants_and_unused_inputs3',
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
ignore_newtrees
=
True
),
8
,
8
,
'remove_constants_and_unused_inputs_scan'
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
theano/scan_module/scan_utils.py
浏览文件 @
a351e3db
...
@@ -500,7 +500,7 @@ def infer_shape(outs, inputs, input_shapes):
...
@@ -500,7 +500,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph
# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
# It will call infer_shape and set_shape appropriately
dummy_fgraph
=
None
dummy_fgraph
=
None
shape_feature
.
on_import
(
dummy_fgraph
,
out
.
owner
)
shape_feature
.
on_import
(
dummy_fgraph
,
out
.
owner
,
reason
=
"dummy"
)
ret
=
[]
ret
=
[]
for
o
in
outs
:
for
o
in
outs
:
...
...
theano/scan_module/tests/test_scan.py
浏览文件 @
a351e3db
...
@@ -1141,7 +1141,7 @@ class T_Scan(unittest.TestCase):
...
@@ -1141,7 +1141,7 @@ class T_Scan(unittest.TestCase):
go_backwards
=
False
)
go_backwards
=
False
)
gX
,
gY
=
tensor
.
grad
(
values
[
1
]
.
sum
(),
[
x
,
y
])
gX
,
gY
=
tensor
.
grad
(
values
[
1
]
.
sum
(),
[
x
,
y
])
f
=
theano
.
function
([
c
,
x
,
y
],
[
gX
,
gY
],
f
=
theano
.
function
([
c
,
x
,
y
],
[
gX
,
gY
],
allow_input_downcast
=
True
)
allow_input_downcast
=
True
)
# Check for runtime errors
# Check for runtime errors
f
(
numpy
.
int32
(
0
),
numpy
.
float32
(
1.
),
numpy
.
float32
(
.
5
))
f
(
numpy
.
int32
(
0
),
numpy
.
float32
(
1.
),
numpy
.
float32
(
.
5
))
...
@@ -1545,6 +1545,12 @@ class T_Scan(unittest.TestCase):
...
@@ -1545,6 +1545,12 @@ class T_Scan(unittest.TestCase):
x0
=
theano
.
tensor
.
vector
(
'x0'
)
x0
=
theano
.
tensor
.
vector
(
'x0'
)
y0
=
theano
.
tensor
.
vector
(
'y0'
)
y0
=
theano
.
tensor
.
vector
(
'y0'
)
W_in1
.
tag
.
test_value
=
vW_in1
u1
.
tag
.
test_value
=
v_u1
u2
.
tag
.
test_value
=
v_u2
x0
.
tag
.
test_value
=
v_x0
y0
.
tag
.
test_value
=
v_y0
def
f_rnn_cmpl
(
u1_t
,
def
f_rnn_cmpl
(
u1_t
,
u2_tm1
,
u2_tm1
,
u2_t
,
u2_t
,
...
@@ -1553,33 +1559,46 @@ class T_Scan(unittest.TestCase):
...
@@ -1553,33 +1559,46 @@ class T_Scan(unittest.TestCase):
y_tm1
,
y_tm1
,
y_tm3
,
y_tm3
,
W_in1
):
W_in1
):
return
[
theano
.
dot
(
u1_t
,
W_in1
)
+
\
return
[
theano
.
dot
(
u1_t
,
W_in1
)
+
(
u2_t
+
u2_tm1
*
u2_tp1
)
*
W_in2
+
\
(
u2_t
+
u2_tm1
*
u2_tp1
)
*
W_in2
+
theano
.
dot
(
x_tm1
,
W
),
theano
.
dot
(
x_tm1
,
W
),
(
y_tm1
+
y_tm3
)
*
theano
.
dot
(
x_tm1
,
W_out
),
(
y_tm1
+
y_tm3
)
*
theano
.
dot
(
x_tm1
,
W_out
),
theano
.
dot
(
u1_t
,
W_in1
)]
theano
.
dot
(
u1_t
,
W_in1
)]
cost
,
updates
=
scan_project_sum
(
f_rnn_cmpl
,
[
u1
,
dict
(
input
=
u2
,
taps
=
[
-
1
,
0
,
1
])],
[
x0
,
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
]),
None
],
W_in1
,
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
)
vparams
=
[
v_u1
,
v_u2
,
v_x0
,
v_y0
,
vW_in1
]
params
=
[
u1
,
u2
,
x0
,
y0
,
W_in1
]
gparams
=
theano
.
tensor
.
grad
(
cost
,
params
)
grad_fn
=
theano
.
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
gparams
,
updates
=
updates
,
no_default_updates
=
True
,
allow_input_downcast
=
True
)
cost_fn
=
theano
.
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
# We change the compute_test_value[_opt] flag to run the
cost
,
# assert in Scan.grad() of the new scan input sequence related
updates
=
updates
,
# to outer_mitsot_outs, outer_sitsot_outs and
no_default_updates
=
True
,
# outer_nitsot_outs. This allow to test an old Scan bug.
allow_input_downcast
=
True
)
old1
=
theano
.
config
.
compute_test_value
old2
=
theano
.
config
.
compute_test_value_opt
theano
.
config
.
compute_test_value
=
'raise'
theano
.
config
.
compute_test_value_opt
=
'raise'
try
:
cost
,
updates
=
scan_project_sum
(
f_rnn_cmpl
,
[
u1
,
dict
(
input
=
u2
,
taps
=
[
-
1
,
0
,
1
])],
[
x0
,
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
]),
None
],
W_in1
,
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
)
vparams
=
[
v_u1
,
v_u2
,
v_x0
,
v_y0
,
vW_in1
]
params
=
[
u1
,
u2
,
x0
,
y0
,
W_in1
]
gparams
=
theano
.
tensor
.
grad
(
cost
,
params
)
grad_fn
=
theano
.
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
gparams
,
updates
=
updates
,
no_default_updates
=
True
,
allow_input_downcast
=
True
)
cost_fn
=
theano
.
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
cost
,
updates
=
updates
,
no_default_updates
=
True
,
allow_input_downcast
=
True
)
finally
:
theano
.
config
.
compute_test_value
=
old1
theano
.
config
.
compute_test_value_opt
=
old2
num_grad
=
multiple_outputs_numeric_grad
(
cost_fn
,
num_grad
=
multiple_outputs_numeric_grad
(
cost_fn
,
[
v_u1
,
[
v_u1
,
...
...
theano/tensor/basic.py
浏览文件 @
a351e3db
...
@@ -2543,7 +2543,7 @@ class Alloc(gof.Op):
...
@@ -2543,7 +2543,7 @@ class Alloc(gof.Op):
#change.
#change.
return
[
gx
]
+
[
DisconnectedType
()()
for
i
in
inputs
[
1
:]]
return
[
gx
]
+
[
DisconnectedType
()()
for
i
in
inputs
[
1
:]]
def
__call__
(
self
,
val
,
*
shapes
):
def
__call__
(
self
,
val
,
*
shapes
,
**
kwargs
):
"""
"""
If the alloc would be useless, this function returns val.
If the alloc would be useless, this function returns val.
...
@@ -2554,7 +2554,7 @@ class Alloc(gof.Op):
...
@@ -2554,7 +2554,7 @@ class Alloc(gof.Op):
If you always want an Alloc node, call make_node.
If you always want an Alloc node, call make_node.
"""
"""
ret
=
super
(
Alloc
,
self
)
.
__call__
(
val
,
*
shapes
)
ret
=
super
(
Alloc
,
self
)
.
__call__
(
val
,
*
shapes
,
**
kwargs
)
try
:
try
:
# It makes optimization difficult when useless allocs are thrown
# It makes optimization difficult when useless allocs are thrown
# into the graph at every stage of optimization. This little logic
# into the graph at every stage of optimization. This little logic
...
...
theano/tensor/opt.py
浏览文件 @
a351e3db
...
@@ -49,14 +49,24 @@ theano.configparser.AddConfigVar('on_shape_error',
...
@@ -49,14 +49,24 @@ theano.configparser.AddConfigVar('on_shape_error',
def
out2in
(
*
local_opts
):
def
out2in
(
*
local_opts
):
"""WRITEME """
"""WRITEME """
return
opt
.
TopoOptimizer
(
opt
.
LocalOptGroup
(
*
local_opts
),
if
len
(
local_opts
)
>
1
:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts
=
opt
.
LocalOptGroup
(
*
local_opts
),
else
:
local_opts
,
=
local_opts
return
opt
.
TopoOptimizer
(
local_opts
,
order
=
'out_to_in'
,
order
=
'out_to_in'
,
failure_callback
=
TopoOptimizer
.
warn_inplace
)
failure_callback
=
TopoOptimizer
.
warn_inplace
)
def
in2out
(
*
local_opts
,
**
kwargs
):
def
in2out
(
*
local_opts
,
**
kwargs
):
"""WRITEME """
"""WRITEME """
return
opt
.
TopoOptimizer
(
opt
.
LocalOptGroup
(
*
local_opts
),
if
len
(
local_opts
)
>
1
:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts
=
opt
.
LocalOptGroup
(
*
local_opts
),
else
:
local_opts
,
=
local_opts
return
opt
.
TopoOptimizer
(
local_opts
,
order
=
'in_to_out'
,
order
=
'in_to_out'
,
failure_callback
=
TopoOptimizer
.
warn_inplace
,
failure_callback
=
TopoOptimizer
.
warn_inplace
,
**
kwargs
)
**
kwargs
)
...
@@ -384,10 +394,12 @@ def local_dimshuffle_lift(node):
...
@@ -384,10 +394,12 @@ def local_dimshuffle_lift(node):
input
=
node
.
inputs
[
0
]
input
=
node
.
inputs
[
0
]
inode
=
input
.
owner
inode
=
input
.
owner
if
inode
and
isinstance
(
inode
.
op
,
Elemwise
)
and
(
len
(
input
.
clients
)
==
1
):
if
inode
and
isinstance
(
inode
.
op
,
Elemwise
)
and
(
len
(
input
.
clients
)
==
1
):
return
inode
.
op
.
make_node
(
*
[
DimShuffle
(
input
.
type
.
broadcastable
,
# Don't use make_node to have tag.test_value set.
op
.
new_order
,
ret
=
inode
.
op
(
*
[
DimShuffle
(
input
.
type
.
broadcastable
,
op
.
inplace
)(
input
)
for
input
in
op
.
new_order
,
inode
.
inputs
])
.
outputs
op
.
inplace
)(
input
)
for
input
in
inode
.
inputs
],
**
dict
(
return_list
=
True
))
return
ret
if
inode
and
isinstance
(
inode
.
op
,
DimShuffle
):
if
inode
and
isinstance
(
inode
.
op
,
DimShuffle
):
new_order
=
[
x
==
'x'
and
'x'
or
inode
.
op
.
new_order
[
x
]
for
x
in
new_order
=
[
x
==
'x'
and
'x'
or
inode
.
op
.
new_order
[
x
]
for
x
in
op
.
new_order
]
op
.
new_order
]
...
@@ -397,8 +409,9 @@ def local_dimshuffle_lift(node):
...
@@ -397,8 +409,9 @@ def local_dimshuffle_lift(node):
iinput
.
type
.
ndim
):
iinput
.
type
.
ndim
):
return
[
iinput
]
return
[
iinput
]
else
:
else
:
return
DimShuffle
(
iinput
.
type
.
broadcastable
,
new_order
,
ret
=
DimShuffle
(
iinput
.
type
.
broadcastable
,
new_order
,
inplace
)
.
make_node
(
iinput
)
.
outputs
inplace
)(
iinput
,
**
dict
(
return_list
=
True
))
return
ret
@register_canonicalize
@register_canonicalize
...
@@ -437,8 +450,10 @@ def dimshuffle_as_view(node):
...
@@ -437,8 +450,10 @@ def dimshuffle_as_view(node):
#Step 60 is the inplace optimization stage.
#Step 60 is the inplace optimization stage.
compile
.
optdb
.
register
(
'dimshuffle_as_view'
,
compile
.
optdb
.
register
(
'dimshuffle_as_view'
,
TopoOptimizer
(
dimshuffle_as_view
,
TopoOptimizer
(
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
dimshuffle_as_view
,
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
'fast_run'
,
'inplace'
)
'fast_run'
,
'inplace'
)
register_canonicalize
(
local_dimshuffle_lift
)
register_canonicalize
(
local_dimshuffle_lift
)
register_specialize
(
local_dimshuffle_lift
)
register_specialize
(
local_dimshuffle_lift
)
...
@@ -771,7 +786,8 @@ class ShapeFeature(object):
...
@@ -771,7 +786,8 @@ class ShapeFeature(object):
if
hasattr
(
r
.
type
,
"broadcastable"
)
and
r
.
type
.
broadcastable
[
i
]:
if
hasattr
(
r
.
type
,
"broadcastable"
)
and
r
.
type
.
broadcastable
[
i
]:
return
self
.
lscalar_one
return
self
.
lscalar_one
else
:
else
:
return
Shape_i
(
i
)
.
make_node
(
r
)
.
outputs
[
0
]
# Do not call make_node for test_value
return
Shape_i
(
i
)(
r
)
def
shape_tuple
(
self
,
r
):
def
shape_tuple
(
self
,
r
):
"""Return a tuple of symbolic shape vars for tensor variable r"""
"""Return a tuple of symbolic shape vars for tensor variable r"""
...
@@ -970,9 +986,9 @@ class ShapeFeature(object):
...
@@ -970,9 +986,9 @@ class ShapeFeature(object):
# shape var -> graph v
# shape var -> graph v
for
node
in
fgraph
.
toposort
():
for
node
in
fgraph
.
toposort
():
self
.
on_import
(
fgraph
,
node
)
self
.
on_import
(
fgraph
,
node
,
reason
=
'on_attach'
)
def
on_import
(
self
,
fgraph
,
node
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
node
.
outputs
[
0
]
in
self
.
shape_of
:
if
node
.
outputs
[
0
]
in
self
.
shape_of
:
# this is a revert, not really an import
# this is a revert, not really an import
for
r
in
node
.
outputs
+
node
.
inputs
:
for
r
in
node
.
outputs
+
node
.
inputs
:
...
@@ -1933,7 +1949,8 @@ def local_subtensor_merge(node):
...
@@ -1933,7 +1949,8 @@ def local_subtensor_merge(node):
sl_ins
=
Subtensor
.
collapse
(
sl_ins
=
Subtensor
.
collapse
(
merged_slices
,
merged_slices
,
lambda
x
:
isinstance
(
x
,
T
.
Variable
))
lambda
x
:
isinstance
(
x
,
T
.
Variable
))
out
=
subtens
.
make_node
(
x
,
*
sl_ins
)
.
outputs
[
0
]
# Do not call make_node for test_value
out
=
subtens
(
x
,
*
sl_ins
)
return
[
out
]
return
[
out
]
...
@@ -4583,8 +4600,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
...
@@ -4583,8 +4600,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
elif
ii
in
tmp_input
:
elif
ii
in
tmp_input
:
tmp_s_input
.
append
(
tmp_scalar
[
tmp_input
.
index
(
ii
)])
tmp_s_input
.
append
(
tmp_scalar
[
tmp_input
.
index
(
ii
)])
else
:
else
:
tmp_s_input
.
append
(
scalar
.
Scalar
(
tmp
=
scalar
.
Scalar
(
ii
.
dtype
)
.
make_variable
()
ii
.
dtype
)
.
make_variable
())
try
:
tmp
.
tag
.
test_value
=
gof
.
op
.
get_test_value
(
ii
)
.
flatten
()[
0
]
except
AttributeError
:
pass
tmp_s_input
.
append
(
tmp
)
tmp_input
.
append
(
ii
)
tmp_input
.
append
(
ii
)
tmp_scalar
.
append
(
tmp_s_input
[
-
1
])
tmp_scalar
.
append
(
tmp_s_input
[
-
1
])
s_op
=
i
.
owner
.
op
.
scalar_op
(
*
tmp_s_input
)
s_op
=
i
.
owner
.
op
.
scalar_op
(
*
tmp_s_input
)
...
@@ -4634,6 +4655,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
...
@@ -4634,6 +4655,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
s
=
s_inputs
[
inputs
.
index
(
i
)]
s
=
s_inputs
[
inputs
.
index
(
i
)]
else
:
else
:
s
=
scalar
.
Scalar
(
i
.
dtype
)
.
make_variable
()
s
=
scalar
.
Scalar
(
i
.
dtype
)
.
make_variable
()
try
:
v
=
gof
.
op
.
get_test_value
(
i
)
if
v
.
size
>
0
:
s
.
tag
.
test_value
=
gof
.
op
.
get_test_value
(
i
)
.
flatten
()[
0
]
except
AttributeError
:
pass
inputs
.
append
(
i
)
inputs
.
append
(
i
)
s_inputs
.
append
(
s
)
s_inputs
.
append
(
s
)
s_g
.
append
(
s
)
s_g
.
append
(
s
)
...
@@ -4667,7 +4695,8 @@ your code will run correctly, but may be slower.""")
...
@@ -4667,7 +4695,8 @@ your code will run correctly, but may be slower.""")
C
=
scalar
.
Composite
(
s_inputs
,
[
s_new_out
])
C
=
scalar
.
Composite
(
s_inputs
,
[
s_new_out
])
#create the new node.
#create the new node.
n
=
OP
(
C
)
.
make_node
(
*
inputs
)
#Do not call make_node to have test_value
n
=
OP
(
C
)(
*
inputs
)
.
owner
assert
len
(
n
.
outputs
)
==
1
assert
len
(
n
.
outputs
)
==
1
assert
node
.
outputs
[
0
]
.
dtype
==
n
.
outputs
[
0
]
.
dtype
assert
node
.
outputs
[
0
]
.
dtype
==
n
.
outputs
[
0
]
.
dtype
...
@@ -4728,9 +4757,11 @@ if config.tensor.local_elemwise_fusion:
...
@@ -4728,9 +4757,11 @@ if config.tensor.local_elemwise_fusion:
_logger
.
debug
(
"enabling optimization fusion elemwise in fast_run"
)
_logger
.
debug
(
"enabling optimization fusion elemwise in fast_run"
)
compile
.
optdb
.
register
(
'elemwise_fusion'
,
compile
.
optdb
.
register
(
'elemwise_fusion'
,
FusionOptimizer
(
local_elemwise_fusion
),
71.00
,
FusionOptimizer
(
local_elemwise_fusion
),
71.00
,
'fast_run'
,
'fusion'
,
'local_elemwise_fusion'
)
'fast_run'
,
'fusion'
,
'local_elemwise_fusion'
,
'FusionOptimizer'
)
else
:
else
:
_logger
.
debug
(
"not enabling optimization fusion elemwise in fast_run"
)
_logger
.
debug
(
"not enabling optimization fusion elemwise in fast_run"
)
compile
.
optdb
.
register
(
'elemwise_fusion'
,
compile
.
optdb
.
register
(
'elemwise_fusion'
,
FusionOptimizer
(
local_elemwise_fusion
),
71.00
,
FusionOptimizer
(
local_elemwise_fusion
),
71.00
,
'fusion'
,
'local_elemwise_fusion'
)
'fusion'
,
'local_elemwise_fusion'
,
'FusionOptimizer'
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论