Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6217ef05
提交
6217ef05
authored
11月 13, 2012
作者:
Ian Goodfellow
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
removed grad_sources_inputs
上级
72ed94f5
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
0 行增加
和
107 行删除
+0
-107
gradient.py
theano/gradient.py
+0
-107
没有找到文件。
theano/gradient.py
浏览文件 @
6217ef05
...
@@ -1045,113 +1045,6 @@ def _populate_grad_dict(var_to_node_to_idx,
...
@@ -1045,113 +1045,6 @@ def _populate_grad_dict(var_to_node_to_idx,
return
rval
return
rval
def
grad_sources_inputs
(
sources
,
graph_inputs
):
"""
Used to compute the gradient of a cost with respect to all the
variables between graph_input and cost, but in the special
case where you don't know the cost, you only know its gradient
on a set of intermediate values.
A gradient source is a pair (``v``, ``g_v``), in which ``v`` is
a `Variable`, and ``g_v`` is a `Variable` that is a gradient wrt
``v``. More specifically, ``g_v`` is the gradient of an external
scalar cost, ``cost`` (that is not explicitly used), wrt ``v``.
This function traverses the graph backward from the ``r`` sources,
calling ``op.grad(...)`` for all ops with some non-None gradient
on an output, to compute gradients of ``cost`` wrt intermediate
variables and ``graph_inputs``.
The ``op.grad(...)`` functions are called like this:
.. code-block:: python
op.grad(op.inputs[:], [total_gradient(v) for v in op.outputs])
This call to ``op.grad`` should return a list or tuple: one symbolic
gradient per input. These gradients represent the gradients of
the same implicit ``cost`` mentionned above, wrt ``op.inputs``. Note
that this is **not** the same as the gradient of ``op.outputs`` wrt
``op.inputs``.
If ``op`` has a single input, then ``op.grad`` should return a list
or tuple of length 1.
For each input wrt to which ``op`` is not differentiable, it should
return ``None`` instead of a `Variable` instance.
If a source ``r`` receives a gradient from another source ``r2``,
then the effective gradient on ``r`` is the sum of both gradients.
:type sources: list of pairs of Variable: (v, gradient-on-v) to
initialize the total_gradient dictionary
:param sources: gradients to back-propagate using chain rule
:type graph_inputs: list of Variable
:param graph_inputs: variables considered to be constant
(do not backpropagate through them)
:rtype: dictionary whose keys and values are of type Variable
:return: mapping from each Variable encountered in the backward
traversal to the gradient with respect to that Variable.
It is assumed that there is some objective J shared between all members of
sources, so that for each v, gradient-on-v is the gradient of J with
respect to v
"""
outputs
,
output_grads
=
zip
(
*
sources
)
for
output_grad
in
output_grads
:
if
not
hasattr
(
output_grad
,
'type'
):
raise
TypeError
(
'output grads must be theano variables.'
'Ambiguous whether
%
s should be made into tensor'
' or sparse theano variable'
%
str
(
type
(
output_grad
)))
if
graph_inputs
is
None
:
graph_inputs
=
gof
.
graph
.
inputs
(
outputs
)
wrt
=
graph_inputs
var_to_node_to_idx
=
_populate_var_to_node_to_idx
(
outputs
,
wrt
,
None
)
# build a dict mapping var to the gradient of cost with respect to var
grad_dict
=
{}
for
output
,
output_grad
in
sources
:
# The gradient of the cost should always be 0 if the cost is of
# discrete (integer) dtype.
if
getattr
(
output
.
type
,
'dtype'
,
''
)
not
in
theano
.
tensor
.
float_dtypes
:
output_grad
=
output
.
zeros_like
()
else
:
# Cast the provided gradient so that it has the same dtype
# as the cost.
output_grad
=
output_grad
.
astype
(
output
.
type
.
dtype
)
grad_dict
[
output
]
=
output_grad
# variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_node_to_idx won't cause an error below
# according to the flag, possibly raise an error if wrt is disconnected
for
elem
in
wrt
:
if
elem
not
in
var_to_node_to_idx
and
elem
not
in
outputs
:
grad_dict
[
elem
]
=
DisconnectedType
()()
_populate_grad_dict
(
var_to_node_to_idx
,
grad_dict
,
wrt
)
# post-process out the DisconnectedTypes
for
key
in
grad_dict
:
if
isinstance
(
grad_dict
[
key
]
.
type
,
DisconnectedType
):
if
hasattr
(
key
,
'zeros_like'
):
grad_dict
[
key
]
=
_float_zeros_like
(
key
)
return
grad_dict
def
_float_zeros_like
(
x
):
def
_float_zeros_like
(
x
):
""" Like zeros_like, but forces the object to have a
""" Like zeros_like, but forces the object to have a
a floating point dtype """
a floating point dtype """
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论