Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
b7cf7933
提交
b7cf7933
authored
8月 17, 2015
作者:
Iban Harlouchet
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
numpydoc for theano/scan_module/scan_utils.py
上级
deabd346
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
92 行增加
和
60 行删除
+92
-60
scan_utils.py
theano/scan_module/scan_utils.py
+92
-60
没有找到文件。
theano/scan_module/scan_utils.py
浏览文件 @
b7cf7933
"""
This module provides utility functions for the Scan Op
This module provides utility functions for the Scan Op.
See scan.py for details on scan.
See scan.py for details on scan
"""
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
...
...
@@ -43,6 +44,7 @@ def safe_new(x, tag='', dtype=None):
by gradient, or the R-op to construct new variables for the inputs of
the inner graph such that there is no interference between the original
graph and the newly constructed graph.
"""
if
hasattr
(
x
,
'name'
)
and
x
.
name
is
not
None
:
nw_name
=
x
.
name
+
tag
...
...
@@ -117,21 +119,28 @@ class until(object):
between the condition and the list of outputs ( unless we enforce and
order, but since this was not impose up to know it can make quite a bit
of code to fail).
"""
def
__init__
(
self
,
condition
):
self
.
condition
=
tensor
.
as_tensor_variable
(
condition
)
assert
self
.
condition
.
ndim
==
0
def
traverse
(
out
,
x
,
x_copy
,
d
,
visited
=
None
):
''' Function used by scan to parse the tree and figure out which nodes
it needs to replace. There are two options :
"""
Function used by scan to parse the tree and figure out which nodes
it needs to replace.
There are two options :
1) x and x_copy or on host, then you would replace x with x_copy
2) x is on gpu, x_copy on host, then you need to replace
host_from_gpu(x) with x_copy
This happens because initially shared variables are on GPU
.. which is
This happens because initially shared variables are on GPU
.
.. which is
fine for the main computational graph but confuses things a bit for the
inner graph of scan '''
inner graph of scan.
"""
# ``visited`` is a set of nodes that are already known and don't need to be
# checked again, speeding up the traversal of multiply-connected graphs.
# if a ``visited`` set is given, it will be updated in-place so the callee
...
...
@@ -191,25 +200,25 @@ def clone(output,
share_inputs
=
True
,
copy_inputs
=
DEPRECATED_ARG
):
"""
Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding
Function that allows replacing subgraphs of a computational graph.
It returns a copy of the initial subgraph with the corresponding
substitutions.
:type output: Theano Variables (or Theano expressions)
:param outputs: Theano expression that represents the computational
graph
Parameters
----------
output : Theano Variables (or Theano expressions)
Theano expression that represents the computational graph.
replace : dict
Dictionary describing which subgraphs should be replaced by what.
share_inputs : bool
If True, use the same inputs (and shared variables) as the original
graph. If False, clone them. Note that cloned shared variables still
use the same underlying storage, so they will always have the same
value.
copy_inputs
Deprecated, use share_inputs.
:type replace: dict
:param replace: dictionary describing which subgraphs should be
replaced by what
:type share_inputs: bool
:param share_inputs: If True, use the same inputs (and shared variables)
as the original graph. If False, clone them. Note that cloned
shared variables still use the same underlying storage, so they
will always have the same value.
:param copy_inputs: deprecated, use share_inputs.
"""
if
copy_inputs
is
not
DEPRECATED_ARG
:
warnings
.
warn
(
'In `clone()` function, the argument `copy_inputs` has been deprecated and renamed into `share_inputs`'
)
...
...
@@ -251,7 +260,7 @@ def get_updates_and_outputs(ls):
"""
This function tries to recognize the updates OrderedDict, the
list of outputs and the stopping condition returned by the
lambda expression and arrange them in a predefined order
lambda expression and arrange them in a predefined order
.
WRITEME: what is the type of ls? how is it formatted?
if it's not in the predefined order already, how does
...
...
@@ -297,6 +306,7 @@ def get_updates_and_outputs(ls):
Return True iff `x` is made only of lists, tuples, dictionaries, Theano
variables or `theano.scan_module.until` objects.
"""
# Is `x` a container we can iterate on?
iter_on
=
None
...
...
@@ -390,10 +400,11 @@ def isNaN_or_Inf_or_None(x):
def
expand
(
tensor_var
,
size
):
'''
"""
Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
by adding 0s at the end of the tensor.
'''
"""
# Corner case that I might use in an optimization
if
size
==
0
:
return
tensor_var
...
...
@@ -406,7 +417,7 @@ def expand(tensor_var, size):
def
equal_computations
(
xs
,
ys
,
in_xs
=
None
,
in_ys
=
None
):
'''
Checks if Theano graphs represent the same computations.
"""
Checks if Theano graphs represent the same computations.
The two lists `xs`, `ys` should have the same number of entries. The
function checks if for any corresponding pair `(x,y)` from `zip(xs,ys)`
...
...
@@ -420,7 +431,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
`ys`, but also represent subgraphs of a computational graph in `xs`
or `ys`.
'''
"""
assert
len
(
xs
)
==
len
(
ys
)
if
in_xs
is
None
:
in_xs
=
[]
...
...
@@ -460,14 +471,16 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
# Explore the two graphs, in parallel, depth first, comparing the nodes
# along the way for equality.
def
compare_nodes
(
nd_x
,
nd_y
,
common
,
different
):
''' Compare two nodes to determine if they perform equal computation.
"""
Compare two nodes to determine if they perform equal computation.
This is done by comparing the ops, the number of inputs, outputs and
by ensuring that the inputs themselves are the result of equal
computation.
NOTE : This function relies on the variable common to cache
results to be more efficient.
'''
"""
if
nd_x
.
op
!=
nd_y
.
op
:
return
False
...
...
@@ -537,13 +550,14 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
def
infer_shape
(
outs
,
inputs
,
input_shapes
):
'''
Compute the shape of the outputs given the shape of the inputs
of a theano
graph.
"""
Compute the shape of the outputs given the shape of the inputs
of a theano
graph.
We do it this way to avoid compiling the inner function just to get
the shape. Changes to ShapeFeature could require changes in this function.
'''
"""
# We use a ShapeFeature because it has all the necessary logic
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
...
...
@@ -560,10 +574,10 @@ def infer_shape(outs, inputs, input_shapes):
shape_feature
.
set_shape
(
inp
,
inp_shp
)
def
local_traverse
(
out
):
'''
"""
Go back in the graph, from out, adding computable shapes to shape_of.
'''
"""
if
out
in
shape_feature
.
shape_of
:
# Its shape is already known
return
...
...
@@ -589,14 +603,18 @@ def infer_shape(outs, inputs, input_shapes):
class
Validator
(
object
):
def
__init__
(
self
,
valid
=
None
,
invalid
=
None
,
valid_equivalent
=
None
):
'''
Check if variables can be expressed without using variables in invalid.
"""
Check if variables can be expressed without using variables in invalid.
Parameters
----------
valid_equivalent
Provides a dictionary mapping some invalid variables to valid ones that
can be used instead.
init_valid_equivalent provides a dictionary mapping some invalid
variables to valid ones that can be used instead.
'''
"""
def
__init__
(
self
,
valid
=
None
,
invalid
=
None
,
valid_equivalent
=
None
):
if
valid
is
None
:
valid
=
[]
if
invalid
is
None
:
...
...
@@ -616,13 +634,14 @@ class Validator(object):
self
.
invalid
.
update
(
list
(
valid_equivalent
.
keys
()))
def
check
(
self
,
out
):
'''
"""
Go backwards in the graph, from out, and check if out is valid.
If out is a valid node, (out, True) is returned.
If out is not valid, but has an equivalent e, (e, False) is returned.
If out is not valid and has no equivalent, None is returned.
'''
"""
if
out
in
self
.
valid
:
return
out
,
True
elif
out
in
self
.
valid_equivalent
:
...
...
@@ -667,12 +686,13 @@ class Validator(object):
def
scan_can_remove_outs
(
op
,
out_idxs
):
'''
"""
Looks at all outputs defined by indices ``out_idxs`` and see whom can be
removed from the scan op without affecting the rest. Return two lists,
the first one with the indices of outs that can be removed, the second
with the outputs that can not be removed.
'''
"""
non_removable
=
[
o
for
i
,
o
in
enumerate
(
op
.
outputs
)
if
i
not
in
out_idxs
]
required_inputs
=
gof
.
graph
.
inputs
(
non_removable
)
...
...
@@ -706,7 +726,7 @@ def scan_can_remove_outs(op, out_idxs):
def
compress_outs
(
op
,
not_required
,
inputs
):
'''
"""
Helpful function that gets a Scan op, a list of indices indicating
which outputs are not required anymore and should be removed, and
a list of inputs to the apply node corresponding to the scan op and
...
...
@@ -714,7 +734,8 @@ def compress_outs(op, not_required, inputs):
the indicated outputs are eliminated. Note that eliminating an output
means removing its inputs from the inner funciton and from the
node inputs, and changing the dictionary.
'''
"""
info
=
OrderedDict
()
info
[
'tap_array'
]
=
[]
info
[
'n_seqs'
]
=
op
.
info
[
'n_seqs'
]
...
...
@@ -852,6 +873,7 @@ def compress_outs(op, not_required, inputs):
def
find_up
(
l_node
,
f_node
):
r"""
Goes up in the graph and returns True if a node in nodes is found.
"""
if
isinstance
(
l_node
,
gof
.
Apply
):
l_outs
=
l_node
.
outputs
...
...
@@ -866,8 +888,9 @@ def reconstruct_graph(inputs, outputs, tag=None):
"""
Different interface to clone, that allows you to pass inputs.
Compared to clone, this method always replaces the inputs with
new variables of the same type, and returns those (
in the same
new variables of the same type, and returns those (in the same
order as the original inputs).
"""
if
tag
is
None
:
tag
=
''
...
...
@@ -885,7 +908,11 @@ def reconstruct_graph(inputs, outputs, tag=None):
class
scan_args
(
object
):
"""Parses the inputs and outputs of scan in an easy to manipulate format"""
"""
Parses the inputs and outputs of scan in an easy to manipulate format.
"""
def
__init__
(
self
,
outer_inputs
,
outer_outputs
,
_inner_inputs
,
_inner_outputs
,
info
):
self
.
n_steps
=
outer_inputs
[
0
]
...
...
@@ -1070,17 +1097,22 @@ class scan_args(object):
def
forced_replace
(
out
,
x
,
y
):
"""
:param out: Theano Variable
:param x: Theano Variable
:param y: Theano Variable
This function checks all internal values of the graph that computes the
variable ``out`` for occurances of values identical with ``x``. If such
occurances are encountered then they are replaced with variable ``y``.
For example:
out := sigmoid(wu)*(1-sigmoid(wu))
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y)
Check all internal values of the graph that compute the variable ``out``
for occurrences of values identical with ``x``. If such occurrences are
encountered then they are replaced with variable ``y``.
Parameters
----------
out : Theano Variable
x : Theano Variable
y : Theano Variable
Examples
--------
out := sigmoid(wu)*(1-sigmoid(wu))
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y)
"""
if
out
is
None
:
return
None
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论