Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2eb8fca2
提交
2eb8fca2
authored
4月 20, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor code that handles disconnected L_op/R_op outputs in OpFromGraph
上级
be799d8f
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
187 行增加
和
331 行删除
+187
-331
builders.py
pytensor/compile/builders.py
+187
-331
没有找到文件。
pytensor/compile/builders.py
浏览文件 @
2eb8fca2
"""Define new Ops from existing Ops"""
"""Define new Ops from existing Ops"""
import
warnings
import
warnings
from
collections
import
OrderedDict
from
collections.abc
import
Callable
,
Sequence
from
collections.abc
import
Callable
,
Sequence
from
copy
import
copy
from
copy
import
copy
from
functools
import
partial
from
functools
import
partial
from
typing
import
Union
,
cast
from
typing
import
Union
,
cast
import
pytensor.tensor
as
pt
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.compile.function.pfunc
import
rebuild_collect_shared
from
pytensor.compile.function.pfunc
import
rebuild_collect_shared
from
pytensor.compile.mode
import
optdb
from
pytensor.compile.mode
import
optdb
...
@@ -251,57 +249,6 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -251,57 +249,6 @@ class OpFromGraph(Op, HasInnerGraph):
"""
"""
TYPE_ERR_MSG
=
(
"L_op/gradient override should be (single or list of)"
"None | OpFromGraph | callable | Variable "
"with NullType or DisconnectedType, got
%
s"
)
STYPE_ERR_MSG
=
(
"Overriding Variable instance can only have type"
" of DisconnectedType or NullType, got
%
s"
)
LOP_TYPE_ERR_MSG
=
'L_op type can only be "grad" or "lop", got
%
s.'
OV_INP_LEN_ERR_MSG
=
"expect overrider with
%
d inputs, got
%
d"
@staticmethod
def
_filter_grad_var
(
grad
,
inp
):
# Returns (filtered_var, overrider_var)
# Args:
# grad: gradient Variable
# inp: the corresponding input of gradient Variable
#
# a grad() call could return instance of NullType() or DisconnectedType()
# which cannot be directly used in OfG
#
# Since we always use an OfG instance as self._lop_op, the current
# workaround is to "remember" the special cases of the gradient and
# replace them after self._lop_op is called.
#
# This helper function changes invalid types into a filtered_var,
# and provides a overrider_var to be replaced at grad() call
#
# For now, this converts NullType or DisconnectedType into zeros_like.
# other types are unmodified: overrider_var -> None
if
isinstance
(
grad
.
type
,
NullType
|
DisconnectedType
):
if
hasattr
(
inp
,
"zeros_like"
):
return
inp
.
zeros_like
(),
grad
else
:
return
pt
.
constant
(
0.0
),
grad
else
:
return
grad
,
None
@staticmethod
def
_filter_rop_var
(
inpJ
,
out
):
# mostly similar to _filter_grad_var
if
isinstance
(
inpJ
.
type
,
NullType
):
return
out
.
zeros_like
(),
inpJ
if
isinstance
(
inpJ
.
type
,
DisconnectedType
):
# since R_op does not have DisconnectedType yet, we will just
# make them zeros.
return
out
.
zeros_like
(),
None
else
:
return
inpJ
,
None
def
__init__
(
def
__init__
(
self
,
self
,
inputs
:
list
[
Variable
],
inputs
:
list
[
Variable
],
...
@@ -322,8 +269,10 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -322,8 +269,10 @@ class OpFromGraph(Op, HasInnerGraph):
----------
----------
inputs
inputs
The inputs to the graph.
The inputs to the graph.
outputs
outputs
The outputs to the graph.
The outputs to the graph.
inline
inline
Defaults to ``False``
Defaults to ``False``
...
@@ -332,6 +281,7 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -332,6 +281,7 @@ class OpFromGraph(Op, HasInnerGraph):
graph but rather its internal graph.
graph but rather its internal graph.
``False`` : will use a pre-compiled function inside.
``False`` : will use a pre-compiled function inside.
grad_overrides
grad_overrides
Defaults to ``None``.
Defaults to ``None``.
This argument is mutually exclusive with ``lop_overrides``.
This argument is mutually exclusive with ``lop_overrides``.
...
@@ -345,6 +295,7 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -345,6 +295,7 @@ class OpFromGraph(Op, HasInnerGraph):
`callable`: Should take two args: ``inputs`` and ``output_grads``.
`callable`: Should take two args: ``inputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable `.
Each argument is expected to be a list of :class:`Variable `.
Must return list of :class:`Variable `.
Must return list of :class:`Variable `.
lop_overrides
lop_overrides
Defaults to ``None``.
Defaults to ``None``.
...
@@ -364,10 +315,6 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -364,10 +315,6 @@ class OpFromGraph(Op, HasInnerGraph):
Each argument is expected to be a list of :class:`Variable`.
Each argument is expected to be a list of :class:`Variable`.
Must return list of :class:`Variable`.
Must return list of :class:`Variable`.
`NullType` instance: Treat as non-differentiable
`DisconnectedType` instance: Treat as disconnected gradient,
numerically gives zero
``list``: Each `OpFromGraph`/callable must return a single
``list``: Each `OpFromGraph`/callable must return a single
:class:`Variable`. Each list element corresponds to gradient of
:class:`Variable`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
a specific input, length of list must be equal to number of inputs.
...
@@ -387,10 +334,6 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -387,10 +334,6 @@ class OpFromGraph(Op, HasInnerGraph):
Each argument is expected to be a list of :class:`Variable`. Must
Each argument is expected to be a list of :class:`Variable`. Must
return list of :class:`Variable`.
return list of :class:`Variable`.
`NullType` instance: Treat as non-differentiable `DisconnectedType`
instance: Treat as zero since `DisconnectedType` is not yet supported
in :meth:`Op.R_op`.
``list``:
``list``:
Each :class:`OpFromGraph`/callable must return a single
Each :class:`OpFromGraph`/callable must return a single
:class:`Variable <pytensor.graph.basic.Variable>`. Each list element
:class:`Variable <pytensor.graph.basic.Variable>`. Each list element
...
@@ -398,12 +341,15 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -398,12 +341,15 @@ class OpFromGraph(Op, HasInnerGraph):
must be equal to number of outputs. connection_pattern If not
must be equal to number of outputs. connection_pattern If not
``None``, this will be used as the connection_pattern for this
``None``, this will be used as the connection_pattern for this
:class:`Op`.
:class:`Op`.
strict: bool, default False
strict: bool, default False
If true, it raises when any variables needed to compute the inner graph
If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with
are not provided as explici inputs. This can only happen for graphs with
shared variables.
shared variables.
name
name
A name for debugging purposes.
A name for debugging purposes.
kwargs
kwargs
Check :func:`pytensor.function` for more arguments, only works when not
Check :func:`pytensor.function` for more arguments, only works when not
inline.
inline.
...
@@ -460,26 +406,19 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -460,26 +406,19 @@ class OpFromGraph(Op, HasInnerGraph):
self
.
grad_overrides
=
grad_overrides
self
.
grad_overrides
=
grad_overrides
self
.
rop_overrides
=
rop_overrides
self
.
rop_overrides
=
rop_overrides
if
lop_overrides
is
not
None
:
self
.
_lop_op_interface
=
True
if
grad_overrides
is
not
None
:
if
grad_overrides
is
not
None
:
if
lop_overrides
is
not
None
:
raise
ValueError
(
raise
ValueError
(
"lop_overrides and grad_overrides are mutually exclusive"
"lop_overrides and grad_overrides are mutually exclusive"
)
)
else
:
self
.
set_lop_overrides
(
lop_overrides
)
self
.
_lop_type
=
"lop"
elif
grad_overrides
is
not
None
:
warnings
.
warn
(
warnings
.
warn
(
"grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future."
,
"grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future."
,
FutureWarning
,
FutureWarning
,
)
)
self
.
set_lop_overrides
(
grad_overrides
)
self
.
_lop_op_interface
=
False
self
.
_lop_type
=
"grad"
self
.
_lop_op_cache
:
Callable
|
None
=
None
else
:
self
.
_rop_op_cache
:
Callable
|
None
=
None
self
.
set_lop_overrides
(
None
)
self
.
_lop_type
=
"lop"
self
.
set_rop_overrides
(
rop_overrides
)
self
.
_connection_pattern
=
connection_pattern
self
.
_connection_pattern
=
connection_pattern
...
@@ -501,307 +440,224 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -501,307 +440,224 @@ class OpFromGraph(Op, HasInnerGraph):
is_inline
=
self
.
is_inline
is_inline
=
self
.
is_inline
return
"{name}{{inline={is_inline}}}"
.
format
(
**
locals
())
return
"{name}{{inline={is_inline}}}"
.
format
(
**
locals
())
def
_combine_list_overrides
(
self
,
default_outs
,
custom_outs
,
callable_args
):
"""Combines default and custom overrides into a single list of outputs."""
default_out_iter
=
iter
(
default_outs
)
combined_outs
=
[]
for
custom_out
in
custom_outs
:
if
custom_out
is
None
:
combined_outs
.
append
(
next
(
default_out_iter
))
elif
isinstance
(
custom_out
,
Variable
):
if
not
isinstance
(
custom_out
.
type
,
NullType
|
DisconnectedType
):
raise
ValueError
(
f
"Override list can only contain NullType or DisconnectedType Variable instances, got {custom_out.type}"
)
combined_outs
.
append
(
custom_out
)
elif
callable
(
custom_out
):
combined_outs
.
append
(
custom_out
(
*
callable_args
))
else
:
raise
ValueError
(
f
"Override list should contain None, Variable or callable, got {type(custom_out)}"
)
return
combined_outs
def
_call_custom_override
(
self
,
op_overrides
,
callable_args
,
nout
):
"""Calls custom override function and provides informative error messages."""
if
not
callable
(
op_overrides
):
raise
TypeError
(
f
"L_op/R_op override should be None, a list or a Callable, got {type(op_overrides)}"
)
outputs
=
op_overrides
(
*
callable_args
)
if
not
isinstance
(
outputs
,
list
):
raise
TypeError
(
f
"Lop/Rop overriding function should return a list, got {type(outputs)}"
)
if
len
(
outputs
)
!=
nout
:
raise
ValueError
(
f
"Lop/Rop overriding function {self.rop_overrides} should return "
f
"a list of {nout} outputs, got {len(outputs)}"
)
return
outputs
@config.change_flags
(
compute_test_value
=
"off"
)
@config.change_flags
(
compute_test_value
=
"off"
)
def
_recompute_lop_op
(
self
):
def
_build_and_cache_lop_op
(
self
)
->
Callable
:
"""
"""converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance.
converts self._lop_op from user supplied form to type(self) instance
Results are cached in self._lop_op_cache
"""
"""
local_inputs
=
self
.
inner_inputs
if
self
.
_lop_op_cache
is
not
None
:
local_outputs
=
self
.
inner_outputs
return
self
.
_lop_op_cache
inp_len
=
len
(
local_inputs
)
lop_op
=
self
.
_lop_op
inner_inputs
=
self
.
inner_inputs
inner_outputs
=
self
.
inner_outputs
if
isinstance
(
lop_op
,
OpFromGraph
):
nin
=
len
(
inner_inputs
)
if
self
.
_lop_op_is_cached
:
lop_overrides
=
(
return
self
.
lop_overrides
if
self
.
_lop_op_interface
else
self
.
grad_overrides
assert
self
.
_lop_type
in
(
"lop"
,
"grad"
),
(
)
self
.
LOP_TYPE_ERR_MSG
%
self
.
_lop_type
)
if
isinstance
(
lop_overrides
,
OpFromGraph
):
if
self
.
_lop_type
==
"grad"
:
if
self
.
_lop_op_interface
:
needed_ninps
=
inp_len
+
len
(
local_outputs
)
self
.
_lop_op_cache
=
lop_overrides
ninps
=
len
(
lop_op
.
inner_inputs
)
lop_overrides
.
kwargs
[
"on_unused_input"
]
=
"ignore"
if
needed_ninps
!=
ninps
:
return
lop_overrides
raise
ValueError
(
self
.
OV_INP_LEN_ERR_MSG
%
(
needed_ninps
,
ninps
))
# make a wrapper callable
else
:
# We need to add a wrapper for the different input signature
def
lop_op
(
inps
,
grads
):
# TODO: Remove this once the grad interface is gone
return
self
.
_lop_op
(
*
(
inps
+
grads
))
def
lop_overrides
(
inps
,
grads
):
return
self
.
grad_overrides
(
*
inps
,
*
grads
)
elif
self
.
_lop_type
==
"lop"
:
# OfG can be directly used in L_op format
needed_ninps
=
inp_len
+
2
*
len
(
local_outputs
)
ninps
=
len
(
lop_op
.
inner_inputs
)
if
needed_ninps
!=
ninps
:
raise
ValueError
(
self
.
OV_INP_LEN_ERR_MSG
%
(
needed_ninps
,
ninps
))
self
.
_lop_op_is_cached
=
True
self
.
_lop_op_stypes_l
=
[
None
]
*
inp_len
self
.
_lop_op
.
kwargs
[
"on_unused_input"
]
=
"ignore"
return
output_grads
=
[
out_t
()
for
out_t
in
self
.
output_types
]
output_grads
=
[
out_t
()
for
out_t
in
self
.
output_types
]
fn_grad
=
partial
(
fn_grad
=
partial
(
grad
,
grad
,
cost
=
None
,
cost
=
None
,
disconnected_inputs
=
"ignore"
,
disconnected_inputs
=
"ignore"
,
return_disconnected
=
"
D
isconnected"
,
return_disconnected
=
"
d
isconnected"
,
null_gradients
=
"return"
,
null_gradients
=
"return"
,
known_grads
=
OrderedDict
(
zip
(
local
_outputs
,
output_grads
)),
known_grads
=
dict
(
zip
(
inner
_outputs
,
output_grads
)),
)
)
assert
self
.
_lop_type
in
(
"lop"
,
"grad"
),
self
.
LOP_TYPE_ERR_MSG
%
self
.
_lop_type
if
self
.
_lop_op_interface
:
if
self
.
_lop_type
==
"lop"
:
callable_args
=
(
inner_inputs
,
inner_outputs
,
output_grads
)
callable_args
=
(
local_inputs
,
local_outputs
,
output_grads
)
else
:
elif
self
.
_lop_type
==
"grad"
:
callable_args
=
(
inner_inputs
,
output_grads
)
callable_args
=
(
local_inputs
,
output_grads
)
# we need to convert _lop_op into an OfG instance
# we need to convert _lop_op into an OfG instance
if
lop_op
is
None
:
if
lop_overrides
is
None
:
gdefaults_l
=
fn_grad
(
wrt
=
local_inputs
)
input_grads
=
fn_grad
(
wrt
=
inner_inputs
)
all_grads_l
,
all_grads_ov_l
=
zip
(
elif
isinstance
(
lop_overrides
,
list
):
*
[
custom_input_grads
=
lop_overrides
OpFromGraph
.
_filter_grad_var
(
grad
,
inp
)
if
len
(
custom_input_grads
)
!=
nin
:
for
grad
,
inp
in
zip
(
gdefaults_l
,
local_inputs
)
]
)
all_grads_l
=
list
(
all_grads_l
)
all_grads_ov_l
=
list
(
all_grads_ov_l
)
elif
isinstance
(
lop_op
,
list
):
goverrides_l
=
lop_op
if
len
(
goverrides_l
)
!=
inp_len
:
raise
ValueError
(
raise
ValueError
(
f
"Need to override {
int(inp_len)} gradients, got {len(goverrides_l
)}"
,
f
"Need to override {
nin} gradients, got {len(custom_input_grads
)}"
,
goverrides_l
,
custom_input_grads
,
)
)
# compute non-overriding downsteam grads from upstreams grads
# compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
# it's normal some input may be disconnected, thus the 'ignore'
wrt_l
=
[
lin
for
lin
,
gov
in
zip
(
local_inputs
,
goverrides_l
)
if
gov
is
None
]
wrt
=
[
gdefaults
=
iter
(
fn_grad
(
wrt
=
wrt_l
)
if
wrt_l
else
[])
lin
for
lin
,
gov
in
zip
(
inner_inputs
,
custom_input_grads
)
if
gov
is
None
# combine overriding gradients
]
all_grads_l
=
[]
default_input_grads
=
fn_grad
(
wrt
=
wrt
)
if
wrt
else
[]
all_grads_ov_l
=
[]
input_grads
=
self
.
_combine_list_overrides
(
for
inp
,
fn_gov
in
zip
(
local_inputs
,
goverrides_l
):
default_input_grads
,
custom_input_grads
,
callable_args
if
fn_gov
is
None
:
gnext
,
gnext_ov
=
OpFromGraph
.
_filter_grad_var
(
next
(
gdefaults
),
inp
)
all_grads_l
.
append
(
gnext
)
all_grads_ov_l
.
append
(
gnext_ov
)
elif
isinstance
(
fn_gov
,
Variable
):
if
isinstance
(
fn_gov
.
type
,
DisconnectedType
|
NullType
):
all_grads_l
.
append
(
inp
.
zeros_like
())
all_grads_ov_l
.
append
(
fn_gov
.
type
())
else
:
raise
ValueError
(
self
.
STYPE_ERR_MSG
%
fn_gov
.
type
)
else
:
if
not
callable
(
fn_gov
):
raise
TypeError
(
self
.
TYPE_ERR_MSG
%
fn_gov
)
gov
,
gov_ov
=
OpFromGraph
.
_filter_grad_var
(
fn_gov
(
*
callable_args
),
inp
)
all_grads_l
.
append
(
gov
)
all_grads_ov_l
.
append
(
gov_ov
)
else
:
# callable case
if
not
callable
(
lop_op
):
raise
TypeError
(
self
.
TYPE_ERR_MSG
%
lop_op
)
goverrides_l
=
lop_op
(
*
callable_args
)
if
not
isinstance
(
goverrides_l
,
list
):
raise
TypeError
(
"Gradient/L_op overriding function should return a list, "
f
'got "{type(goverrides_l)}"'
)
all_grads_l
,
all_grads_ov_l
=
zip
(
*
[
OpFromGraph
.
_filter_grad_var
(
grad
,
inp
)
for
grad
,
inp
in
zip
(
goverrides_l
,
local_inputs
)
]
)
)
if
len
(
all_grads_l
)
!=
len
(
local_inputs
):
else
:
raise
ValueError
(
input_grads
=
self
.
_call_custom_override
(
lop_overrides
,
callable_args
,
nin
)
"Gradient/L_op overriding function should return list of "
f
"{int(inp_len)} outputs, got {len(all_grads_l)}"
# Filter out disconnected input and output gradients
)
connected_input_grads
=
[
all_grads_l
=
list
(
all_grads_l
)
inp_grad
all_grads_ov_l
=
list
(
all_grads_ov_l
)
for
inp_grad
in
input_grads
self
.
_lop_op
=
type
(
self
)(
if
not
isinstance
(
inp_grad
.
type
,
DisconnectedType
|
NullType
)
inputs
=
local_inputs
+
local_outputs
+
output_grads
,
]
outputs
=
all_grads_l
,
lop_op
=
type
(
self
)(
inputs
=
inner_inputs
+
inner_outputs
+
output_grads
,
outputs
=
connected_input_grads
,
inline
=
self
.
is_inline
,
inline
=
self
.
is_inline
,
name
=
(
None
if
self
.
name
is
None
else
self
.
name
+
"_"
+
self
.
_lop_type
),
name
=
(
None
if
self
.
name
is
None
else
f
"{self.name}_LOp"
),
# TODO: We can be eager here and exclude unused inputs in the OFG
on_unused_input
=
"ignore"
,
on_unused_input
=
"ignore"
,
)
)
self
.
_lop_op_stypes_l
=
all_grads_ov_l
self
.
_lop_op_is_cached
=
True
# Return a wrapper that combines connected and disconnected input gradients
self
.
_lop_type
=
"lop"
def
wrapper
(
*
inputs
:
Variable
,
**
kwargs
)
->
list
[
Variable
]:
connected_input_grads
=
iter
(
lop_op
(
*
inputs
,
**
kwargs
))
return
[
input_grad
if
isinstance
(
input_grad
.
type
,
DisconnectedType
|
NullType
)
else
next
(
connected_input_grads
)
for
input_grad
in
input_grads
]
self
.
_lop_op_cache
=
wrapper
return
wrapper
@config.change_flags
(
compute_test_value
=
"off"
)
@config.change_flags
(
compute_test_value
=
"off"
)
def
_recompute_rop_op
(
self
):
def
_build_and_cache_rop_op
(
self
):
"""
"""Converts rop_overrides from user supplied form to type(self) instance.
converts self._rop_op from user supplied form to type(self) instance
Results are cached in self._rop_op_cache
"""
"""
local_inputs
=
self
.
inner_inputs
if
self
.
_rop_op_cache
is
not
None
:
local_outputs
=
self
.
inner_outputs
return
self
.
_rop_op_cache
out_len
=
len
(
local_outputs
)
rop_op
=
self
.
_rop_op
inner_inputs
=
self
.
inner_inputs
inner_outputs
=
self
.
inner_outputs
if
isinstance
(
rop_op
,
OpFromGraph
):
nout
=
len
(
inner_outputs
)
if
not
self
.
_rop_op_is_cached
:
rop_overrides
=
self
.
rop_overrides
self
.
_rop_op_is_cached
=
True
self
.
_rop_op_stypes_l
=
[
None
]
*
out_len
if
isinstance
(
rop_overrides
,
OpFromGraph
):
return
self
.
_rop_op_cache
=
rop_overrides
return
rop_overrides
eval_points
=
[
inp_t
()
for
inp_t
in
self
.
input_types
]
eval_points
=
[
inp_t
()
for
inp_t
in
self
.
input_types
]
fn_rop
=
partial
(
Rop
,
wrt
=
local_inputs
,
eval_points
=
eval_points
)
fn_rop
=
partial
(
Rop
,
wrt
=
inner_inputs
,
eval_points
=
eval_points
)
TYPE_ERR_MSG
=
(
"R_op overrides should be (single or list of)"
callable_args
=
(
inner_inputs
,
eval_points
)
"OpFromGraph, None, a list or a callable, got
%
s"
if
rop_overrides
is
None
:
)
output_grads
=
fn_rop
(
f
=
inner_outputs
)
STYPE_ERR_MSG
=
(
elif
isinstance
(
rop_overrides
,
list
):
"Overriding Variable instance can only have type"
custom_output_grads
=
rop_overrides
" of DisconnectedType or NullType, got
%
s"
if
len
(
custom_output_grads
)
!=
nout
:
)
if
rop_op
is
None
:
rdefaults_l
=
fn_rop
(
f
=
local_outputs
)
all_rops_l
,
all_rops_ov_l
=
zip
(
*
[
OpFromGraph
.
_filter_rop_var
(
rop
,
out
)
for
rop
,
out
in
zip
(
rdefaults_l
,
local_outputs
)
]
)
all_rops_l
=
list
(
all_rops_l
)
all_rops_ov_l
=
list
(
all_rops_ov_l
)
elif
isinstance
(
rop_op
,
list
):
roverrides_l
=
rop_op
if
len
(
roverrides_l
)
!=
out_len
:
raise
ValueError
(
raise
ValueError
(
f
"Need to override {int(
out_len)} Rop, got {len(roverrides_l
)}"
,
f
"Need to override {int(
nout)} Rop, got {len(custom_output_grads
)}"
,
roverrides_l
,
custom_output_grads
,
)
)
# get outputs that does not have Rop override
# get outputs that does not have Rop override
odefaults_l
=
[
f
=
[
lo
for
lo
,
rov
in
zip
(
local_outputs
,
roverrides_l
)
if
rov
is
None
output
for
output
,
custom_output_grad
in
zip
(
inner_outputs
,
custom_output_grads
)
if
custom_output_grad
is
None
]
]
rdefaults_l
=
fn_rop
(
f
=
odefaults_l
)
default_output_grads
=
fn_rop
(
f
=
f
)
if
f
else
[]
rdefaults
=
iter
(
rdefaults_l
if
odefaults_l
else
[])
output_grads
=
self
.
_combine_list_overrides
(
# combine overriding Rops
default_output_grads
,
custom_output_grads
,
callable_args
all_rops_l
=
[]
)
all_rops_ov_l
=
[]
for
out
,
fn_rov
in
zip
(
local_outputs
,
roverrides_l
):
if
fn_rov
is
None
:
rnext
,
rnext_ov
=
OpFromGraph
.
_filter_rop_var
(
next
(
rdefaults
),
out
)
all_rops_l
.
append
(
rnext
)
all_rops_ov_l
.
append
(
rnext_ov
)
elif
isinstance
(
fn_rov
,
Variable
):
if
isinstance
(
fn_rov
.
type
,
NullType
):
all_rops_l
.
append
(
out
.
zeros_like
())
all_rops_ov_l
.
append
(
fn_rov
.
type
())
if
isinstance
(
fn_rov
.
type
,
DisconnectedType
):
all_rops_l
.
append
(
out
.
zeros_like
())
all_rops_ov_l
.
append
(
None
)
else
:
raise
ValueError
(
STYPE_ERR_MSG
%
fn_rov
.
type
)
else
:
if
not
callable
(
fn_rov
):
raise
TypeError
(
TYPE_ERR_MSG
%
fn_rov
)
rov
,
rov_ov
=
OpFromGraph
.
_filter_rop_var
(
fn_rov
(
local_inputs
,
eval_points
),
out
)
all_rops_l
.
append
(
rov
)
all_rops_ov_l
.
append
(
rov_ov
)
else
:
else
:
if
not
callable
(
rop_op
):
output_grads
=
self
.
_call_custom_override
(
raise
TypeError
(
TYPE_ERR_MSG
%
rop_op
)
rop_overrides
,
callable_args
,
nout
roverrides_l
=
rop_op
(
local_inputs
,
eval_points
)
if
not
isinstance
(
roverrides_l
,
list
):
raise
TypeError
(
"Rop overriding function should return a list, "
f
'got "{type(roverrides_l)}"'
)
all_rops_l
,
all_rops_ov_l
=
zip
(
*
[
OpFromGraph
.
_filter_rop_var
(
rop
,
out
)
for
rop
,
out
in
zip
(
roverrides_l
,
local_outputs
)
]
)
)
if
len
(
all_rops_l
)
!=
out_len
:
raise
ValueError
(
# Filter out disconnected output gradients
(
filtered_output_grads
=
[
f
"Rop overriding function {self._rop_op} should return list of "
out_grad
f
"{int(out_len)} outputs, got {len(all_rops_l)}"
,
for
out_grad
in
output_grads
),
if
not
isinstance
(
out_grad
.
type
,
DisconnectedType
|
NullType
)
rop_op
,
]
)
rop_op
=
type
(
self
)(
all_rops_l
=
list
(
all_rops_l
)
inputs
=
inner_inputs
+
eval_points
,
all_rops_ov_l
=
list
(
all_rops_ov_l
)
outputs
=
filtered_output_grads
,
self
.
_rop_op
=
type
(
self
)(
inputs
=
local_inputs
+
eval_points
,
outputs
=
all_rops_l
,
inline
=
self
.
is_inline
,
inline
=
self
.
is_inline
,
name
=
(
None
if
self
.
name
is
None
else
self
.
name
+
"_rop"
),
name
=
(
None
if
self
.
name
is
None
else
self
.
name
+
"_rop"
),
on_unused_input
=
"ignore"
,
on_unused_input
=
"ignore"
,
)
)
self
.
_rop_op_stypes_l
=
all_rops_ov_l
self
.
_rop_op_is_cached
=
True
def
get_lop_op
(
self
):
if
not
self
.
_lop_op_is_cached
:
self
.
_recompute_lop_op
()
return
self
.
_lop_op
def
get_rop_op
(
self
):
if
not
self
.
_rop_op_is_cached
:
self
.
_recompute_rop_op
()
return
self
.
_rop_op
def
set_grad_overrides
(
self
,
grad_overrides
):
"""
Set gradient overrides.
This will completely remove any previously set L_op/gradient overrides
"""
self
.
_lop_op
=
grad_overrides
self
.
_lop_op_is_cached
=
False
self
.
_lop_type
=
"grad"
def
set_lop_overrides
(
self
,
lop_overrides
):
"""
Set L_op overrides
This will completely remove any previously set L_op/gradient overrides
"""
self
.
_lop_op
=
lop_overrides
self
.
_lop_op_is_cached
=
False
self
.
_lop_type
=
"lop"
def
set_rop_overrides
(
self
,
rop_overrides
):
# Return a wrapper that combines connected and disconnected output gradients
"""
def
wrapper
(
*
inputs
:
Variable
,
**
kwargs
)
->
list
[
Variable
|
None
]:
Set R_op overrides
connected_output_grads
=
iter
(
rop_op
(
*
inputs
,
**
kwargs
))
This will completely remove any previously set R_op overrides
all_output_grads
=
[]
for
out_grad
in
output_grads
:
if
isinstance
(
out_grad
.
type
,
DisconnectedType
):
# R_Op does not have DisconnectedType yet, None should be used instead
all_output_grads
.
append
(
None
)
elif
isinstance
(
out_grad
.
type
,
NullType
):
all_output_grads
.
append
(
out_grad
)
else
:
all_output_grads
.
append
(
next
(
connected_output_grads
))
return
all_output_grads
"""
self
.
_rop_op_cache
=
wrapper
self
.
_rop_op
=
rop_overrides
return
wrapper
self
.
_rop_op_is_cached
=
False
def
L_op
(
self
,
inputs
,
outputs
,
output_grads
):
def
L_op
(
self
,
inputs
,
outputs
,
output_grads
):
if
not
self
.
_lop_op_is_cached
:
lop_op
=
self
.
_build_and_cache_lop_op
()
self
.
_recompute_lop_op
()
return
lop_op
(
*
inputs
,
*
outputs
,
*
output_grads
,
return_list
=
True
)
inps
=
list
(
inputs
)
+
list
(
outputs
)
+
list
(
output_grads
)
ret_ofg_l
=
self
.
_lop_op
(
*
inps
,
return_list
=
True
)
ret_l
=
[
ret_ofg
if
ov
is
None
else
ov
for
ret_ofg
,
ov
in
zip
(
ret_ofg_l
,
self
.
_lop_op_stypes_l
)
]
return
ret_l
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
if
not
self
.
_rop_op_is_cached
:
rop_op
=
self
.
_build_and_cache_rop_op
()
self
.
_recompute_rop_op
()
return
rop_op
(
*
inputs
,
*
eval_points
,
return_list
=
True
)
ret_ofg_l
=
self
.
_rop_op
(
*
(
list
(
inputs
)
+
list
(
eval_points
)),
return_list
=
True
)
ret_l
=
[
ret_ofg
if
ov
is
None
else
ov
for
ret_ofg
,
ov
in
zip
(
ret_ofg_l
,
self
.
_rop_op_stypes_l
)
]
return
ret_l
def
__call__
(
self
,
*
inputs
,
**
kwargs
):
def
__call__
(
self
,
*
inputs
,
**
kwargs
):
# The user interface doesn't expect the shared variable inputs of the
# The user interface doesn't expect the shared variable inputs of the
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论