Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
3f734be4
提交
3f734be4
authored
9月 11, 2012
作者:
Ian Goodfellow
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pep8 theano/gradient.py
上级
84a3f6f1
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
22 行增加
和
19 行删除
+22
-19
gradient.py
theano/gradient.py
+22
-19
没有找到文件。
theano/gradient.py
浏览文件 @
3f734be4
...
@@ -266,7 +266,7 @@ def Rop(f, wrt, eval_points):
...
@@ -266,7 +266,7 @@ def Rop(f, wrt, eval_points):
# we have to make it be wrong for Rop to keep working
# we have to make it be wrong for Rop to keep working
# Rop should eventually be upgraded to handle integers
# Rop should eventually be upgraded to handle integers
# correctly, the same as grad
# correctly, the same as grad
y
=
theano
.
tensor
.
cast
(
y
,
x
.
type
.
dtype
)
y
=
theano
.
tensor
.
cast
(
y
,
x
.
type
.
dtype
)
y
=
x
.
type
.
filter_variable
(
y
)
y
=
x
.
type
.
filter_variable
(
y
)
assert
x
.
type
==
y
.
type
assert
x
.
type
==
y
.
type
same_type_eval_points
.
append
(
y
)
same_type_eval_points
.
append
(
y
)
...
@@ -493,7 +493,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
...
@@ -493,7 +493,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# Make sure we didn't initialize the grad_dict with any ints
# Make sure we didn't initialize the grad_dict with any ints
for
var
in
grad_dict
:
for
var
in
grad_dict
:
g
=
grad_dict
[
var
]
g
=
grad_dict
[
var
]
if
hasattr
(
g
.
type
,
'dtype'
):
if
hasattr
(
g
.
type
,
'dtype'
):
assert
g
.
type
.
dtype
.
find
(
'float'
)
!=
-
1
assert
g
.
type
.
dtype
.
find
(
'float'
)
!=
-
1
rval
=
_populate_grad_dict
(
var_to_node_to_idx
,
rval
=
_populate_grad_dict
(
var_to_node_to_idx
,
...
@@ -509,6 +509,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
...
@@ -509,6 +509,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
rval
,
=
rval
rval
,
=
rval
return
rval
return
rval
def
_node_to_pattern
(
node
):
def
_node_to_pattern
(
node
):
""" given an apply node, obtain its connection pattern
""" given an apply node, obtain its connection pattern
this is just a wrapper around Op.connection_pattern
this is just a wrapper around Op.connection_pattern
...
@@ -516,7 +517,7 @@ def _node_to_pattern(node):
...
@@ -516,7 +517,7 @@ def _node_to_pattern(node):
if the method is not implemented
if the method is not implemented
"""
"""
if
hasattr
(
node
.
op
,
'connection_pattern'
):
if
hasattr
(
node
.
op
,
'connection_pattern'
):
connection_pattern
=
node
.
op
.
connection_pattern
(
node
)
connection_pattern
=
node
.
op
.
connection_pattern
(
node
)
if
not
isinstance
(
connection_pattern
,
list
):
if
not
isinstance
(
connection_pattern
,
list
):
...
@@ -538,7 +539,7 @@ def _node_to_pattern(node):
...
@@ -538,7 +539,7 @@ def _node_to_pattern(node):
connection_pattern
=
\
connection_pattern
=
\
[[
True
for
output
in
node
.
outputs
]
[[
True
for
output
in
node
.
outputs
]
for
ipt
in
node
.
inputs
]
for
ipt
in
node
.
inputs
]
assert
isinstance
(
connection_pattern
,
list
)
assert
isinstance
(
connection_pattern
,
list
)
assert
len
(
connection_pattern
)
==
len
(
node
.
inputs
)
assert
len
(
connection_pattern
)
==
len
(
node
.
inputs
)
for
ii
in
xrange
(
len
(
node
.
inputs
)):
for
ii
in
xrange
(
len
(
node
.
inputs
)):
assert
isinstance
(
connection_pattern
[
ii
],
list
)
assert
isinstance
(
connection_pattern
[
ii
],
list
)
...
@@ -546,6 +547,7 @@ def _node_to_pattern(node):
...
@@ -546,6 +547,7 @@ def _node_to_pattern(node):
len
(
node
.
outputs
)
len
(
node
.
outputs
)
return
connection_pattern
return
connection_pattern
def
_populate_var_to_node_to_idx
(
outputs
,
wrt
):
def
_populate_var_to_node_to_idx
(
outputs
,
wrt
):
"""
"""
Common code shared between grad and grad_sources_inputs
Common code shared between grad and grad_sources_inputs
...
@@ -583,7 +585,6 @@ def _populate_var_to_node_to_idx(outputs, wrt):
...
@@ -583,7 +585,6 @@ def _populate_var_to_node_to_idx(outputs, wrt):
# connection_pattern)
# connection_pattern)
accounted_for
=
set
([])
accounted_for
=
set
([])
def
account_for
(
var
):
def
account_for
(
var
):
if
var
in
accounted_for
:
if
var
in
accounted_for
:
return
return
...
@@ -693,16 +694,16 @@ def _populate_grad_dict(var_to_node_to_idx,
...
@@ -693,16 +694,16 @@ def _populate_grad_dict(var_to_node_to_idx,
output_grads
=
[
access_grad_cache
(
var
)
for
var
in
node
.
outputs
]
output_grads
=
[
access_grad_cache
(
var
)
for
var
in
node
.
outputs
]
# list of bools indicating if each output is connected to the cost
# list of bools indicating if each output is connected to the cost
outputs_connected
=
[
not
isinstance
(
g
.
type
,
DisconnectedType
)
outputs_connected
=
[
not
isinstance
(
g
.
type
,
DisconnectedType
)
for
g
in
output_grads
]
for
g
in
output_grads
]
connection_pattern
=
_node_to_pattern
(
node
)
connection_pattern
=
_node_to_pattern
(
node
)
# list of bools indicating if each input is connected to the cost
# list of bools indicating if each input is connected to the cost
inputs_connected
=
[
inputs_connected
=
[
(
True
in
[
input_to_output
and
output_to_cost
for
(
True
in
[
input_to_output
and
output_to_cost
for
input_to_output
,
output_to_cost
in
input_to_output
,
output_to_cost
in
zip
(
input_to_outputs
,
outputs_connected
)
])
for
zip
(
input_to_outputs
,
outputs_connected
)])
for
input_to_outputs
in
connection_pattern
input_to_outputs
in
connection_pattern
]
]
...
@@ -752,16 +753,16 @@ def _populate_grad_dict(var_to_node_to_idx,
...
@@ -752,16 +753,16 @@ def _populate_grad_dict(var_to_node_to_idx,
# Do type checking on the result
# Do type checking on the result
#List of bools indicating if each output is an integer dtype
#List of bools indicating if each output is an integer dtype
output_is_int
=
[
hasattr
(
output
.
type
,
'dtype'
)
and
output_is_int
=
[
hasattr
(
output
.
type
,
'dtype'
)
and
output
.
type
.
dtype
.
find
(
'int'
)
!=
-
1
output
.
type
.
dtype
.
find
(
'int'
)
!=
-
1
for
output
in
node
.
outputs
]
for
output
in
node
.
outputs
]
#List of bools indicating if each input only has integer outputs
#List of bools indicating if each input only has integer outputs
only_connected_to_int
=
[
(
True
not
in
only_connected_to_int
=
[(
True
not
in
[
in_to_out
and
out_to_cost
and
not
out_int
[
in_to_out
and
out_to_cost
and
not
out_int
for
in_to_out
,
out_to_cost
,
out_int
in
for
in_to_out
,
out_to_cost
,
out_int
in
zip
(
in_to_outs
,
outputs_connected
,
output_is_int
)
])
zip
(
in_to_outs
,
outputs_connected
,
output_is_int
)])
for
in_to_outs
in
connection_pattern
]
for
in_to_outs
in
connection_pattern
]
for
i
,
term
in
enumerate
(
input_grads
):
for
i
,
term
in
enumerate
(
input_grads
):
...
@@ -780,9 +781,9 @@ def _populate_grad_dict(var_to_node_to_idx,
...
@@ -780,9 +781,9 @@ def _populate_grad_dict(var_to_node_to_idx,
'functions.'
)
%
node
.
op
)
'functions.'
)
%
node
.
op
)
if
not
isinstance
(
term
.
type
,
if
not
isinstance
(
term
.
type
,
(
NullType
,
DisconnectedType
)):
(
NullType
,
DisconnectedType
)):
if
term
.
type
.
dtype
.
find
(
'float'
)
==
-
1
:
if
term
.
type
.
dtype
.
find
(
'float'
)
==
-
1
:
raise
TypeError
(
str
(
node
.
op
)
+
'.grad illegally '
raise
TypeError
(
str
(
node
.
op
)
+
'.grad illegally '
' returned an integer-valued variable.'
' returned an integer-valued variable.'
' (Input index
%
d, dtype
%
s)'
%
(
i
,
' (Input index
%
d, dtype
%
s)'
%
(
i
,
term
.
type
.
dtype
))
term
.
type
.
dtype
))
...
@@ -851,7 +852,6 @@ def _populate_grad_dict(var_to_node_to_idx,
...
@@ -851,7 +852,6 @@ def _populate_grad_dict(var_to_node_to_idx,
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
#Check that op.connection_pattern matches the connectivity
#Check that op.connection_pattern matches the connectivity
#logic driving the op.grad method
#logic driving the op.grad method
for
i
,
packed
in
\
for
i
,
packed
in
\
...
@@ -872,7 +872,7 @@ def _populate_grad_dict(var_to_node_to_idx,
...
@@ -872,7 +872,7 @@ def _populate_grad_dict(var_to_node_to_idx,
msg
=
"
%
s.grad returned DisconnectedType for input"
msg
=
"
%
s.grad returned DisconnectedType for input"
msg
+=
"
%
d."
msg
+=
"
%
d."
msg
=
msg
%
(
str
(
node
.
op
),
i
)
msg
=
msg
%
(
str
(
node
.
op
),
i
)
if
hasattr
(
node
.
op
,
'connection_pattern'
):
if
hasattr
(
node
.
op
,
'connection_pattern'
):
msg
+=
' Its connection_pattern method does not'
msg
+=
' Its connection_pattern method does not'
msg
+=
' allow this.'
msg
+=
' allow this.'
raise
TypeError
(
msg
)
raise
TypeError
(
msg
)
...
@@ -917,7 +917,7 @@ def _populate_grad_dict(var_to_node_to_idx,
...
@@ -917,7 +917,7 @@ def _populate_grad_dict(var_to_node_to_idx,
if
len
(
terms
)
>
0
:
if
len
(
terms
)
>
0
:
# the next line is like sum(terms) but doesn't add an
# the next line is like sum(terms) but doesn't add an
# extraneous TensorConstant(0)
# extraneous TensorConstant(0)
grad_dict
[
var
]
=
reduce
(
lambda
x
,
y
:
x
+
y
,
terms
)
grad_dict
[
var
]
=
reduce
(
lambda
x
,
y
:
x
+
y
,
terms
)
else
:
else
:
grad_dict
[
var
]
=
DisconnectedType
()()
grad_dict
[
var
]
=
DisconnectedType
()()
...
@@ -1029,6 +1029,7 @@ def grad_sources_inputs(sources, graph_inputs):
...
@@ -1029,6 +1029,7 @@ def grad_sources_inputs(sources, graph_inputs):
return
grad_dict
return
grad_dict
def
_float_zeros_like
(
x
):
def
_float_zeros_like
(
x
):
""" Like zeros_like, but forces the object to have a
""" Like zeros_like, but forces the object to have a
a floating point dtype """
a floating point dtype """
...
@@ -1040,6 +1041,7 @@ def _float_zeros_like(x):
...
@@ -1040,6 +1041,7 @@ def _float_zeros_like(x):
return
rval
.
astype
(
theano
.
config
.
floatX
)
return
rval
.
astype
(
theano
.
config
.
floatX
)
def
_float_ones_like
(
x
):
def
_float_ones_like
(
x
):
""" Like ones_like, but forces the object to have a
""" Like ones_like, but forces the object to have a
floating point dtype """
floating point dtype """
...
@@ -1051,6 +1053,7 @@ def _float_ones_like(x):
...
@@ -1051,6 +1053,7 @@ def _float_ones_like(x):
return
rval
.
astype
(
theano
.
config
.
floatX
)
return
rval
.
astype
(
theano
.
config
.
floatX
)
class
numeric_grad
(
object
):
class
numeric_grad
(
object
):
"""
"""
Compute the numeric derivative of a scalar-valued function at a particular
Compute the numeric derivative of a scalar-valued function at a particular
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论