Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
97317a50
提交
97317a50
authored
3月 20, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
3月 22, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow defining an OpFromGraph from constant and shared inputs.
Also adds a strict flag
上级
339aab4d
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
57 行增加
和
37 行删除
+57
-37
builders.py
pytensor/compile/builders.py
+31
-30
test_builders.py
tests/compile/test_builders.py
+26
-7
没有找到文件。
pytensor/compile/builders.py
浏览文件 @
97317a50
...
...
@@ -92,38 +92,29 @@ def construct_nominal_fgraph(
dict
[
Variable
,
Variable
],
]:
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
dummy_inputs
=
[]
for
n
,
inp
in
enumerate
(
inputs
):
if
(
not
isinstance
(
inp
,
Variable
)
or
isinstance
(
inp
,
Constant
)
or
isinstance
(
inp
,
SharedVariable
)
):
raise
TypeError
(
f
"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)
dummy_inputs
.
append
(
inp
.
type
())
implicit_shared_inputs
=
[]
dummy_
shared_inputs
=
[
]
shared_inputs
=
[]
dummy_
inputs
=
[
inp
.
type
()
for
inp
in
inputs
]
dummy_implicit_
shared_inputs
=
[]
for
var
in
graph_inputs
(
outputs
,
inputs
):
if
var
in
inputs
:
continue
if
isinstance
(
var
,
SharedVariable
):
#
To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
shared_inputs
.
append
(
var
)
dummy_shared_inputs
.
append
(
var
.
type
())
elif
var
not
in
inputs
and
not
isinstance
(
var
,
Constant
):
raise
MissingInputError
(
f
"OpFromGraph is missing an input: {var}"
)
replacements
=
dict
(
zip
(
inputs
+
shared_inputs
,
dummy_inputs
+
dummy_shared_inputs
)
)
#
We allow shared inputs to be added automatically to the graph
implicit_shared_inputs
.
append
(
var
)
dummy_implicit_shared_inputs
.
append
(
var
.
type
())
elif
not
isinstance
(
var
,
Constant
):
raise
MissingInputError
(
f
"NominalGraph is missing an input: {var}"
)
replacements
=
dict
(
zip
(
inputs
+
implicit_shared_inputs
,
dummy_inputs
+
dummy_implicit_shared_inputs
)
)
new
=
rebuild_collect_shared
(
cast
(
Sequence
[
Variable
],
outputs
),
inputs
=
inputs
+
shared_inputs
,
inputs
=
inputs
+
implicit_
shared_inputs
,
replace
=
replacements
,
copy_inputs_over
=
False
,
)
...
...
@@ -133,7 +124,7 @@ def construct_nominal_fgraph(
(
clone_d
,
update_d
,
update_expr
,
new_shared_inputs
),
)
=
new
assert
len
(
local_inputs
)
==
len
(
inputs
)
+
len
(
shared_inputs
)
assert
len
(
local_inputs
)
==
len
(
inputs
)
+
len
(
implicit_
shared_inputs
)
assert
len
(
local_outputs
)
==
len
(
outputs
)
assert
not
update_d
assert
not
update_expr
...
...
@@ -155,7 +146,7 @@ def construct_nominal_fgraph(
fgraph
.
clients
.
pop
(
inp
,
None
)
fgraph
.
add_input
(
nom_inp
)
return
fgraph
,
shared_inputs
,
update_d
,
update_expr
return
fgraph
,
implicit_
shared_inputs
,
update_d
,
update_expr
class
OpFromGraph
(
Op
,
HasInnerGraph
):
...
...
@@ -177,8 +168,6 @@ class OpFromGraph(Op, HasInnerGraph):
- grad() make it support DisconnectedType and the new interface
- add support for NullType and DisconnectedType when R_op supports them
- check how it works with updates.
- add test with constant as input or inside the inner graph.
- Add support for the GPU? Probably just need an opt to remove transfer
- Add support to pickle this Op.
- Add support/test with random generator
- Add optimization to removing unused inputs/outputs
...
...
@@ -310,11 +299,13 @@ class OpFromGraph(Op, HasInnerGraph):
self
,
inputs
:
list
[
Variable
],
outputs
:
list
[
Variable
],
*
,
inline
:
bool
=
False
,
lop_overrides
:
str
=
"default"
,
grad_overrides
:
str
=
"default"
,
rop_overrides
:
str
=
"default"
,
connection_pattern
:
Optional
[
list
[
list
[
bool
]]]
=
None
,
strict
:
bool
=
False
,
name
:
Optional
[
str
]
=
None
,
**
kwargs
,
):
...
...
@@ -399,6 +390,10 @@ class OpFromGraph(Op, HasInnerGraph):
must be equal to number of outputs. connection_pattern If not
``None``, this will be used as the connection_pattern for this
:class:`Op`.
strict: bool, default False
If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with
shared variables.
name
A name for debugging purposes.
kwargs
...
...
@@ -424,6 +419,12 @@ class OpFromGraph(Op, HasInnerGraph):
inputs
,
outputs
)
if
strict
and
self
.
shared_inputs
:
raise
ValueError
(
"All variables needed to compute inner-graph must be provided as inputs under strict=True. "
f
"The inner-graph implicitly depends on the following shared variables {self.shared_inputs}"
)
self
.
kwargs
=
kwargs
self
.
input_types
=
[
inp
.
type
for
inp
in
inputs
]
self
.
output_types
=
[
out
.
type
for
out
in
outputs
]
...
...
tests/compile/test_builders.py
浏览文件 @
97317a50
...
...
@@ -15,7 +15,7 @@ from pytensor.graph.null_type import NullType
from
pytensor.graph.rewriting.utils
import
rewrite_graph
from
pytensor.graph.utils
import
MissingInputError
from
pytensor.printing
import
debugprint
from
pytensor.tensor.basic
import
as_tensor
from
pytensor.tensor.basic
import
constant
from
pytensor.tensor.math
import
dot
,
exp
,
sigmoid
from
pytensor.tensor.math
import
round
as
pt_round
from
pytensor.tensor.math
import
sum
as
pt_sum
...
...
@@ -43,12 +43,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
with
pytest
.
raises
(
TypeError
):
OpFromGraph
([
1
],
[
1
])
with
pytest
.
raises
(
TypeError
):
OpFromGraph
([
x
,
as_tensor
(
1
)],
[
x
])
with
pytest
.
raises
(
TypeError
):
OpFromGraph
([
shared
(
1
)],
[
1
])
with
pytest
.
raises
(
NotImplementedError
):
OpFromGraph
([
x
],
[
x
],
updates
=
{})
...
...
@@ -559,6 +553,31 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
# The original `op.fgraph` outputs should stay the same, though
assert
equal_computations
(
op
.
inner_outputs
,
[
x
**
2
/
x
],
op
.
inner_inputs
,
[
x
])
def
test_explicit_input_from_constant
(
self
):
x
=
pt
.
dscalar
(
"x"
)
y
=
constant
(
1.0
,
name
=
"y"
)
test_ofg
=
OpFromGraph
([
x
,
y
],
[
x
+
y
])
out
=
test_ofg
(
x
,
y
)
assert
out
.
eval
({
x
:
5
})
==
6
def
test_explicit_input_from_shared
(
self
):
x
=
pt
.
dscalar
(
"x"
)
y
=
shared
(
1.0
,
name
=
"y"
)
with
pytest
.
raises
(
ValueError
,
match
=
r"The inner-graph implicitly depends on the following shared variables \[y\]"
,
):
OpFromGraph
([
x
],
[
x
+
y
],
strict
=
True
)
test_ofg
=
OpFromGraph
([
x
,
y
],
[
x
+
y
],
strict
=
True
)
out
=
test_ofg
(
x
,
y
)
assert
out
.
eval
({
x
:
5
})
==
6
y
.
set_value
(
2.0
)
assert
out
.
eval
({
x
:
6
})
@config.change_flags
(
floatX
=
"float64"
)
def
test_debugprint
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论