Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d8a82f73
提交
d8a82f73
authored
11月 15, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor FunctionGraph interface
This commit does the following: - changes `r` to `var`, - adds missing docstrings, - and removes unnecessary dunder method names.
上级
3c47f74a
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
240 行增加
和
183 行删除
+240
-183
fg.py
theano/gof/fg.py
+240
-183
没有找到文件。
theano/gof/fg.py
浏览文件 @
d8a82f73
...
@@ -15,9 +15,6 @@ from theano.gof.utils import TestValueError, get_variable_trace_string
...
@@ -15,9 +15,6 @@ from theano.gof.utils import TestValueError, get_variable_trace_string
from
theano.misc.ordered_set
import
OrderedSet
from
theano.misc.ordered_set
import
OrderedSet
NullType
=
None
class
InconsistencyError
(
Exception
):
class
InconsistencyError
(
Exception
):
"""
"""
This exception should be thrown by listeners to FunctionGraph when the
This exception should be thrown by listeners to FunctionGraph when the
...
@@ -105,14 +102,16 @@ class FunctionGraph(utils.object2):
...
@@ -105,14 +102,16 @@ class FunctionGraph(utils.object2):
Parameters
Parameters
----------
----------
inputs : list of
variables
inputs : list of
theano.gof.graph.Variable
Inputs nodes of the graph, usually declared by the user
Inputs nodes of the graph, usually declared by the user
outputs : list of
variables
outputs : list of
theano.gof.graph.Variable
Outputs nodes of the graph.
Outputs nodes of the graph.
clone : boolean
clone : boolean
If true, we will clone the graph. This is useful to remove the
If true, we will clone the graph. This is useful to remove the
constant cache problem.
constant cache problem.
update_mapping : dictionary
features : list of theano.gof.toolbox.Feature
A list of features to be added to the `FunctionGraph`.
update_mapping : dict
Mapping between the inputs with updates and the outputs
Mapping between the inputs with updates and the outputs
corresponding to their updates.
corresponding to their updates.
"""
"""
...
@@ -120,6 +119,12 @@ class FunctionGraph(utils.object2):
...
@@ -120,6 +119,12 @@ class FunctionGraph(utils.object2):
if
clone
:
if
clone
:
inputs
,
outputs
=
graph
.
clone
(
inputs
,
outputs
)
inputs
,
outputs
=
graph
.
clone
(
inputs
,
outputs
)
if
not
isinstance
(
inputs
,
list
):
raise
TypeError
(
"Argument `inputs` should be a list"
)
if
not
isinstance
(
outputs
,
list
):
raise
TypeError
(
"Argument `outputs` should be a list"
)
self
.
execute_callbacks_time
=
0
self
.
execute_callbacks_time
=
0
self
.
execute_callbacks_times
=
{}
self
.
execute_callbacks_times
=
{}
...
@@ -139,47 +144,71 @@ class FunctionGraph(utils.object2):
...
@@ -139,47 +144,71 @@ class FunctionGraph(utils.object2):
# outputs even if they aren't used in the graph.
# outputs even if they aren't used in the graph.
self
.
variables
=
set
()
self
.
variables
=
set
()
self
.
inputs
=
list
(
inputs
)
# TODO FIXME: We should *not* be using a list created elsewhere!
self
.
outputs
=
outputs
self
.
outputs
=
outputs
for
f
in
features
:
for
f
in
features
:
self
.
attach_feature
(
f
)
self
.
attach_feature
(
f
)
self
.
attach_feature
(
toolbox
.
ReplaceValidate
())
self
.
attach_feature
(
toolbox
.
ReplaceValidate
())
for
input
in
self
.
inputs
:
self
.
inputs
=
[]
if
input
.
owner
is
not
None
:
for
in_var
in
inputs
:
if
in_var
.
owner
is
not
None
:
raise
ValueError
(
raise
ValueError
(
"One of the provided inputs is the output of"
"One of the provided inputs is the output of
"
"an already existing node. "
"an already existing node. "
"If that is okay, either discard that "
"If that is okay, either discard that "
"input's owner or use graph.clone."
"input's owner or use graph.clone."
)
)
self
.
__setup_r__
(
input
)
self
.
variables
.
add
(
input
)
self
.
add_input
(
in_var
,
check
=
False
)
for
output
in
outputs
:
for
output
in
outputs
:
self
.
__import_r__
(
output
,
reason
=
"init"
)
self
.
import_var
(
output
,
reason
=
"init"
)
for
i
,
output
in
enumerate
(
outputs
):
for
i
,
output
in
enumerate
(
outputs
):
output
.
clients
.
append
((
"output"
,
i
))
output
.
clients
.
append
((
"output"
,
i
))
self
.
profile
=
None
self
.
profile
=
None
self
.
update_mapping
=
update_mapping
self
.
update_mapping
=
update_mapping
def
add_input
(
self
,
input
):
def
add_input
(
self
,
var
,
check
=
True
):
if
input
not
in
self
.
inputs
:
"""Add a new variable as an input to this `FunctionGraph`.
self
.
inputs
.
append
(
input
)
self
.
__setup_r__
(
input
)
Parameters
self
.
variables
.
add
(
input
)
----------
var : theano.gof.graph.Variable
"""
if
check
and
var
in
self
.
inputs
:
return
self
.
inputs
.
append
(
var
)
self
.
setup_var
(
var
)
self
.
variables
.
add
(
var
)
def
setup_var
(
self
,
var
):
"""Set up a variable so it belongs to this `FunctionGraph`.
Parameters
----------
var : theano.gof.graph.Variable
def
__setup_r__
(
self
,
r
):
"""
if
hasattr
(
r
,
"fgraph"
)
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
self
:
if
hasattr
(
var
,
"fgraph"
)
and
var
.
fgraph
is
not
None
and
var
.
fgraph
is
not
self
:
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
raise
Exception
(
"
%
s is already owned by another fgraph"
%
var
)
r
.
fgraph
=
self
var
.
fgraph
=
self
r
.
clients
=
[]
var
.
clients
=
[]
# self.execute_callbacks('on_setup_variable', r)
# self.execute_callbacks('on_setup_variable', var)
def
setup_node
(
self
,
node
):
"""Set up node so it belongs to this `FunctionGraph`.
Parameters
----------
node : theano.gof.graph.Apply
def
__setup_node__
(
self
,
node
):
"""
# sets up node so it belongs to this fgraph
if
hasattr
(
node
,
"fgraph"
)
and
node
.
fgraph
is
not
self
:
if
hasattr
(
node
,
"fgraph"
)
and
node
.
fgraph
is
not
self
:
raise
Exception
(
"
%
s is already owned by another fgraph"
%
node
)
raise
Exception
(
"
%
s is already owned by another fgraph"
%
node
)
if
hasattr
(
node
.
op
,
"view_map"
)
and
not
all
(
if
hasattr
(
node
.
op
,
"view_map"
)
and
not
all
(
...
@@ -226,125 +255,141 @@ class FunctionGraph(utils.object2):
...
@@ -226,125 +255,141 @@ class FunctionGraph(utils.object2):
self
.
profile
=
None
self
.
profile
=
None
self
.
update_mapping
=
None
self
.
update_mapping
=
None
# clients #
def
clients
(
self
,
var
):
def
clients
(
self
,
r
):
"""Return a list of all the `(node, i)` pairs such that `node.inputs[i]` is `var`.
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Told differently, a list of (node,i) such that each node have
r as input at index i.
"""
Told differently, a `list` of `(node, i)` such that each node have
return
r
.
clients
`var` as input at index `i`.
def
__add_client__
(
self
,
r
,
new_client
):
"""
"""
Updates the list of clients of r with new_clients.
return
var
.
clients
def
add_client
(
self
,
var
,
new_client
):
"""Update the clients of `var` with `new_clients`.
Parameters
Parameters
----------
----------
r
var : Variable.
Variable.
new_client : (Apply, int)
new_client
A `(node, i)` pair such that `node.inputs[i]` is `var`.
(node, i) pair such that node.inputs[i] is r.
"""
"""
# Ne need to do the assert as it is always True. The logic
var
.
clients
.
append
(
new_client
)
# that call __add_client__ is valid. When the client list is
# long, the check it time consuming, so we don't enable it by
# default.
# assert not new_client in r.clients
r
.
clients
.
append
(
new_client
)
def
__remove_client__
(
self
,
r
,
client_to_remove
,
reason
=
None
):
def
remove_client
(
self
,
var
,
client_to_remove
,
reason
=
None
):
"""
"""Recursively removes clients of a variable.
Removes all from the clients list of r.
This is the main method to remove variable
or apply node
from
This is the main method to remove variable
s or `Apply` nodes
from
a
n FunctionGraph
.
a
`FunctionGraph`
.
Remove r from this fgraph if it don't have clients left. If it
This will remove `var` from the `FunctionGraph` if it doesn't have any
have an owner and all the outputs of the owner have no
clients remaining. If it has an owner and all the outputs of the owner
clients, it will
be removed.
have no clients, it will also
be removed.
Parameters
Parameters
----------
----------
r : Variable
var : Variable
The clients of r will be removed.
The clients of `var` that will be removed.
client_to_remove : (op, i) pair
client_to_remove : pair of (Apply, int)
(op, i) pair such that node.inputs[i] is not r anymore.
A `(node, i)` pair such that `node.inputs[i]` will no longer be
`var` in this `FunctionGraph`.
"""
l
=
[(
r
,
client_to_remove
)]
"""
while
l
:
r
,
client_to_remove
=
l
.
pop
()
removal_stack
=
[(
var
,
client_to_remove
)]
r
.
clients
.
remove
(
client_to_remove
)
while
removal_stack
:
# entry should be uniq in r. No need to assert it as it is
var
,
client_to_remove
=
removal_stack
.
pop
()
# already asserted in __add_client__.
# assert entry not in r.clients
try
:
if
r
.
clients
:
var
.
clients
.
remove
(
client_to_remove
)
except
ValueError
:
# In this case, the original `var` could've been removed from
# the current `var`'s client list before this call.
# There's nothing inherently wrong with that, so we continue as
# if it were removed here.
pass
if
var
.
clients
:
continue
continue
# r have no more clients, so check if we need to remove it
# Now, `var` has no more clients, so check if we need to remove it
# and its parent.
# and its `Apply` node
variable
=
r
if
not
var
.
owner
:
if
not
variable
.
owner
:
# The `var` is a `Constant` or an input without a client, so we
# A Constant or input without client. Remove it.
# remove it
self
.
variables
.
remove
(
variable
)
self
.
variables
.
remove
(
var
)
# This allow to quickly know if a var is still in the fgraph
# or not.
# This allows us to quickly determine if `var` is still in the
del
variable
.
fgraph
# `FunctionGraph`
# TODO: It's a poor approach; remove it
del
var
.
fgraph
else
:
else
:
apply_node
=
var
iable
.
owner
apply_node
=
var
.
owner
used
=
[
output
for
output
in
apply_node
.
outputs
if
output
.
clients
]
if
not
any
(
output
.
clients
for
output
in
apply_node
.
outputs
):
# If the apply node is not used and is not an output
# The `Apply` node is not used and is not an output, so we
if
not
used
:
# remove it and its outputs
if
not
hasattr
(
apply_node
.
tag
,
"removed_by"
):
if
not
hasattr
(
apply_node
.
tag
,
"removed_by"
):
apply_node
.
tag
.
removed_by
=
[]
apply_node
.
tag
.
removed_by
=
[]
apply_node
.
tag
.
removed_by
.
append
(
str
(
reason
))
apply_node
.
tag
.
removed_by
.
append
(
str
(
reason
))
self
.
apply_nodes
.
remove
(
apply_node
)
self
.
apply_nodes
.
remove
(
apply_node
)
# del apply_node.fgraph
# del apply_node.fgraph
self
.
variables
.
difference_update
(
apply_node
.
outputs
)
#
# for var in apply_node.outputs:
# for var in apply_node.outputs:
# del var.fgraph
# del var.fgraph
self
.
variables
.
difference_update
(
apply_node
.
outputs
)
self
.
execute_callbacks
(
"on_prune"
,
apply_node
,
reason
)
self
.
execute_callbacks
(
"on_prune"
,
apply_node
,
reason
)
for
i
,
in
put
in
enumerate
(
apply_node
.
inputs
):
for
i
,
in
_var
in
enumerate
(
apply_node
.
inputs
):
l
.
append
((
input
,
(
apply_node
,
i
)))
removal_stack
.
append
((
in_var
,
(
apply_node
,
i
)))
def
__import_r__
(
self
,
variable
,
reason
):
def
import_var
(
self
,
var
,
reason
):
"""
"""
Import variables into this `FunctionGraph`.
Import variables to this FunctionGraph and also their apply_node,
if those nodes are not in this graph
.
This will also import the `variable`'s `Apply` node
.
Parameters:
Parameters:
----------
----------
reason
variable : theano.gof.graph.Variable
reason is the name of the optimization or operation in progress.
The variable to be imported.
reason : str
The name of the optimization or operation in progress.
"""
"""
# Imports the owners of the variables
# Imports the owners of the variables
if
var
iable
.
owner
and
variable
.
owner
not
in
self
.
apply_nodes
:
if
var
.
owner
and
var
.
owner
not
in
self
.
apply_nodes
:
self
.
__import__
(
variable
.
owner
,
reason
=
reason
)
self
.
import_node
(
var
.
owner
,
reason
=
reason
)
elif
(
elif
(
var
iable
.
owner
is
None
var
.
owner
is
None
and
not
isinstance
(
var
iable
,
graph
.
Constant
)
and
not
isinstance
(
var
,
graph
.
Constant
)
and
var
iable
not
in
self
.
inputs
and
var
not
in
self
.
inputs
):
):
global
NullType
from
theano.gof.null_type
import
NullType
if
NullType
is
None
:
from
.null_type
import
NullType
if
isinstance
(
var
.
type
,
NullType
):
if
isinstance
(
variable
.
type
,
NullType
):
raise
TypeError
(
raise
TypeError
(
"Computation graph contains a NaN. "
+
var
iable
.
type
.
why_null
"Computation graph contains a NaN. "
+
var
.
type
.
why_null
)
)
raise
MissingInputError
(
"Undeclared input"
,
variable
=
var
iable
)
raise
MissingInputError
(
"Undeclared input"
,
variable
=
var
)
if
not
getattr
(
var
iable
,
"fgraph"
,
None
)
is
self
:
if
not
getattr
(
var
,
"fgraph"
,
None
)
is
self
:
self
.
__setup_r__
(
variable
)
self
.
setup_var
(
var
)
self
.
variables
.
add
(
var
iable
)
self
.
variables
.
add
(
var
)
def
__import__
(
self
,
apply_node
,
check
=
True
,
reason
=
None
):
def
import_node
(
self
,
apply_node
,
check
=
True
,
reason
=
None
):
"""
"""Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs.
Given an apply_node, recursively search from this node to know graph,
and then add all unknown variables and apply_nodes to this graph.
Parameters:
----------
apply_node : theano.gof.graph.Apply
The node to be imported.
check : bool
Check that the inputs for the imported nodes are also present in
the `FunctionGraph`.
reason : str
The name of the optimization or operation in progress.
"""
"""
node
=
apply_node
node
=
apply_node
...
@@ -358,13 +403,13 @@ class FunctionGraph(utils.object2):
...
@@ -358,13 +403,13 @@ class FunctionGraph(utils.object2):
for
node
in
new_nodes
:
for
node
in
new_nodes
:
if
hasattr
(
node
,
"fgraph"
)
and
node
.
fgraph
is
not
self
:
if
hasattr
(
node
,
"fgraph"
)
and
node
.
fgraph
is
not
self
:
raise
Exception
(
"
%
s is already owned by another fgraph"
%
node
)
raise
Exception
(
"
%
s is already owned by another fgraph"
%
node
)
for
r
in
node
.
inputs
:
for
va
r
in
node
.
inputs
:
if
hasattr
(
r
,
"fgraph"
)
and
r
.
fgraph
is
not
self
:
if
hasattr
(
var
,
"fgraph"
)
and
va
r
.
fgraph
is
not
self
:
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
raise
Exception
(
"
%
s is already owned by another fgraph"
%
va
r
)
if
(
if
(
r
.
owner
is
None
va
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
not
isinstance
(
va
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
and
va
r
not
in
self
.
inputs
):
):
# Standard error message
# Standard error message
error_msg
=
(
error_msg
=
(
...
@@ -373,51 +418,60 @@ class FunctionGraph(utils.object2):
...
@@ -373,51 +418,60 @@ class FunctionGraph(utils.object2):
"provided and not given a value. Use the "
"provided and not given a value. Use the "
"Theano flag exception_verbosity='high', "
"Theano flag exception_verbosity='high', "
"for more information on this error."
"for more information on this error."
%
(
node
.
inputs
.
index
(
r
),
str
(
node
))
%
(
node
.
inputs
.
index
(
va
r
),
str
(
node
))
)
)
raise
MissingInputError
(
error_msg
,
variable
=
r
)
raise
MissingInputError
(
error_msg
,
variable
=
va
r
)
for
node
in
new_nodes
:
for
node
in
new_nodes
:
assert
node
not
in
self
.
apply_nodes
assert
node
not
in
self
.
apply_nodes
self
.
__setup_node__
(
node
)
self
.
setup_node
(
node
)
self
.
apply_nodes
.
add
(
node
)
self
.
apply_nodes
.
add
(
node
)
if
not
hasattr
(
node
.
tag
,
"imported_by"
):
if
not
hasattr
(
node
.
tag
,
"imported_by"
):
node
.
tag
.
imported_by
=
[]
node
.
tag
.
imported_by
=
[]
node
.
tag
.
imported_by
.
append
(
str
(
reason
))
node
.
tag
.
imported_by
.
append
(
str
(
reason
))
for
output
in
node
.
outputs
:
for
output
in
node
.
outputs
:
self
.
__setup_r__
(
output
)
self
.
setup_var
(
output
)
self
.
variables
.
add
(
output
)
self
.
variables
.
add
(
output
)
for
i
,
input
in
enumerate
(
node
.
inputs
):
for
i
,
input
in
enumerate
(
node
.
inputs
):
if
input
not
in
self
.
variables
:
if
input
not
in
self
.
variables
:
self
.
__setup_r__
(
input
)
self
.
setup_var
(
input
)
self
.
variables
.
add
(
input
)
self
.
variables
.
add
(
input
)
self
.
__add_client__
(
input
,
(
node
,
i
))
self
.
add_client
(
input
,
(
node
,
i
))
assert
node
.
fgraph
is
self
assert
node
.
fgraph
is
self
self
.
execute_callbacks
(
"on_import"
,
node
,
reason
)
self
.
execute_callbacks
(
"on_import"
,
node
,
reason
)
# change input #
def
change_input
(
self
,
node
,
i
,
new_var
,
reason
=
None
):
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
"""Change ``node.inputs[i]`` to `new_var`.
"""
Changes node.inputs[i] to new_r.
new_r.type == old_r.type must be True, where old_r
is the
``new_var.type == old_var.type`` must be ``True``, where ``old_var``
is the
current value of
node.inputs[i]
which we want to replace.
current value of
``node.inputs[i]``
which we want to replace.
For each feature that has a 'on_change_input' method, calls:
For each feature that has an `on_change_input` method, this method calls:
feature.on_change_input(function_graph, node, i, old_r, new_r, reason)
``feature.on_change_input(function_graph, node, i, old_var, new_var, reason)``
Parameters
----------
node : theano.gof.graph.Apply or str
The node for which an input is to be changed. If the value is
the string ``"output"`` then the ``self.outputs`` will be used
instead of ``node.inputs``.
i : int
The index in `node.inputs` that we want to change.
new_var : theano.gof.graph.Variable
The new variable to take the place of ``node.inputs[i]``.
"""
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if
node
==
"output"
:
if
node
==
"output"
:
r
=
self
.
outputs
[
i
]
r
=
self
.
outputs
[
i
]
if
not
r
.
type
==
new_r
.
type
:
if
not
r
.
type
==
new_
va
r
.
type
:
raise
TypeError
(
raise
TypeError
(
"The type of the replacement must be the"
"The type of the replacement must be the"
" same as the type of the original Variable."
,
" same as the type of the original Variable."
,
r
,
r
,
new_r
,
new_
va
r
,
)
)
self
.
outputs
[
i
]
=
new_r
self
.
outputs
[
i
]
=
new_
va
r
else
:
else
:
if
node
.
fgraph
is
not
self
:
if
node
.
fgraph
is
not
self
:
raise
Exception
(
raise
Exception
(
...
@@ -425,51 +479,63 @@ class FunctionGraph(utils.object2):
...
@@ -425,51 +479,63 @@ class FunctionGraph(utils.object2):
" belong to this FunctionGraph"
%
node
" belong to this FunctionGraph"
%
node
)
)
r
=
node
.
inputs
[
i
]
r
=
node
.
inputs
[
i
]
if
not
r
.
type
==
new_r
.
type
:
if
not
r
.
type
==
new_
va
r
.
type
:
raise
TypeError
(
raise
TypeError
(
"The type of the replacement must be the"
"The type of the replacement must be the"
" same as the type of the original Variable."
,
" same as the type of the original Variable."
,
r
,
r
,
new_r
,
new_
va
r
,
)
)
node
.
inputs
[
i
]
=
new_r
node
.
inputs
[
i
]
=
new_
va
r
if
r
is
new_r
:
if
r
is
new_
va
r
:
return
return
self
.
__import_r__
(
new_
r
,
reason
=
reason
)
self
.
import_var
(
new_va
r
,
reason
=
reason
)
self
.
__add_client__
(
new_
r
,
(
node
,
i
))
self
.
add_client
(
new_va
r
,
(
node
,
i
))
self
.
__remove_client__
(
r
,
(
node
,
i
),
reason
=
reason
)
self
.
remove_client
(
r
,
(
node
,
i
),
reason
=
reason
)
# Precondition: the substitution is semantically valid
# Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the
# However it may introduce cycles to the graph, in which case the
# transaction will be reverted later.
# transaction will be reverted later.
self
.
execute_callbacks
(
"on_change_input"
,
node
,
i
,
r
,
new_r
,
reason
=
reason
)
self
.
execute_callbacks
(
"on_change_input"
,
node
,
i
,
r
,
new_
va
r
,
reason
=
reason
)
# replace #
def
replace
(
self
,
var
,
new_var
,
reason
=
None
,
verbose
=
None
):
def
replace
(
self
,
r
,
new_r
,
reason
=
None
,
verbose
=
None
):
"""Replace a variable in the `FunctionGraph`.
"""
This is the main interface to manipulate the subgraph in FunctionGraph.
This is the main interface to manipulate the subgraph in `FunctionGraph`.
For every node that uses r as input, makes it use new_r instead.
For every node that uses `var` as input, makes it use `new_var` instead.
Parameters:
----------
var : theano.gof.graph.Variable
The variable to be replaced.
new_var : theano.gof.graph.Variable
The variable to replace `var`.
reason : str
The name of the optimization or operation in progress.
verbose : bool
Print `reason`, `var`, and `new_var`.
"""
"""
if
verbose
is
None
:
if
verbose
is
None
:
verbose
=
config
.
optimizer_verbose
verbose
=
config
.
optimizer_verbose
if
verbose
:
if
verbose
:
print
(
reason
,
r
,
new_r
)
print
(
reason
,
var
,
new_var
)
if
hasattr
(
r
,
"fgraph"
)
and
r
.
fgraph
is
not
self
:
if
hasattr
(
var
,
"fgraph"
)
and
var
.
fgraph
is
not
self
:
raise
Exception
(
raise
Exception
(
"Cannot replace
%
s because it does not belong "
"Cannot replace
%
s because it does not belong "
"to this FunctionGraph"
%
r
,
"to this FunctionGraph"
%
va
r
,
str
(
reason
),
str
(
reason
),
)
)
if
r
.
type
!=
new_
r
.
type
:
if
var
.
type
!=
new_va
r
.
type
:
new_
r2
=
r
.
type
.
convert_variable
(
new_
r
)
new_
var_2
=
var
.
type
.
convert_variable
(
new_va
r
)
# We still make sure that the type converts correctly
# We still make sure that the type converts correctly
if
new_
r2
is
None
or
new_r2
.
type
!=
r
.
type
:
if
new_
var_2
is
None
or
new_var_2
.
type
!=
va
r
.
type
:
done
=
dict
()
done
=
dict
()
used_ids
=
dict
()
used_ids
=
dict
()
old
=
theano
.
compile
.
debugmode
.
debugprint
(
old
=
theano
.
compile
.
debugmode
.
debugprint
(
r
,
va
r
,
prefix
=
" "
,
prefix
=
" "
,
depth
=
6
,
depth
=
6
,
file
=
StringIO
(),
file
=
StringIO
(),
...
@@ -478,7 +544,7 @@ class FunctionGraph(utils.object2):
...
@@ -478,7 +544,7 @@ class FunctionGraph(utils.object2):
used_ids
=
used_ids
,
used_ids
=
used_ids
,
)
.
getvalue
()
)
.
getvalue
()
new
=
theano
.
compile
.
debugmode
.
debugprint
(
new
=
theano
.
compile
.
debugmode
.
debugprint
(
new_r
,
new_
va
r
,
prefix
=
" "
,
prefix
=
" "
,
depth
=
6
,
depth
=
6
,
file
=
StringIO
(),
file
=
StringIO
(),
...
@@ -487,16 +553,17 @@ class FunctionGraph(utils.object2):
...
@@ -487,16 +553,17 @@ class FunctionGraph(utils.object2):
used_ids
=
used_ids
,
used_ids
=
used_ids
,
)
.
getvalue
()
)
.
getvalue
()
raise
toolbox
.
BadOptimization
(
raise
toolbox
.
BadOptimization
(
r
,
va
r
,
new_r
,
new_
va
r
,
None
,
None
,
None
,
None
,
str
(
reason
)
+
". The type of the replacement must be the same."
,
str
(
reason
)
+
". The type of the replacement must be the same."
,
old
,
old
,
new
,
new
,
)
)
new_r
=
new_r2
new_var
=
new_var_2
if
r
not
in
self
.
variables
:
if
var
not
in
self
.
variables
:
# this variable isn't in the graph... don't raise an
# this variable isn't in the graph... don't raise an
# exception here, just return silently because it makes it
# exception here, just return silently because it makes it
# easier to implement some optimizations for
# easier to implement some optimizations for
...
@@ -505,8 +572,8 @@ class FunctionGraph(utils.object2):
...
@@ -505,8 +572,8 @@ class FunctionGraph(utils.object2):
if
theano
.
config
.
compute_test_value
!=
"off"
:
if
theano
.
config
.
compute_test_value
!=
"off"
:
try
:
try
:
tval
=
theano
.
gof
.
op
.
get_test_value
(
r
)
tval
=
theano
.
gof
.
op
.
get_test_value
(
va
r
)
new_tval
=
theano
.
gof
.
op
.
get_test_value
(
new_r
)
new_tval
=
theano
.
gof
.
op
.
get_test_value
(
new_
va
r
)
except
TestValueError
:
except
TestValueError
:
pass
pass
else
:
else
:
...
@@ -518,27 +585,21 @@ class FunctionGraph(utils.object2):
...
@@ -518,27 +585,21 @@ class FunctionGraph(utils.object2):
"a shape different from the original variable's "
"a shape different from the original variable's "
"test value. Original:
%
s, new:
%
s"
"test value. Original:
%
s, new:
%
s"
%
(
tval_shape
,
new_tval_shape
),
%
(
tval_shape
,
new_tval_shape
),
r
,
va
r
,
new_r
,
new_
va
r
,
str
(
reason
),
str
(
reason
),
)
)
for
node
,
i
in
list
(
r
.
clients
):
# copy the client list for iteration
for
node
,
i
in
list
(
var
.
clients
):
# copy the client list for iteration
assert
(
node
==
"output"
and
self
.
outputs
[
i
]
is
r
)
or
(
node
.
inputs
[
i
]
is
r
)
assert
(
node
==
"output"
and
self
.
outputs
[
i
]
is
var
)
or
(
self
.
change_input
(
node
,
i
,
new_r
,
reason
=
reason
)
node
.
inputs
[
i
]
is
var
)
# sometimes the following is triggered. If you understand why, please explain to James.
self
.
change_input
(
node
,
i
,
new_var
,
reason
=
reason
)
# He's curious... -JB20090331
# if len(r.clients) != 0:
# print >> sys.stderr, "WARNING: CLIENTS LEFT AFTER REPLACE", r, r.clients
def
replace_all
(
self
,
pairs
,
reason
=
None
):
def
replace_all
(
self
,
pairs
,
reason
=
None
):
"""
"""Replace variables in the `FunctionGraph` according to `(var, new_var)` pairs in a list."""
For every node that uses r as input, makes it use new_r instead
for
var
,
new_var
in
pairs
:
self
.
replace
(
var
,
new_var
,
reason
=
reason
)
"""
for
r
,
new_r
in
pairs
:
self
.
replace
(
r
,
new_r
,
reason
=
reason
)
def
attach_feature
(
self
,
feature
):
def
attach_feature
(
self
,
feature
):
"""
"""
...
@@ -587,7 +648,6 @@ class FunctionGraph(utils.object2):
...
@@ -587,7 +648,6 @@ class FunctionGraph(utils.object2):
if
detach
is
not
None
:
if
detach
is
not
None
:
detach
(
self
)
detach
(
self
)
# callback utils #
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
"""Execute callbacks
"""Execute callbacks
...
@@ -625,7 +685,6 @@ class FunctionGraph(utils.object2):
...
@@ -625,7 +685,6 @@ class FunctionGraph(utils.object2):
d
[
feature
]
=
fn
(
*
args
)
d
[
feature
]
=
fn
(
*
args
)
return
d
return
d
# misc #
def
toposort
(
self
):
def
toposort
(
self
):
"""Toposort
"""Toposort
...
@@ -655,17 +714,16 @@ class FunctionGraph(utils.object2):
...
@@ -655,17 +714,16 @@ class FunctionGraph(utils.object2):
return
order
return
order
def
orderings
(
self
):
def
orderings
(
self
):
"""
"""Return `dict` `d` s.t. `d[node]` is a list of nodes that must be evaluated before `node` itself can be evaluated.
Return dict d s.t. d[node] is a list of nodes that must be evaluated
before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that
This is used primarily by the destroy_handler feature to ensure that
all clients of any destroyed inputs have already computed their outputs.
the clients of any destroyed inputs have already computed their
outputs.
Notes
Notes
-----
-----
This only calls the
orderings() fct
on all features. It does not
This only calls the
`orderings()` function
on all features. It does not
take care of computing dependencies by itself.
take care of computing
the
dependencies by itself.
"""
"""
assert
isinstance
(
self
.
_features
,
list
)
assert
isinstance
(
self
.
_features
,
list
)
...
@@ -769,7 +827,6 @@ class FunctionGraph(utils.object2):
...
@@ -769,7 +827,6 @@ class FunctionGraph(utils.object2):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__str__
()
return
self
.
__str__
()
# clone #
def
clone
(
self
,
check_integrity
=
True
):
def
clone
(
self
,
check_integrity
=
True
):
"""
"""
Clone the graph and get a memo( a dict )that map old node to new node
Clone the graph and get a memo( a dict )that map old node to new node
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论