Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ab304cb9
提交
ab304cb9
authored
5月 29, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Deprecate use of "default" and Variable as OpFromGrah overrides
上级
6dfc811f
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
52 行增加
和
50 行删除
+52
-50
builders.py
pytensor/compile/builders.py
+36
-46
test_builders.py
tests/compile/test_builders.py
+16
-4
没有找到文件。
pytensor/compile/builders.py
浏览文件 @
ab304cb9
...
...
@@ -2,10 +2,10 @@
import
warnings
from
collections
import
OrderedDict
from
collections.abc
import
Sequence
from
collections.abc
import
Callable
,
Sequence
from
copy
import
copy
from
functools
import
partial
from
typing
import
cast
from
typing
import
Union
,
cast
import
pytensor.tensor
as
pt
from
pytensor.compile.function
import
function
...
...
@@ -225,7 +225,7 @@ class OpFromGraph(Op, HasInnerGraph):
e2 = op(x, y, z) + op(z, y, x)
fn = function([x, y, z], [e2])
Example 3 override L_op
Example 3 override
second output of
L_op
.. code-block:: python
...
...
@@ -241,7 +241,7 @@ class OpFromGraph(Op, HasInnerGraph):
op = OpFromGraph(
[x, y, z],
[e],
lop_overrides=[
'default', rescale_dy, 'default'
],
lop_overrides=[
None, rescale_dy, None
],
)
e2 = op(x, y, z)
dx, dy, dz = grad(e2, [x, y, z])
...
...
@@ -253,7 +253,7 @@ class OpFromGraph(Op, HasInnerGraph):
TYPE_ERR_MSG
=
(
"L_op/gradient override should be (single or list of)"
"
'default'
| OpFromGraph | callable | Variable "
"
None
| OpFromGraph | callable | Variable "
"with NullType or DisconnectedType, got
%
s"
)
STYPE_ERR_MSG
=
(
...
...
@@ -308,9 +308,9 @@ class OpFromGraph(Op, HasInnerGraph):
outputs
:
list
[
Variable
],
*
,
inline
:
bool
=
False
,
lop_overrides
:
str
=
"default"
,
grad_overrides
:
str
=
"default"
,
rop_overrides
:
str
=
"default"
,
lop_overrides
:
Union
[
Callable
,
"OpFromGraph"
,
None
]
=
None
,
grad_overrides
:
Union
[
Callable
,
"OpFromGraph"
,
None
]
=
None
,
rop_overrides
:
Union
[
Callable
,
"OpFromGraph"
,
None
]
=
None
,
connection_pattern
:
list
[
list
[
bool
]]
|
None
=
None
,
strict
:
bool
=
False
,
name
:
str
|
None
=
None
,
...
...
@@ -333,10 +333,10 @@ class OpFromGraph(Op, HasInnerGraph):
``False`` : will use a pre-compiled function inside.
grad_overrides
Defaults to ``
'default'
``.
Defaults to ``
None
``.
This argument is mutually exclusive with ``lop_overrides``.
``
'default'
`` : Do not override, use default grad() result
``
None
`` : Do not override, use default grad() result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs`` and ``output_grads``
...
...
@@ -346,14 +346,14 @@ class OpFromGraph(Op, HasInnerGraph):
Each argument is expected to be a list of :class:`Variable `.
Must return list of :class:`Variable `.
lop_overrides
Defaults to ``
'default'
``.
Defaults to ``
None
``.
This argument is mutually exclusive with ``grad_overrides``.
These options are similar to the ``grad_overrides`` above, but for
the :meth:`Op.L_op` method.
``
'default'
``: Do not override, use the default :meth:`Op.L_op` result
``
None
``: Do not override, use the default :meth:`Op.L_op` result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs``,
...
...
@@ -373,11 +373,11 @@ class OpFromGraph(Op, HasInnerGraph):
a specific input, length of list must be equal to number of inputs.
rop_overrides
One of ``{
'default'
, OpFromGraph, callable, Variable}``.
One of ``{
None
, OpFromGraph, callable, Variable}``.
Defaults to ``
'default'
``.
Defaults to ``
None
``.
``
'default'
``: Do not override, use the default :meth:`Op.R_op` result
``
None
``: Do not override, use the default :meth:`Op.R_op` result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs`` and ``eval_points``
...
...
@@ -446,19 +446,29 @@ class OpFromGraph(Op, HasInnerGraph):
self
.
input_types
=
[
inp
.
type
for
inp
in
inputs
]
self
.
output_types
=
[
out
.
type
for
out
in
outputs
]
for
override
in
(
lop_overrides
,
grad_overrides
,
rop_overrides
):
if
override
==
"default"
:
raise
ValueError
(
"'default' is no longer a valid value for overrides. Use None instead."
)
if
isinstance
(
override
,
Variable
):
raise
TypeError
(
"Variables are no longer valid types for overrides. Return them in a list for each output instead"
)
self
.
lop_overrides
=
lop_overrides
self
.
grad_overrides
=
grad_overrides
self
.
rop_overrides
=
rop_overrides
if
lop_overrides
!=
"default"
:
if
grad_overrides
!=
"default"
:
if
lop_overrides
is
not
None
:
if
grad_overrides
is
not
None
:
raise
ValueError
(
"lop_overrides and grad_overrides are mutually exclusive"
)
else
:
self
.
set_lop_overrides
(
lop_overrides
)
self
.
_lop_type
=
"lop"
elif
grad_overrides
!=
"default"
:
elif
grad_overrides
is
not
None
:
warnings
.
warn
(
"grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future."
,
FutureWarning
,
...
...
@@ -466,7 +476,7 @@ class OpFromGraph(Op, HasInnerGraph):
self
.
set_lop_overrides
(
grad_overrides
)
self
.
_lop_type
=
"grad"
else
:
self
.
set_lop_overrides
(
"default"
)
self
.
set_lop_overrides
(
None
)
self
.
_lop_type
=
"lop"
self
.
set_rop_overrides
(
rop_overrides
)
...
...
@@ -546,7 +556,7 @@ class OpFromGraph(Op, HasInnerGraph):
callable_args
=
(
local_inputs
,
output_grads
)
# we need to convert _lop_op into an OfG instance
if
lop_op
==
"default"
:
if
lop_op
is
None
:
gdefaults_l
=
fn_grad
(
wrt
=
local_inputs
)
all_grads_l
,
all_grads_ov_l
=
zip
(
*
[
...
...
@@ -556,12 +566,6 @@ class OpFromGraph(Op, HasInnerGraph):
)
all_grads_l
=
list
(
all_grads_l
)
all_grads_ov_l
=
list
(
all_grads_ov_l
)
elif
isinstance
(
lop_op
,
Variable
):
if
isinstance
(
lop_op
.
type
,
DisconnectedType
|
NullType
):
all_grads_l
=
[
inp
.
zeros_like
()
for
inp
in
local_inputs
]
all_grads_ov_l
=
[
lop_op
.
type
()
for
_
in
range
(
inp_len
)]
else
:
raise
ValueError
(
self
.
STYPE_ERR_MSG
%
lop_op
.
type
)
elif
isinstance
(
lop_op
,
list
):
goverrides_l
=
lop_op
if
len
(
goverrides_l
)
!=
inp_len
:
...
...
@@ -571,15 +575,13 @@ class OpFromGraph(Op, HasInnerGraph):
)
# compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
wrt_l
=
[
lin
for
lin
,
gov
in
zip
(
local_inputs
,
goverrides_l
)
if
gov
==
"default"
]
wrt_l
=
[
lin
for
lin
,
gov
in
zip
(
local_inputs
,
goverrides_l
)
if
gov
is
None
]
gdefaults
=
iter
(
fn_grad
(
wrt
=
wrt_l
)
if
wrt_l
else
[])
# combine overriding gradients
all_grads_l
=
[]
all_grads_ov_l
=
[]
for
inp
,
fn_gov
in
zip
(
local_inputs
,
goverrides_l
):
if
fn_gov
==
"default"
:
if
fn_gov
is
None
:
gnext
,
gnext_ov
=
OpFromGraph
.
_filter_grad_var
(
next
(
gdefaults
),
inp
)
all_grads_l
.
append
(
gnext
)
all_grads_ov_l
.
append
(
gnext_ov
)
...
...
@@ -652,13 +654,13 @@ class OpFromGraph(Op, HasInnerGraph):
fn_rop
=
partial
(
Rop
,
wrt
=
local_inputs
,
eval_points
=
eval_points
)
TYPE_ERR_MSG
=
(
"R_op overrides should be (single or list of)"
"OpFromGraph
| 'default' | None | 0 |
callable, got
%
s"
"OpFromGraph
, None, a list or a
callable, got
%
s"
)
STYPE_ERR_MSG
=
(
"Overriding Variable instance can only have type"
" of DisconnectedType or NullType, got
%
s"
)
if
rop_op
==
"default"
:
if
rop_op
is
None
:
rdefaults_l
=
fn_rop
(
f
=
local_outputs
)
all_rops_l
,
all_rops_ov_l
=
zip
(
*
[
...
...
@@ -668,15 +670,6 @@ class OpFromGraph(Op, HasInnerGraph):
)
all_rops_l
=
list
(
all_rops_l
)
all_rops_ov_l
=
list
(
all_rops_ov_l
)
elif
isinstance
(
rop_op
,
Variable
):
if
isinstance
(
rop_op
.
type
,
NullType
):
all_rops_l
=
[
inp
.
zeros_like
()
for
inp
in
local_inputs
]
all_rops_ov_l
=
[
rop_op
.
type
()
for
_
in
range
(
out_len
)]
elif
isinstance
(
rop_op
.
type
,
DisconnectedType
):
all_rops_l
=
[
inp
.
zeros_like
()
for
inp
in
local_inputs
]
all_rops_ov_l
=
[
None
]
*
out_len
else
:
raise
ValueError
(
STYPE_ERR_MSG
%
rop_op
.
type
)
elif
isinstance
(
rop_op
,
list
):
roverrides_l
=
rop_op
if
len
(
roverrides_l
)
!=
out_len
:
...
...
@@ -686,7 +679,7 @@ class OpFromGraph(Op, HasInnerGraph):
)
# get outputs that does not have Rop override
odefaults_l
=
[
lo
for
lo
,
rov
in
zip
(
local_outputs
,
roverrides_l
)
if
rov
==
"default"
lo
for
lo
,
rov
in
zip
(
local_outputs
,
roverrides_l
)
if
rov
is
None
]
rdefaults_l
=
fn_rop
(
f
=
odefaults_l
)
rdefaults
=
iter
(
rdefaults_l
if
odefaults_l
else
[])
...
...
@@ -694,7 +687,7 @@ class OpFromGraph(Op, HasInnerGraph):
all_rops_l
=
[]
all_rops_ov_l
=
[]
for
out
,
fn_rov
in
zip
(
local_outputs
,
roverrides_l
):
if
fn_rov
==
"default"
:
if
fn_rov
is
None
:
rnext
,
rnext_ov
=
OpFromGraph
.
_filter_rop_var
(
next
(
rdefaults
),
out
)
all_rops_l
.
append
(
rnext
)
all_rops_ov_l
.
append
(
rnext_ov
)
...
...
@@ -769,7 +762,6 @@ class OpFromGraph(Op, HasInnerGraph):
self
.
_lop_op
=
grad_overrides
self
.
_lop_op_is_cached
=
False
self
.
_lop_type
=
"grad"
self
.
_lop_is_default
=
grad_overrides
==
"default"
def
set_lop_overrides
(
self
,
lop_overrides
):
"""
...
...
@@ -780,7 +772,6 @@ class OpFromGraph(Op, HasInnerGraph):
self
.
_lop_op
=
lop_overrides
self
.
_lop_op_is_cached
=
False
self
.
_lop_type
=
"lop"
self
.
_lop_is_default
=
lop_overrides
==
"default"
def
set_rop_overrides
(
self
,
rop_overrides
):
"""
...
...
@@ -790,7 +781,6 @@ class OpFromGraph(Op, HasInnerGraph):
"""
self
.
_rop_op
=
rop_overrides
self
.
_rop_op_is_cached
=
False
self
.
_rop_is_default
=
rop_overrides
==
"default"
def
L_op
(
self
,
inputs
,
outputs
,
output_grads
):
if
not
self
.
_lop_op_is_cached
:
...
...
tests/compile/test_builders.py
浏览文件 @
ab304cb9
...
...
@@ -11,7 +11,7 @@ from pytensor.configdefaults import config
from
pytensor.gradient
import
DisconnectedType
,
Rop
,
disconnected_type
,
grad
from
pytensor.graph.basic
import
equal_computations
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.null_type
import
NullType
from
pytensor.graph.null_type
import
NullType
,
null_type
from
pytensor.graph.rewriting.utils
import
rewrite_graph
from
pytensor.graph.utils
import
MissingInputError
from
pytensor.printing
import
debugprint
...
...
@@ -93,6 +93,20 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert
res
.
shape
==
(
2
,
5
)
assert
np
.
all
(
180.0
==
res
)
def
test_overrides_deprecated_api
(
self
):
inp
=
scalar
(
"x"
)
out
=
inp
+
1
for
kwarg
in
(
"lop_overrides"
,
"grad_overrides"
,
"rop_overrides"
):
with
pytest
.
raises
(
ValueError
,
match
=
"'default' is no longer a valid value for overrides"
):
OpFromGraph
([
inp
],
[
out
],
**
{
kwarg
:
"default"
})
with
pytest
.
raises
(
TypeError
,
match
=
"Variables are no longer valid types for overrides"
):
OpFromGraph
([
inp
],
[
out
],
**
{
kwarg
:
null_type
()})
@pytest.mark.parametrize
(
"cls_ofg"
,
[
OpFromGraph
,
partial
(
OpFromGraph
,
inline
=
True
)]
)
...
...
@@ -211,9 +225,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
w
,
b
=
vectors
(
"wb"
)
# we make the 3rd gradient default (no override)
with
pytest
.
warns
(
FutureWarning
,
match
=
"grad_overrides is deprecated"
):
op_linear
=
cls_ofg
(
[
x
,
w
,
b
],
[
x
*
w
+
b
],
grad_overrides
=
[
go1
,
go2
,
"default"
]
)
op_linear
=
cls_ofg
([
x
,
w
,
b
],
[
x
*
w
+
b
],
grad_overrides
=
[
go1
,
go2
,
None
])
xx
,
ww
,
bb
=
vector
(
"xx"
),
vector
(
"yy"
),
vector
(
"bb"
)
zz
=
pt_sum
(
op_linear
(
xx
,
ww
,
bb
))
dx
,
dw
,
db
=
grad
(
zz
,
[
xx
,
ww
,
bb
])
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论