Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c0f7334c
提交
c0f7334c
authored
1月 27, 2017
作者:
khaotik
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change R_op overriding to new format
上级
8e8758a3
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
119 行增加
和
79 行删除
+119
-79
builders.py
theano/compile/builders.py
+119
-79
没有找到文件。
theano/compile/builders.py
浏览文件 @
c0f7334c
"""Define new Ops from existing Ops"""
from
__future__
import
absolute_import
,
print_function
,
division
from
functools
import
reduce
from
functools
import
reduce
,
partial
from
collections
import
OrderedDict
import
theano
...
...
@@ -62,7 +62,7 @@ class OpFromGraph(gof.Op):
Defaults to ``None``.
``None`` : No value, gives NullType()
``0`` : zero value, gives
DisconnectedType(
)
``0`` : zero value, gives
zeros_like(...
)
``...`` : Do not override, use default R_op() result
OpFromGraph instance : Override with another OpFromGraph, should
...
...
@@ -92,14 +92,15 @@ class OpFromGraph(gof.Op):
local_outputs)
- c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface
- extend to lop_overrides?
- extend grad() to L_op
- add support for NullType and DisconnectedType when R_op supports them
- check how it works with updates.
- add test with constant as input or inside the inner graph.
- Add support for the GPU? Probably just need an opt to remove transfer
- Add support to pickle this Op.
- Add support/test with random generator
- Add optimization to removing unused inputs/outputs
- Add optimization to work inplace when not inline
- Add optimization to work inplace
on inputs
when not inline
Notes
-----
...
...
@@ -113,7 +114,7 @@ class OpFromGraph(gof.Op):
``fast_run`` mode.
- It's recommanded to provide pure functions (no side effects like
setting global variable) as callable(s). The callable(s) supplied
for overrding gradient/rop will be called only once at the first
for overr
i
ding gradient/rop will be called only once at the first
call to grad/R_op, and will be converted to OpFromGraph instances.
Examples
...
...
@@ -176,25 +177,35 @@ class OpFromGraph(gof.Op):
# grad: gradient Variable
# inp: the corresponding input of gradient Variable
#
# Some Variable types cannot be used directly as OfG output such as
# NullType, or DisconnectedType.
#
# However a grad() call could return these types
# a grad() call could return instance of NullType() or DisconnectedType()
# which cannot be directly used in OfG
#
# Since we always use an OfG instance as self._grad_op, the current
# workaround is to "remember" the special cases of the gradient and
# replace them after self._grad_op is called.
#
# This helper function changes invalid types into a filtered_
type
,
# and provides a overrider_
type
to be replaced at grad() call
# This helper function changes invalid types into a filtered_
var
,
# and provides a overrider_
var
to be replaced at grad() call
#
# For now, this converts NullType or DisconnectedType into zeros_like.
# other types are unmodified
with overrider_type
-> None
# other types are unmodified
: overrider_var
-> None
if
isinstance
(
grad
.
type
,
(
NullType
,
DisconnectedType
)):
return
inp
.
zeros_like
(),
grad
.
type
return
inp
.
zeros_like
(),
grad
else
:
return
grad
,
None
@staticmethod
def
_filter_rop_var
(
inpJ
,
out
):
# mostly similar to _filter_grad_var
if
isinstance
(
inpJ
.
type
,
NullType
):
return
out
.
zeros_like
(),
inpJ
if
isinstance
(
inpJ
.
type
,
DisconnectedType
):
# since R_op does not have DisconnectedType yet, we will just
# make them zeros.
return
out
.
zeros_like
(),
None
else
:
return
inpJ
,
None
def
__init__
(
self
,
inputs
,
outputs
,
inline
=
False
,
...
...
@@ -202,7 +213,7 @@ class OpFromGraph(gof.Op):
name
=
None
,
**
kwargs
):
if
not
isinstance
(
outputs
,
list
):
raise
TypeError
(
'outputs must be list, got
%
s'
%
outputs
,
outputs
)
raise
TypeError
(
'outputs must be list, got
%
s'
%
type
(
outputs
)
,
outputs
)
for
i
in
inputs
+
outputs
:
if
not
isinstance
(
i
,
gof
.
Variable
):
raise
TypeError
(
...
...
@@ -263,21 +274,26 @@ class OpFromGraph(gof.Op):
if
isinstance
(
grad_op
,
OpFromGraph
):
self
.
_grad_op_is_cached
=
True
self
.
_grad_op_overrides_l
=
[
None
]
*
len
(
self
.
local_inputs
)
self
.
_grad_op_overrides_l
=
[
None
]
*
inp_len
return
output_grads
=
[
out_t
()
for
out_t
in
self
.
output_types
]
fn_grad
=
partial
(
theano
.
gradient
.
grad
,
cost
=
None
,
disconnected_inputs
=
'ignore'
,
return_disconnected
=
'Disconnected'
,
null_gradients
=
'return'
,
known_grads
=
OrderedDict
(
izip
(
local_outputs
,
output_grads
)))
TYPE_ERR_MSG
=
'Gradient override should be (single or list of)'
\
'OpFromGraph | Ellipsis | None | 0 | callable, got
%
s'
# we need to convert _grad_op into an OfG instance
if
grad_op
is
Ellipsis
:
self
.
_grad_op_tflags
=
bytes
(
inp_len
)
all_grads_l
=
theano
.
gradient
.
grad
(
cost
=
None
,
known_grads
=
OrderedDict
(
izip
(
local_outputs
,
output_grads
)),
wrt
=
local_inputs
,
disconnected_inputs
=
'ignore'
)
gdefaults_l
=
fn_grad
(
wrt
=
local_inputs
)
all_grads_ov_l
=
[
None
]
*
inp_len
all_grads_l
,
all_grads_ov_l
=
izip
(
*
[
OpFromGraph
.
_filter_grad_var
(
grad
,
inp
)
for
grad
,
inp
in
izip
(
gdefaults_l
,
local_inputs
)])
elif
grad_op
is
None
:
all_grads_l
=
[
inp
.
zeros_like
()
for
inp
in
local_inputs
]
all_grads_ov_l
=
[
self
.
ofg_null_t
()]
*
inp_len
...
...
@@ -285,7 +301,7 @@ class OpFromGraph(gof.Op):
all_grads_l
=
[
inp
.
zeros_like
()
for
inp
in
local_inputs
]
all_grads_ov_l
=
[
self
.
ofg_discon_t
()]
*
inp_len
elif
isinstance
(
grad_op
,
list
):
goverrides_l
=
self
.
_
grad_op
goverrides_l
=
grad_op
if
len
(
goverrides_l
)
!=
inp_len
:
raise
ValueError
(
'Need to override
%
d gradients, got
%
d'
%
(
...
...
@@ -293,18 +309,15 @@ class OpFromGraph(gof.Op):
# compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
wrt_l
=
[
lin
for
lin
,
gov
in
izip
(
self
.
local_inputs
,
goverrides_l
)
if
gov
is
Ellipsis
]
gdefaults
=
iter
(
theano
.
gradient
.
grad
(
cost
=
None
,
known_grads
=
OrderedDict
(
izip
(
self
.
local_outputs
,
output_grads
)),
wrt
=
wrt_l
,
disconnected_inputs
=
'ignore'
)
if
wrt_l
else
[])
local_inputs
,
goverrides_l
)
if
gov
is
Ellipsis
]
gdefaults
=
iter
(
fn_grad
(
wrt
=
wrt_l
)
if
wrt_l
else
[])
# combine overriding gradients
all_grads_l
=
[]
all_grads_ov_l
=
[]
for
i
,
(
inp
,
fn_gov
)
in
enumerate
(
izip
(
local_inputs
,
goverrides_l
)
):
for
i
np
,
fn_gov
in
izip
(
local_inputs
,
goverrides_l
):
if
fn_gov
is
Ellipsis
:
gnext
,
gnext_ov
=
OpFromGraph
.
_filter_grad_var
(
next
(
gdefaults
),
inp
)
gnext
,
gnext_ov
=
OpFromGraph
.
_filter_grad_var
(
next
(
gdefaults
),
inp
)
all_grads_l
.
append
(
gnext
)
all_grads_ov_l
.
append
(
gnext_ov
)
elif
fn_gov
is
0
:
...
...
@@ -330,13 +343,14 @@ class OpFromGraph(gof.Op):
'Gradient overriding function should return a list, '
'got "
%
s"'
%
type
(
goverrides_l
))
all_grads_l
,
all_grads_ov_l
=
izip
(
*
[
OpFromGraph
.
_filter_grad_var
(
grad
,
inp
)
for
grad
,
inp
in
izip
(
goverrides_l
,
local_inputs
)])
*
[
OpFromGraph
.
_filter_grad_var
(
grad
,
inp
)
for
grad
,
inp
in
izip
(
goverrides_l
,
local_inputs
)])
if
len
(
all_grads_l
)
!=
len
(
local_inputs
):
raise
ValueError
(
'Gradient overriding function should return list of '
'
%
d outputs, got
%
d'
%
(
inp_len
,
len
(
all_grads_l
)))
all_grads_l
=
list
(
all_grads_l
)
all_grads_ov_l
=
list
(
all_grads_ov_l
)
all_grads_l
=
list
(
all_grads_l
)
all_grads_ov_l
=
list
(
all_grads_ov_l
)
self
.
_grad_op
=
type
(
self
)(
inputs
=
local_inputs
+
output_grads
,
outputs
=
all_grads_l
,
...
...
@@ -347,65 +361,92 @@ class OpFromGraph(gof.Op):
self
.
_grad_op_is_cached
=
True
def
_recompute_rop_op
(
self
):
local_inputs
=
self
.
local_inputs
local_outputs
=
self
.
local_outputs
out_len
=
len
(
local_outputs
)
rop_op
=
self
.
_rop_op
if
isinstance
(
self
.
_rop_op
,
OpFromGraph
):
self
.
_rop_op_is_cached
=
True
self
.
_rop_op_overrides_l
=
[
None
]
*
out_len
return
eval_points
=
[
inp_t
()
for
inp_t
in
self
.
input_types
]
if
self
.
_rop_op
is
None
:
self
.
_rop_op
=
[]
if
isinstance
(
self
.
_rop_op
,
list
):
roverrides_l
=
self
.
_rop_op
if
len
(
roverrides_l
)
>
len
(
self
.
local_outputs
):
eval_points
=
[
inp_t
()
for
inp_t
in
self
.
input_types
]
fn_rop
=
partial
(
theano
.
gradient
.
Rop
,
wrt
=
local_inputs
,
eval_points
=
eval_points
)
TYPE_ERR_MSG
=
'R_op overrides should be (single or list of)'
\
'OpFromGraph | Ellipsis | None | 0 | callable, got
%
s'
if
rop_op
is
Ellipsis
:
all_rops_l
=
fn_rop
(
f
=
local_outputs
)
all_rops_ov_l
=
[
None
]
*
out_len
elif
rop_op
is
None
:
all_rops_l
=
[
out
.
zeros_like
()
for
out
in
local_outputs
]
all_rops_ov_l
=
[
self
.
ofg_null_t
()]
*
out_len
elif
rop_op
is
0
:
all_rops_l
=
[
out
.
zeros_like
()
for
out
in
local_outputs
]
all_rops_ov_l
=
[
None
]
*
out_len
elif
isinstance
(
rop_op
,
list
):
roverrides_l
=
rop_op
if
len
(
roverrides_l
)
!=
out_len
:
raise
ValueError
(
'Can override
%
d gradients at most, got
%
d'
%
(
len
(
self
.
local_onputs
),
len
(
roverrides_l
)),
roverrides_l
)
if
len
(
roverrides_l
)
<
len
(
self
.
local_outputs
):
roverrides_l
+=
[
None
]
*
(
len
(
self
.
local_outputs
)
-
len
(
roverrides_l
))
'Need to override
%
d Rop, got
%
d'
%
(
out_len
,
len
(
roverrides_l
)),
roverrides_l
)
# get outputs that does not have Rop override
odefaults_l
=
[
lo
for
lo
,
rov
in
izip
(
self
.
local_outputs
,
roverrides_l
)
if
not
rov
]
rdefaults_li
=
theano
.
gradient
.
Rop
(
f
=
odefaults_l
,
wrt
=
self
.
local_inputs
,
eval_points
=
eval_points
)
rdefaults
=
iter
(
rdefaults_li
if
odefaults_l
else
[])
lo
for
lo
,
rov
in
izip
(
local_outputs
,
roverrides_l
)
if
rov
is
Ellipsis
]
rdefaults_l
=
fn_rop
(
f
=
odefaults_l
)
rdefaults
=
iter
(
rdefaults_l
if
odefaults_l
else
[])
# combine overriding Rops
all_rops_l
=
[]
for
out
,
rov
in
izip
(
self
.
local_outputs
,
roverrides_l
):
if
rov
is
None
:
all_rops_l
.
append
(
next
(
rdefaults
))
elif
rov
is
undef
:
all_rops_l
.
append
(
out
.
zeros_like
()
.
astype
(
theano
.
config
.
floatX
))
all_rops_ov_l
=
[]
for
out
,
fn_rov
in
izip
(
local_outputs
,
roverrides_l
):
if
fn_rov
is
Ellipsis
:
rnext
,
rnext_ov
=
OpFromGraph
.
_filter_rop_var
(
next
(
rdefaults
),
out
)
all_rops_l
.
append
(
rnext
)
all_rops_ov_l
.
append
(
rnext_ov
)
elif
fn_rov
is
0
:
all_rops_l
.
append
(
out
.
zeros_like
())
all_rops_ov_l
.
append
(
None
)
elif
fn_rov
is
None
:
all_rops_l
.
append
(
out
.
zeros_like
())
all_rops_ov_l
.
append
(
self
.
ofg_null_t
())
else
:
all_rops_l
.
append
(
rov
(
self
.
local_inputs
,
eval_points
))
elif
self
.
_rop_op
is
undef
:
all_rops_l
=
[
out
.
zeros_like
()
.
astype
(
theano
.
config
.
floatX
)
for
out
in
self
.
local_outputs
]
if
not
hasattr
(
fn_rov
,
'__call__'
):
raise
TypeError
(
TYPE_ERR_MSG
%
fn_rov
)
rov
,
rov_ov
=
OpFromGraph
.
_filter_rop_var
(
fn_rov
(
local_inputs
,
eval_points
),
out
)
all_rops_l
.
append
(
rov
)
all_rops_ov_l
.
append
(
rov_ov
)
else
:
all_rops_l
=
self
.
_rop_op
(
self
.
local_inputs
,
eval_points
)
if
not
isinstance
(
all_rops_l
,
(
tuple
,
list
)):
all_rops_l
=
[
all_rops_l
]
if
len
(
all_rops_l
)
!=
len
(
self
.
local_outputs
):
if
not
hasattr
(
rop_op
,
'__call__'
):
raise
TypeError
(
TYPE_ERR_MSG
%
rop_op
)
roverrides_l
=
rop_op
(
local_inputs
,
eval_points
)
if
not
isinstance
(
roverrides_l
,
list
):
raise
TypeError
(
'Rop overriding function should return a list, '
'got "
%
s"'
%
type
(
roverrides_l
))
all_rops_l
,
all_rops_ov_l
=
izip
(
*
[
OpFromGraph
.
_filter_rop_var
(
rop
,
out
)
for
rop
,
out
in
izip
(
roverrides_l
,
local_outputs
)])
if
len
(
all_rops_l
)
!=
out_len
:
raise
ValueError
(
'Rop overriding function
%
s should return list of '
'
%
d outputs, got
%
d'
%
(
self
.
_rop_op
,
len
(
self
.
local_outputs
),
len
(
all_rops_l
)),
self
.
_rop_op
)
self
.
_rop_op
,
out_len
,
len
(
all_rops_l
)),
rop_op
)
all_rops_l
=
list
(
all_rops_l
)
all_rops_ov_l
=
list
(
all_rops_ov_l
)
self
.
_rop_op
=
type
(
self
)(
inputs
=
self
.
local_inputs
+
eval_points
,
inputs
=
local_inputs
+
eval_points
,
outputs
=
all_rops_l
,
inline
=
self
.
is_inline
,
name
=
(
None
if
self
.
name
is
None
else
self
.
name
+
'_rop'
),
on_unused_input
=
'ignore'
)
self
.
_rop_op_overrides_l
=
all_rops_ov_l
self
.
_rop_op_is_cached
=
True
def
get_grad_op
(
self
):
...
...
@@ -447,10 +488,9 @@ class OpFromGraph(gof.Op):
self
.
_recompute_rop_op
()
ret_ofg_l
=
self
.
_rop_op
(
*
(
list
(
inputs
)
+
list
(
eval_points
)),
return_list
=
True
)
ret_l
=
[{
self
.
TFLAG_NULL_T
:
self
.
ofg_null_t
(),
self
.
TFLAG_DISCON_T
:
self
.
ofg_discon_t
()
}[
flag
]
if
flag
else
ret_ofg
for
ret_ofg
,
flag
in
izip
(
ret_ofg_l
,
self
.
_grad_tflags
)]
ret_l
=
[
ret_ofg
if
ov
is
None
else
ov
for
ret_ofg
,
ov
in
izip
(
ret_ofg_l
,
self
.
_rop_op_overrides_l
)]
return
ret_l
def
grad
(
self
,
inputs
,
output_grads
):
...
...
@@ -459,10 +499,10 @@ class OpFromGraph(gof.Op):
ret_ofg_l
=
self
.
_grad_op
(
*
(
list
(
inputs
)
+
list
(
output_grads
)),
return_list
=
True
)
ret_l
=
[
ret_ofg
if
ov
is
None
else
ov
for
ret_ofg
,
ov
in
izip
(
ret_ofg_l
,
self
.
_grad_op_overrides_l
)]
ret_ofg
if
ov
is
None
else
ov
for
ret_ofg
,
ov
in
izip
(
ret_ofg_l
,
self
.
_grad_op_overrides_l
)]
return
ret_l
def
make_node
(
self
,
*
inputs
):
num_expected_inps
=
len
(
self
.
local_inputs
)
-
len
(
self
.
shared_inputs
)
if
len
(
inputs
)
!=
num_expected_inps
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论