Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
bacd93af
提交
bacd93af
authored
8月 11, 2015
作者:
Iban Harlouchet
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
numpydoc for theano/sandbox/scan_module/scan_utils.py
上级
8d4e690a
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
48 行增加
和
33 行删除
+48
-33
scan_utils.py
theano/sandbox/scan_module/scan_utils.py
+48
-33
没有找到文件。
theano/sandbox/scan_module/scan_utils.py
浏览文件 @
bacd93af
"""
This module provides utility functions for the Scan Op
This module provides utility functions for the Scan Op.
See scan.py for details on scan.
See scan.py for details on scan
"""
from
__future__
import
print_function
__docformat__
=
'restructedtext en'
...
...
@@ -41,8 +42,11 @@ def expand(tensor_var, size):
``tensor_var``, namely:
rval[:d1] = tensor_var
:param tensor_var: Theano tensor variable
:param size: int
Parameters
----------
tensor_var : Theano tensor variable.
size : int
"""
# Corner case that I might use in an optimization
if
size
==
0
:
...
...
@@ -57,7 +61,8 @@ def expand(tensor_var, size):
def
to_list
(
ls
):
"""
Converts ``ls`` to list if it is a tuple, or wraps ``ls`` into a list if
it is not a list already
it is not a list already.
"""
if
isinstance
(
ls
,
(
list
,
tuple
)):
return
list
(
ls
)
...
...
@@ -70,7 +75,9 @@ class until(object):
Theano can end on a condition. In order to differentiate this condition
from the other outputs of scan, this class is used to wrap the condition
around it.
"""
def
__init__
(
self
,
condition
):
self
.
condition
=
tensor
.
as_tensor_variable
(
condition
)
assert
self
.
condition
.
ndim
==
0
...
...
@@ -78,10 +85,12 @@ class until(object):
def
get_updates_and_outputs
(
ls
):
"""
Parses the list ``ls`` into outputs and updates. The semantics
of ``ls`` is defined by the constructive function of scan.
Parses the list ``ls`` into outputs and updates.
The semantics of ``ls`` is defined by the constructive function of scan.
The elemets of ``ls`` are either a list of expressions representing the
outputs/states, a dictionary of updates or a condition.
"""
def
is_list_outputs
(
elem
):
if
(
isinstance
(
elem
,
(
list
,
tuple
))
and
...
...
@@ -150,23 +159,23 @@ def get_updates_and_outputs(ls):
def
clone
(
output
,
replace
=
None
,
strict
=
True
,
share_inputs
=
True
):
"""
Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding
Function that allows replacing subgraphs of a computational graph.
It returns a copy of the initial subgraph with the corresponding
substitutions.
:type output: Theano Variables (or Theano expressions)
:param outputs: Theano expression that represents the computational
graph
:type replace: dict
:param replace: dictionary describing which subgraphs should be
replaced by what
Parameters
----------
output : Theano Variables (or Theano expressions)
Theano expression that represents the computational graph.
replace: dict
Dictionary describing which subgraphs should be replaced by what.
share_inputs : bool
If True, use the same inputs (and shared variables) as the original
graph. If False, clone them. Note that cloned shared variables still
use the same underlying storage, so they will always have the same
value.
:type share_inputs: bool
:param share_inputs: If True, use the same inputs (and shared variables)
as the original graph. If False, clone them. Note that cloned
shared variables still use the same underlying storage, so they
will always have the same value.
"""
inps
,
outs
,
other_stuff
=
rebuild_collect_shared
(
output
,
[],
...
...
@@ -189,6 +198,7 @@ def canonical_arguments(sequences,
Mainly it makes sure that arguments are given as lists of dictionaries,
and that the different fields of of a dictionary are set to default
value if the user has not provided any.
"""
states_info
=
to_list
(
outputs_info
)
parameters
=
[
tensor
.
as_tensor_variable
(
x
)
for
x
in
to_list
(
non_sequences
)]
...
...
@@ -303,13 +313,14 @@ def canonical_arguments(sequences,
def
infer_shape
(
outs
,
inputs
,
input_shapes
):
'''
"""
Compute the shape of the outputs given the shape of the inputs
of a theano graph.
We do it this way to avoid compiling the inner function just to get
the shape. Changes to ShapeFeature could require changes in this function.
'''
We do it this way to avoid compiling the inner function just to get the
shape. Changes to ShapeFeature could require changes in this function.
"""
# We use a ShapeFeature because it has all the necessary logic
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
...
...
@@ -326,9 +337,10 @@ def infer_shape(outs, inputs, input_shapes):
shape_feature
.
set_shape
(
inp
,
inp_shp
)
def
local_traverse
(
out
):
'''
"""
Go back in the graph, from out, adding computable shapes to shape_of.
'''
"""
if
out
in
shape_feature
.
shape_of
:
# Its shape is already known
...
...
@@ -358,14 +370,17 @@ def allocate_memory(T, y_info, y):
"""
Allocates memory for an output of scan.
:param T: scalar
Variable representing the number of steps scan will run
:param y_info: dict
Parameters
----------
T : scalar
Variable representing the number of steps scan will run.
y_info : dict
Dictionary describing the output (more specifically describing shape
information for the output
:param y
: Tensor variable
information for the output
.
y
: Tensor variable
Expression describing the computation resulting in out entry of y.
It can be used to infer the shape of y
It can be used to infer the shape of y.
"""
if
'shape'
in
y_info
:
return
tensor
.
zeros
([
T
,
]
+
list
(
y_info
[
'shape'
]),
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论