Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
394b355b
提交
394b355b
authored
4月 20, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix OpFromGraph L_op with related and/or disconnected outputs
上级
2eb8fca2
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
113 行增加
和
18 行删除
+113
-18
builders.py
pytensor/compile/builders.py
+69
-16
test_builders.py
tests/compile/test_builders.py
+44
-2
没有找到文件。
pytensor/compile/builders.py
浏览文件 @
394b355b
...
@@ -417,7 +417,10 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -417,7 +417,10 @@ class OpFromGraph(Op, HasInnerGraph):
FutureWarning
,
FutureWarning
,
)
)
self
.
_lop_op_interface
=
False
self
.
_lop_op_interface
=
False
self
.
_lop_op_cache
:
Callable
|
None
=
None
# Dictionary where we cache OpFromGraph that represent the L_op
# A distinct OpFromGraph is needed to represent each pattern of output_grads connection
# It also returns a tuple that indicates which input_gradients are disconnected
self
.
_lop_op_cache
:
dict
[
tuple
[
bool
,
...
],
Callable
]
=
{}
self
.
_rop_op_cache
:
Callable
|
None
=
None
self
.
_rop_op_cache
:
Callable
|
None
=
None
self
.
_connection_pattern
=
connection_pattern
self
.
_connection_pattern
=
connection_pattern
...
@@ -480,24 +483,30 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -480,24 +483,30 @@ class OpFromGraph(Op, HasInnerGraph):
return
outputs
return
outputs
@config.change_flags
(
compute_test_value
=
"off"
)
@config.change_flags
(
compute_test_value
=
"off"
)
def
_build_and_cache_lop_op
(
self
)
->
Callable
:
def
_build_and_cache_lop_op
(
"""converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance.
self
,
disconnected_output_grads
:
tuple
[
bool
,
...
]
)
->
Callable
:
"""converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance,
specialized for the pattern of disconnected_output_grads
Results are cached in self._lop_op_cache
Results are cached in self._lop_op_cache
"""
"""
if
self
.
_lop_op_cache
is
not
None
:
try
:
return
self
.
_lop_op_cache
return
self
.
_lop_op_cache
[
disconnected_output_grads
]
except
KeyError
:
pass
inner_inputs
=
self
.
inner_inputs
inner_inputs
=
self
.
inner_inputs
inner_outputs
=
self
.
inner_outputs
inner_outputs
=
self
.
inner_outputs
nin
=
len
(
inner_inputs
)
nin
=
len
(
inner_inputs
)
nout
=
len
(
inner_outputs
)
lop_overrides
=
(
lop_overrides
=
(
self
.
lop_overrides
if
self
.
_lop_op_interface
else
self
.
grad_overrides
self
.
lop_overrides
if
self
.
_lop_op_interface
else
self
.
grad_overrides
)
)
if
isinstance
(
lop_overrides
,
OpFromGraph
):
if
isinstance
(
lop_overrides
,
OpFromGraph
):
if
self
.
_lop_op_interface
:
if
self
.
_lop_op_interface
:
self
.
_lop_op_cache
=
lop_overrides
self
.
_lop_op_cache
[
disconnected_output_grads
]
=
lop_overrides
lop_overrides
.
kwargs
[
"on_unused_input"
]
=
"ignore"
lop_overrides
.
kwargs
[
"on_unused_input"
]
=
"ignore"
return
lop_overrides
return
lop_overrides
...
@@ -507,20 +516,42 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -507,20 +516,42 @@ class OpFromGraph(Op, HasInnerGraph):
def
lop_overrides
(
inps
,
grads
):
def
lop_overrides
(
inps
,
grads
):
return
self
.
grad_overrides
(
*
inps
,
*
grads
)
return
self
.
grad_overrides
(
*
inps
,
*
grads
)
output_grads
=
[
out_t
()
for
out_t
in
self
.
output_types
]
# We try to compute the gradient with respect to connected outputs only
connected_inner_outputs
=
[
# We add an identity operation(copy) so that we don't override indirect
# gradient contributions to an inner output coming from other inner outputs
inner_out
.
copy
()
for
inner_out
,
disconnected
in
zip
(
inner_outputs
,
disconnected_output_grads
,
strict
=
True
)
if
not
disconnected
]
connected_output_grads
=
[
out_t
()
for
out_t
,
disconnected
in
zip
(
self
.
output_types
,
disconnected_output_grads
,
strict
=
True
)
if
not
disconnected
]
fn_grad
=
partial
(
fn_grad
=
partial
(
grad
,
grad
,
cost
=
None
,
cost
=
None
,
disconnected_inputs
=
"ignore"
,
disconnected_inputs
=
"ignore"
,
return_disconnected
=
"disconnected"
,
return_disconnected
=
"disconnected"
,
null_gradients
=
"return"
,
null_gradients
=
"return"
,
known_grads
=
dict
(
zip
(
inner_outputs
,
output_grads
)),
known_grads
=
dict
(
zip
(
connected_inner_outputs
,
connected_output_grads
,
strict
=
True
)
),
)
)
if
self
.
_lop_op_interface
:
if
self
.
_lop_op_interface
:
callable_args
=
(
inner_inputs
,
inner_outputs
,
output_grads
)
callable_args
=
(
inner_inputs
,
connected_inner_outputs
,
connected_output_grads
,
)
else
:
else
:
callable_args
=
(
inner_inputs
,
output_grads
)
callable_args
=
(
inner_inputs
,
connected_
output_grads
)
# we need to convert _lop_op into an OfG instance
# we need to convert _lop_op into an OfG instance
if
lop_overrides
is
None
:
if
lop_overrides
is
None
:
...
@@ -544,14 +575,15 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -544,14 +575,15 @@ class OpFromGraph(Op, HasInnerGraph):
else
:
else
:
input_grads
=
self
.
_call_custom_override
(
lop_overrides
,
callable_args
,
nin
)
input_grads
=
self
.
_call_custom_override
(
lop_overrides
,
callable_args
,
nin
)
# Filter out disconnected input and output gradients
# Filter out disconnected/null input generated from the inner graph grad
# We append them in the outer wrapper function below
connected_input_grads
=
[
connected_input_grads
=
[
inp_grad
inp_grad
for
inp_grad
in
input_grads
for
inp_grad
in
input_grads
if
not
isinstance
(
inp_grad
.
type
,
DisconnectedType
|
NullType
)
if
not
isinstance
(
inp_grad
.
type
,
DisconnectedType
|
NullType
)
]
]
lop_op
=
type
(
self
)(
lop_op
=
type
(
self
)(
inputs
=
inner_inputs
+
inner_outputs
+
output_grads
,
inputs
=
inner_inputs
+
connected_inner_outputs
+
connected_
output_grads
,
outputs
=
connected_input_grads
,
outputs
=
connected_input_grads
,
inline
=
self
.
is_inline
,
inline
=
self
.
is_inline
,
name
=
(
None
if
self
.
name
is
None
else
f
"{self.name}_LOp"
),
name
=
(
None
if
self
.
name
is
None
else
f
"{self.name}_LOp"
),
...
@@ -559,9 +591,27 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -559,9 +591,27 @@ class OpFromGraph(Op, HasInnerGraph):
on_unused_input
=
"ignore"
,
on_unused_input
=
"ignore"
,
)
)
# Return a wrapper that combines connected and disconnected input gradients
# Return a wrapper that combines connected and disconnected/null input gradients
# And also filters out disconnected/null output gradients
def
wrapper
(
*
inputs
:
Variable
,
**
kwargs
)
->
list
[
Variable
]:
def
wrapper
(
*
inputs
:
Variable
,
**
kwargs
)
->
list
[
Variable
]:
connected_input_grads
=
iter
(
lop_op
(
*
inputs
,
**
kwargs
))
inputs
,
outputs
,
output_grads
=
(
inputs
[:
-
nout
*
2
],
inputs
[
-
nout
*
2
:
-
nout
],
inputs
[
-
nout
:],
)
connected_outputs
=
[
output
for
output
,
output_grad
in
zip
(
outputs
,
output_grads
,
strict
=
True
)
if
not
isinstance
(
output_grad
.
type
,
DisconnectedType
|
NullType
)
]
connected_output_grads
=
[
output_grad
for
output_grad
in
output_grads
if
not
isinstance
(
output_grad
.
type
,
DisconnectedType
)
]
connected_input_grads
=
iter
(
lop_op
(
*
inputs
,
*
connected_outputs
,
*
connected_output_grads
,
**
kwargs
)
)
return
[
return
[
input_grad
input_grad
if
isinstance
(
input_grad
.
type
,
DisconnectedType
|
NullType
)
if
isinstance
(
input_grad
.
type
,
DisconnectedType
|
NullType
)
...
@@ -569,7 +619,7 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -569,7 +619,7 @@ class OpFromGraph(Op, HasInnerGraph):
for
input_grad
in
input_grads
for
input_grad
in
input_grads
]
]
self
.
_lop_op_cache
=
wrapper
self
.
_lop_op_cache
[
disconnected_output_grads
]
=
wrapper
return
wrapper
return
wrapper
@config.change_flags
(
compute_test_value
=
"off"
)
@config.change_flags
(
compute_test_value
=
"off"
)
...
@@ -652,7 +702,10 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -652,7 +702,10 @@ class OpFromGraph(Op, HasInnerGraph):
return
wrapper
return
wrapper
def
L_op
(
self
,
inputs
,
outputs
,
output_grads
):
def
L_op
(
self
,
inputs
,
outputs
,
output_grads
):
lop_op
=
self
.
_build_and_cache_lop_op
()
disconnected_output_grads
=
tuple
(
isinstance
(
og
.
type
,
DisconnectedType
)
for
og
in
output_grads
)
lop_op
=
self
.
_build_and_cache_lop_op
(
disconnected_output_grads
)
return
lop_op
(
*
inputs
,
*
outputs
,
*
output_grads
,
return_list
=
True
)
return
lop_op
(
*
inputs
,
*
outputs
,
*
output_grads
,
return_list
=
True
)
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
...
...
tests/compile/test_builders.py
浏览文件 @
394b355b
...
@@ -8,7 +8,13 @@ from pytensor.compile import shared
...
@@ -8,7 +8,13 @@ from pytensor.compile import shared
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.gradient
import
DisconnectedType
,
Rop
,
disconnected_type
,
grad
from
pytensor.gradient
import
(
DisconnectedType
,
Rop
,
disconnected_type
,
grad
,
verify_grad
,
)
from
pytensor.graph.basic
import
equal_computations
from
pytensor.graph.basic
import
equal_computations
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.null_type
import
NullType
,
null_type
from
pytensor.graph.null_type
import
NullType
,
null_type
...
@@ -22,7 +28,15 @@ from pytensor.tensor.math import sum as pt_sum
...
@@ -22,7 +28,15 @@ from pytensor.tensor.math import sum as pt_sum
from
pytensor.tensor.random.utils
import
RandomStream
from
pytensor.tensor.random.utils
import
RandomStream
from
pytensor.tensor.rewriting.shape
import
ShapeOptimizer
from
pytensor.tensor.rewriting.shape
import
ShapeOptimizer
from
pytensor.tensor.shape
import
specify_shape
from
pytensor.tensor.shape
import
specify_shape
from
pytensor.tensor.type
import
TensorType
,
matrices
,
matrix
,
scalar
,
vector
,
vectors
from
pytensor.tensor.type
import
(
TensorType
,
dscalars
,
matrices
,
matrix
,
scalar
,
vector
,
vectors
,
)
from
tests
import
unittest_tools
from
tests
import
unittest_tools
from
tests.graph.utils
import
MyVariable
from
tests.graph.utils
import
MyVariable
...
@@ -638,6 +652,34 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
...
@@ -638,6 +652,34 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
out
=
test_ofg
(
y
,
y
)
out
=
test_ofg
(
y
,
y
)
assert
out
.
eval
()
==
4
assert
out
.
eval
()
==
4
def
test_L_op_disconnected_output_grad
(
self
):
x
,
y
=
dscalars
(
"x"
,
"y"
)
rng
=
np
.
random
.
default_rng
(
594
)
point
=
list
(
rng
.
normal
(
size
=
(
2
,)))
out1
=
x
+
y
out2
=
x
*
y
out3
=
out1
*
out2
# Create dependency between outputs
op
=
OpFromGraph
([
x
,
y
],
[
out1
,
out2
,
out3
])
verify_grad
(
lambda
x
,
y
:
pt
.
add
(
*
op
(
x
,
y
)),
point
,
rng
=
rng
)
verify_grad
(
lambda
x
,
y
:
pt
.
add
(
*
op
(
x
,
y
)[:
-
1
]),
point
,
rng
=
rng
)
verify_grad
(
lambda
x
,
y
:
pt
.
add
(
*
op
(
x
,
y
)[
1
:]),
point
,
rng
=
rng
)
verify_grad
(
lambda
x
,
y
:
pt
.
add
(
*
op
(
x
,
y
)[::
2
]),
point
,
rng
=
rng
)
verify_grad
(
lambda
x
,
y
:
op
(
x
,
y
)[
0
],
point
,
rng
=
rng
)
verify_grad
(
lambda
x
,
y
:
op
(
x
,
y
)[
1
],
point
,
rng
=
rng
)
verify_grad
(
lambda
x
,
y
:
op
(
x
,
y
)[
2
],
point
,
rng
=
rng
)
# Test disconnected graphs are handled correctly
op
=
OpFromGraph
([
x
,
y
],
[
x
**
2
,
y
**
3
])
with
pytest
.
warns
(
UserWarning
):
grad_x_wrt_y
=
grad
(
op
(
x
,
y
)[
0
],
wrt
=
y
,
return_disconnected
=
"disconnected"
,
disconnected_inputs
=
"warn"
,
)
assert
isinstance
(
grad_x_wrt_y
.
type
,
DisconnectedType
)
def
test_repeated_inputs
(
self
):
def
test_repeated_inputs
(
self
):
x
=
pt
.
dscalar
(
"x"
)
x
=
pt
.
dscalar
(
"x"
)
y
=
pt
.
dscalar
(
"y"
)
y
=
pt
.
dscalar
(
"y"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论