Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
b7cf7933
提交
b7cf7933
authored
8月 17, 2015
作者:
Iban Harlouchet
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
numpydoc for theano/scan_module/scan_utils.py
上级
deabd346
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
88 行增加
和
56 行删除
+88
-56
scan_utils.py
theano/scan_module/scan_utils.py
+88
-56
没有找到文件。
theano/scan_module/scan_utils.py
浏览文件 @
b7cf7933
"""
"""
This module provides utility functions for the Scan Op
This module provides utility functions for the Scan Op.
See scan.py for details on scan.
See scan.py for details on scan
"""
"""
__docformat__
=
'restructedtext en'
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
__authors__
=
(
"Razvan Pascanu "
...
@@ -43,6 +44,7 @@ def safe_new(x, tag='', dtype=None):
...
@@ -43,6 +44,7 @@ def safe_new(x, tag='', dtype=None):
by gradient, or the R-op to construct new variables for the inputs of
by gradient, or the R-op to construct new variables for the inputs of
the inner graph such that there is no interference between the original
the inner graph such that there is no interference between the original
graph and the newly constructed graph.
graph and the newly constructed graph.
"""
"""
if
hasattr
(
x
,
'name'
)
and
x
.
name
is
not
None
:
if
hasattr
(
x
,
'name'
)
and
x
.
name
is
not
None
:
nw_name
=
x
.
name
+
tag
nw_name
=
x
.
name
+
tag
...
@@ -117,21 +119,28 @@ class until(object):
...
@@ -117,21 +119,28 @@ class until(object):
between the condition and the list of outputs ( unless we enforce and
between the condition and the list of outputs ( unless we enforce and
order, but since this was not impose up to know it can make quite a bit
order, but since this was not impose up to know it can make quite a bit
of code to fail).
of code to fail).
"""
"""
def
__init__
(
self
,
condition
):
def
__init__
(
self
,
condition
):
self
.
condition
=
tensor
.
as_tensor_variable
(
condition
)
self
.
condition
=
tensor
.
as_tensor_variable
(
condition
)
assert
self
.
condition
.
ndim
==
0
assert
self
.
condition
.
ndim
==
0
def
traverse
(
out
,
x
,
x_copy
,
d
,
visited
=
None
):
def
traverse
(
out
,
x
,
x_copy
,
d
,
visited
=
None
):
''' Function used by scan to parse the tree and figure out which nodes
"""
it needs to replace. There are two options :
Function used by scan to parse the tree and figure out which nodes
it needs to replace.
There are two options :
1) x and x_copy or on host, then you would replace x with x_copy
1) x and x_copy or on host, then you would replace x with x_copy
2) x is on gpu, x_copy on host, then you need to replace
2) x is on gpu, x_copy on host, then you need to replace
host_from_gpu(x) with x_copy
host_from_gpu(x) with x_copy
This happens because initially shared variables are on GPU
.. which is
This happens because initially shared variables are on GPU
.
.. which is
fine for the main computational graph but confuses things a bit for the
fine for the main computational graph but confuses things a bit for the
inner graph of scan '''
inner graph of scan.
"""
# ``visited`` is a set of nodes that are already known and don't need to be
# ``visited`` is a set of nodes that are already known and don't need to be
# checked again, speeding up the traversal of multiply-connected graphs.
# checked again, speeding up the traversal of multiply-connected graphs.
# if a ``visited`` set is given, it will be updated in-place so the callee
# if a ``visited`` set is given, it will be updated in-place so the callee
...
@@ -191,25 +200,25 @@ def clone(output,
...
@@ -191,25 +200,25 @@ def clone(output,
share_inputs
=
True
,
share_inputs
=
True
,
copy_inputs
=
DEPRECATED_ARG
):
copy_inputs
=
DEPRECATED_ARG
):
"""
"""
Function that allows replacing subgraphs of a computational
Function that allows replacing subgraphs of a computational graph.
graph. It returns a copy of the initial subgraph with the corresponding
substitutions.
:type output: Theano Variables (or Theano expressions)
It returns a copy of the initial subgraph with the corresponding
:param outputs: Theano expression that represents the computational
substitutions.
graph
:type replace: dict
:param replace: dictionary describing which subgraphs should be
replaced by what
:type share_inputs: bool
Parameters
:param share_inputs: If True, use the same inputs (and shared variables)
----------
as the original graph. If False, clone them. Note that cloned
output : Theano Variables (or Theano expressions)
shared variables still use the same underlying storage, so they
Theano expression that represents the computational graph.
will always have the same value.
replace : dict
Dictionary describing which subgraphs should be replaced by what.
share_inputs : bool
If True, use the same inputs (and shared variables) as the original
graph. If False, clone them. Note that cloned shared variables still
use the same underlying storage, so they will always have the same
value.
copy_inputs
Deprecated, use share_inputs.
:param copy_inputs: deprecated, use share_inputs.
"""
"""
if
copy_inputs
is
not
DEPRECATED_ARG
:
if
copy_inputs
is
not
DEPRECATED_ARG
:
warnings
.
warn
(
'In `clone()` function, the argument `copy_inputs` has been deprecated and renamed into `share_inputs`'
)
warnings
.
warn
(
'In `clone()` function, the argument `copy_inputs` has been deprecated and renamed into `share_inputs`'
)
...
@@ -251,7 +260,7 @@ def get_updates_and_outputs(ls):
...
@@ -251,7 +260,7 @@ def get_updates_and_outputs(ls):
"""
"""
This function tries to recognize the updates OrderedDict, the
This function tries to recognize the updates OrderedDict, the
list of outputs and the stopping condition returned by the
list of outputs and the stopping condition returned by the
lambda expression and arrange them in a predefined order
lambda expression and arrange them in a predefined order
.
WRITEME: what is the type of ls? how is it formatted?
WRITEME: what is the type of ls? how is it formatted?
if it's not in the predefined order already, how does
if it's not in the predefined order already, how does
...
@@ -297,6 +306,7 @@ def get_updates_and_outputs(ls):
...
@@ -297,6 +306,7 @@ def get_updates_and_outputs(ls):
Return True iff `x` is made only of lists, tuples, dictionaries, Theano
Return True iff `x` is made only of lists, tuples, dictionaries, Theano
variables or `theano.scan_module.until` objects.
variables or `theano.scan_module.until` objects.
"""
"""
# Is `x` a container we can iterate on?
# Is `x` a container we can iterate on?
iter_on
=
None
iter_on
=
None
...
@@ -390,10 +400,11 @@ def isNaN_or_Inf_or_None(x):
...
@@ -390,10 +400,11 @@ def isNaN_or_Inf_or_None(x):
def
expand
(
tensor_var
,
size
):
def
expand
(
tensor_var
,
size
):
'''
"""
Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
by adding 0s at the end of the tensor.
by adding 0s at the end of the tensor.
'''
"""
# Corner case that I might use in an optimization
# Corner case that I might use in an optimization
if
size
==
0
:
if
size
==
0
:
return
tensor_var
return
tensor_var
...
@@ -406,7 +417,7 @@ def expand(tensor_var, size):
...
@@ -406,7 +417,7 @@ def expand(tensor_var, size):
def
equal_computations
(
xs
,
ys
,
in_xs
=
None
,
in_ys
=
None
):
def
equal_computations
(
xs
,
ys
,
in_xs
=
None
,
in_ys
=
None
):
'''
Checks if Theano graphs represent the same computations.
"""
Checks if Theano graphs represent the same computations.
The two lists `xs`, `ys` should have the same number of entries. The
The two lists `xs`, `ys` should have the same number of entries. The
function checks if for any corresponding pair `(x,y)` from `zip(xs,ys)`
function checks if for any corresponding pair `(x,y)` from `zip(xs,ys)`
...
@@ -420,7 +431,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
...
@@ -420,7 +431,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
`ys`, but also represent subgraphs of a computational graph in `xs`
`ys`, but also represent subgraphs of a computational graph in `xs`
or `ys`.
or `ys`.
'''
"""
assert
len
(
xs
)
==
len
(
ys
)
assert
len
(
xs
)
==
len
(
ys
)
if
in_xs
is
None
:
if
in_xs
is
None
:
in_xs
=
[]
in_xs
=
[]
...
@@ -460,14 +471,16 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
...
@@ -460,14 +471,16 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
# Explore the two graphs, in parallel, depth first, comparing the nodes
# Explore the two graphs, in parallel, depth first, comparing the nodes
# along the way for equality.
# along the way for equality.
def
compare_nodes
(
nd_x
,
nd_y
,
common
,
different
):
def
compare_nodes
(
nd_x
,
nd_y
,
common
,
different
):
''' Compare two nodes to determine if they perform equal computation.
"""
Compare two nodes to determine if they perform equal computation.
This is done by comparing the ops, the number of inputs, outputs and
This is done by comparing the ops, the number of inputs, outputs and
by ensuring that the inputs themselves are the result of equal
by ensuring that the inputs themselves are the result of equal
computation.
computation.
NOTE : This function relies on the variable common to cache
NOTE : This function relies on the variable common to cache
results to be more efficient.
results to be more efficient.
'''
"""
if
nd_x
.
op
!=
nd_y
.
op
:
if
nd_x
.
op
!=
nd_y
.
op
:
return
False
return
False
...
@@ -537,13 +550,14 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
...
@@ -537,13 +550,14 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
def
infer_shape
(
outs
,
inputs
,
input_shapes
):
def
infer_shape
(
outs
,
inputs
,
input_shapes
):
'''
"""
Compute the shape of the outputs given the shape of the inputs
Compute the shape of the outputs given the shape of the inputs
of a theano
of a theano
graph.
graph.
We do it this way to avoid compiling the inner function just to get
We do it this way to avoid compiling the inner function just to get
the shape. Changes to ShapeFeature could require changes in this function.
the shape. Changes to ShapeFeature could require changes in this function.
'''
"""
# We use a ShapeFeature because it has all the necessary logic
# We use a ShapeFeature because it has all the necessary logic
# inside. We don't use the full ShapeFeature interface, but we
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
# let it initialize itself with an empty fgraph, otherwise we will
...
@@ -560,10 +574,10 @@ def infer_shape(outs, inputs, input_shapes):
...
@@ -560,10 +574,10 @@ def infer_shape(outs, inputs, input_shapes):
shape_feature
.
set_shape
(
inp
,
inp_shp
)
shape_feature
.
set_shape
(
inp
,
inp_shp
)
def
local_traverse
(
out
):
def
local_traverse
(
out
):
'''
"""
Go back in the graph, from out, adding computable shapes to shape_of.
Go back in the graph, from out, adding computable shapes to shape_of.
'''
"""
if
out
in
shape_feature
.
shape_of
:
if
out
in
shape_feature
.
shape_of
:
# Its shape is already known
# Its shape is already known
return
return
...
@@ -589,14 +603,18 @@ def infer_shape(outs, inputs, input_shapes):
...
@@ -589,14 +603,18 @@ def infer_shape(outs, inputs, input_shapes):
class
Validator
(
object
):
class
Validator
(
object
):
def
__init__
(
self
,
valid
=
None
,
invalid
=
None
,
valid_equivalent
=
None
):
"""
'''
Check if variables can be expressed without using variables in invalid.
Check if variables can be expressed without using variables in invalid.
init_valid_equivalent provides a dictionary mapping some invalid
Parameters
variables to valid ones that can be used instead.
----------
'''
valid_equivalent
Provides a dictionary mapping some invalid variables to valid ones that
can be used instead.
"""
def
__init__
(
self
,
valid
=
None
,
invalid
=
None
,
valid_equivalent
=
None
):
if
valid
is
None
:
if
valid
is
None
:
valid
=
[]
valid
=
[]
if
invalid
is
None
:
if
invalid
is
None
:
...
@@ -616,13 +634,14 @@ class Validator(object):
...
@@ -616,13 +634,14 @@ class Validator(object):
self
.
invalid
.
update
(
list
(
valid_equivalent
.
keys
()))
self
.
invalid
.
update
(
list
(
valid_equivalent
.
keys
()))
def
check
(
self
,
out
):
def
check
(
self
,
out
):
'''
"""
Go backwards in the graph, from out, and check if out is valid.
Go backwards in the graph, from out, and check if out is valid.
If out is a valid node, (out, True) is returned.
If out is a valid node, (out, True) is returned.
If out is not valid, but has an equivalent e, (e, False) is returned.
If out is not valid, but has an equivalent e, (e, False) is returned.
If out is not valid and has no equivalent, None is returned.
If out is not valid and has no equivalent, None is returned.
'''
"""
if
out
in
self
.
valid
:
if
out
in
self
.
valid
:
return
out
,
True
return
out
,
True
elif
out
in
self
.
valid_equivalent
:
elif
out
in
self
.
valid_equivalent
:
...
@@ -667,12 +686,13 @@ class Validator(object):
...
@@ -667,12 +686,13 @@ class Validator(object):
def
scan_can_remove_outs
(
op
,
out_idxs
):
def
scan_can_remove_outs
(
op
,
out_idxs
):
'''
"""
Looks at all outputs defined by indices ``out_idxs`` and see whom can be
Looks at all outputs defined by indices ``out_idxs`` and see whom can be
removed from the scan op without affecting the rest. Return two lists,
removed from the scan op without affecting the rest. Return two lists,
the first one with the indices of outs that can be removed, the second
the first one with the indices of outs that can be removed, the second
with the outputs that can not be removed.
with the outputs that can not be removed.
'''
"""
non_removable
=
[
o
for
i
,
o
in
enumerate
(
op
.
outputs
)
if
i
not
in
non_removable
=
[
o
for
i
,
o
in
enumerate
(
op
.
outputs
)
if
i
not
in
out_idxs
]
out_idxs
]
required_inputs
=
gof
.
graph
.
inputs
(
non_removable
)
required_inputs
=
gof
.
graph
.
inputs
(
non_removable
)
...
@@ -706,7 +726,7 @@ def scan_can_remove_outs(op, out_idxs):
...
@@ -706,7 +726,7 @@ def scan_can_remove_outs(op, out_idxs):
def
compress_outs
(
op
,
not_required
,
inputs
):
def
compress_outs
(
op
,
not_required
,
inputs
):
'''
"""
Helpful function that gets a Scan op, a list of indices indicating
Helpful function that gets a Scan op, a list of indices indicating
which outputs are not required anymore and should be removed, and
which outputs are not required anymore and should be removed, and
a list of inputs to the apply node corresponding to the scan op and
a list of inputs to the apply node corresponding to the scan op and
...
@@ -714,7 +734,8 @@ def compress_outs(op, not_required, inputs):
...
@@ -714,7 +734,8 @@ def compress_outs(op, not_required, inputs):
the indicated outputs are eliminated. Note that eliminating an output
the indicated outputs are eliminated. Note that eliminating an output
means removing its inputs from the inner funciton and from the
means removing its inputs from the inner funciton and from the
node inputs, and changing the dictionary.
node inputs, and changing the dictionary.
'''
"""
info
=
OrderedDict
()
info
=
OrderedDict
()
info
[
'tap_array'
]
=
[]
info
[
'tap_array'
]
=
[]
info
[
'n_seqs'
]
=
op
.
info
[
'n_seqs'
]
info
[
'n_seqs'
]
=
op
.
info
[
'n_seqs'
]
...
@@ -852,6 +873,7 @@ def compress_outs(op, not_required, inputs):
...
@@ -852,6 +873,7 @@ def compress_outs(op, not_required, inputs):
def
find_up
(
l_node
,
f_node
):
def
find_up
(
l_node
,
f_node
):
r"""
r"""
Goes up in the graph and returns True if a node in nodes is found.
Goes up in the graph and returns True if a node in nodes is found.
"""
"""
if
isinstance
(
l_node
,
gof
.
Apply
):
if
isinstance
(
l_node
,
gof
.
Apply
):
l_outs
=
l_node
.
outputs
l_outs
=
l_node
.
outputs
...
@@ -866,8 +888,9 @@ def reconstruct_graph(inputs, outputs, tag=None):
...
@@ -866,8 +888,9 @@ def reconstruct_graph(inputs, outputs, tag=None):
"""
"""
Different interface to clone, that allows you to pass inputs.
Different interface to clone, that allows you to pass inputs.
Compared to clone, this method always replaces the inputs with
Compared to clone, this method always replaces the inputs with
new variables of the same type, and returns those (
in the same
new variables of the same type, and returns those (in the same
order as the original inputs).
order as the original inputs).
"""
"""
if
tag
is
None
:
if
tag
is
None
:
tag
=
''
tag
=
''
...
@@ -885,7 +908,11 @@ def reconstruct_graph(inputs, outputs, tag=None):
...
@@ -885,7 +908,11 @@ def reconstruct_graph(inputs, outputs, tag=None):
class
scan_args
(
object
):
class
scan_args
(
object
):
"""Parses the inputs and outputs of scan in an easy to manipulate format"""
"""
Parses the inputs and outputs of scan in an easy to manipulate format.
"""
def
__init__
(
self
,
outer_inputs
,
outer_outputs
,
def
__init__
(
self
,
outer_inputs
,
outer_outputs
,
_inner_inputs
,
_inner_outputs
,
info
):
_inner_inputs
,
_inner_outputs
,
info
):
self
.
n_steps
=
outer_inputs
[
0
]
self
.
n_steps
=
outer_inputs
[
0
]
...
@@ -1070,17 +1097,22 @@ class scan_args(object):
...
@@ -1070,17 +1097,22 @@ class scan_args(object):
def
forced_replace
(
out
,
x
,
y
):
def
forced_replace
(
out
,
x
,
y
):
"""
"""
:param out: Theano Variable
Check all internal values of the graph that compute the variable ``out``
:param x: Theano Variable
for occurrences of values identical with ``x``. If such occurrences are
:param y: Theano Variable
encountered then they are replaced with variable ``y``.
This function checks all internal values of the graph that computes the
Parameters
variable ``out`` for occurances of values identical with ``x``. If such
----------
occurances are encountered then they are replaced with variable ``y``.
out : Theano Variable
For example:
x : Theano Variable
y : Theano Variable
Examples
--------
out := sigmoid(wu)*(1-sigmoid(wu))
out := sigmoid(wu)*(1-sigmoid(wu))
x := sigmoid(wu)
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y)
forced_replace(out, x, y) := y*(1-y)
"""
"""
if
out
is
None
:
if
out
is
None
:
return
None
return
None
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论