Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9ba6d99f
提交
9ba6d99f
authored
5月 30, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
7月 08, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace str "output" by a dummy Op in the clients of the FunctionGraph
上级
7f623fef
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
18 个修改的文件
包含
84 行增加
和
91 行删除
+84
-91
debugmode.py
pytensor/compile/debugmode.py
+8
-9
types.py
pytensor/compile/function/types.py
+15
-18
profiling.py
pytensor/compile/profiling.py
+4
-7
destroyhandler.py
pytensor/graph/destroyhandler.py
+2
-3
fg.py
pytensor/graph/fg.py
+0
-0
basic.py
pytensor/graph/rewriting/basic.py
+3
-5
utils.py
pytensor/graph/rewriting/utils.py
+4
-6
basic.py
pytensor/link/c/basic.py
+1
-3
vm.py
pytensor/link/vm.py
+1
-2
printing.py
pytensor/printing.py
+9
-4
rewriting.py
pytensor/scan/rewriting.py
+4
-4
basic.py
pytensor/tensor/basic.py
+2
-2
basic.py
pytensor/tensor/random/rewriting/basic.py
+5
-8
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+3
-4
linalg.py
pytensor/tensor/rewriting/linalg.py
+0
-2
math.py
pytensor/tensor/rewriting/math.py
+2
-5
shape.py
pytensor/tensor/rewriting/shape.py
+6
-3
test_fg.py
tests/graph/test_fg.py
+15
-6
没有找到文件。
pytensor/compile/debugmode.py
浏览文件 @
9ba6d99f
...
...
@@ -30,6 +30,7 @@ from pytensor.configdefaults import config
from
pytensor.graph.basic
import
Variable
,
io_toposort
from
pytensor.graph.destroyhandler
import
DestroyHandler
from
pytensor.graph.features
import
AlreadyThere
,
BadOptimization
from
pytensor.graph.fg
import
Output
from
pytensor.graph.op
import
HasInnerGraph
,
Op
from
pytensor.graph.utils
import
InconsistencyError
,
MethodNotDefined
from
pytensor.link.basic
import
Container
,
LocalLinker
...
...
@@ -628,7 +629,9 @@ def _is_used_in_graph(fgraph, var):
True if `var` is used by another node in the graph.
"""
return
not
(
fgraph
.
clients
[
var
]
==
[(
"output"
,
1
)]
or
fgraph
.
clients
[
var
]
==
[])
return
any
(
client
for
client
,
_
in
fgraph
.
clients
[
var
]
if
not
isinstance
(
client
.
op
,
Output
)
)
def
_check_strides_match
(
a
,
b
,
warn_err
,
op
):
...
...
@@ -977,7 +980,7 @@ def _check_preallocated_output(
# disable memory checks in that mode, since they were already run.
try
:
changed_inner_mode
=
False
if
isinstance
(
getattr
(
node
,
"op"
,
None
)
,
HasInnerGraph
):
if
isinstance
(
node
.
op
,
HasInnerGraph
):
fn
=
node
.
op
.
fn
if
not
(
fn
and
hasattr
(
fn
,
"maker"
)
and
hasattr
(
fn
.
maker
,
"mode"
)):
_logger
.
warning
(
f
"Expected pytensor function not found in {node.op}.fn"
)
...
...
@@ -1132,18 +1135,14 @@ class _FunctionGraphEvent:
def
__init__
(
self
,
kind
,
node
,
idx
=
None
,
reason
=
None
):
self
.
kind
=
kind
if
node
==
"output"
:
self
.
node
=
"output"
self
.
op
=
"output"
else
:
self
.
node
=
node
self
.
op
=
node
.
op
self
.
node
=
node
self
.
op
=
node
.
op
self
.
idx
=
idx
self
.
reason
=
str
(
reason
)
def
__str__
(
self
):
if
self
.
kind
==
"change"
:
if
self
.
op
!=
"output"
:
if
not
isinstance
(
self
.
op
,
Output
)
:
msg
=
str
(
len
(
self
.
node
.
inputs
))
else
:
msg
=
""
...
...
pytensor/compile/function/types.py
浏览文件 @
9ba6d99f
...
...
@@ -78,8 +78,6 @@ def view_tree_set(fgraph, v, treeset):
"""
treeset
.
add
(
v
)
for
cl
,
v_input_pos_to_cl
in
fgraph
.
clients
[
v
]:
if
cl
==
"output"
:
continue
vmap
=
cl
.
op
.
view_map
dmap
=
cl
.
op
.
destroy_map
for
opos
,
iposlist
in
chain
(
vmap
.
items
(),
dmap
.
items
()):
...
...
@@ -1202,8 +1200,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
has_destroyers_attr
=
hasattr
(
fgraph
,
"has_destroyers"
)
for
i
in
range
(
len
(
fgraph
.
outputs
)):
original_out
=
fgraph
.
outputs
[
i
]
output_client
=
fgraph
.
get_output_client
(
i
)
views_of_output_i
=
set
()
view_tree_set
(
fgraph
,
alias_root
(
fgraph
.
outputs
[
i
]
),
views_of_output_i
)
view_tree_set
(
fgraph
,
alias_root
(
original_out
),
views_of_output_i
)
copied
=
False
# do not allow outputs to be aliased
for
j
in
range
(
i
+
1
,
len
(
fgraph
.
outputs
)):
...
...
@@ -1212,16 +1213,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
if
fgraph
.
outputs
[
j
]
in
views_of_output_i
:
if
wrapped_outputs
[
i
]
.
borrow
and
wrapped_outputs
[
j
]
.
borrow
:
fgraph
.
change_node_input
(
"output"
,
i
,
view_op
(
fgraph
.
outputs
[
i
]
),
reason
=
reason
*
output_client
,
view_op
(
original_out
),
reason
=
reason
)
else
:
fgraph
.
change_node_input
(
"output"
,
i
,
deep_copy_op
(
fgraph
.
outputs
[
i
]
),
reason
=
reason
*
output_client
,
deep_copy_op
(
original_out
),
reason
=
reason
)
copied
=
True
break
if
not
copied
:
if
not
copied
:
# no-break
for
input_j
in
all_graph_inputs
:
# do not allow outputs to be aliased to an inputs (j), unless
# a) that j'th input has been 'destroyed' by
...
...
@@ -1239,33 +1240,29 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
j
=
fgraph
.
inputs
.
index
(
input_j
)
if
wrapped_outputs
[
i
]
.
borrow
and
wrapped_inputs
[
j
]
.
borrow
:
fgraph
.
change_node_input
(
"output"
,
i
,
view_op
(
fgraph
.
outputs
[
i
]),
*
output_client
,
view_op
(
original_out
),
reason
=
reason
,
)
break
else
:
fgraph
.
change_node_input
(
"output"
,
i
,
deep_copy_op
(
fgraph
.
outputs
[
i
]),
*
output_client
,
deep_copy_op
(
original_out
),
reason
=
reason
,
)
break
elif
wrapped_outputs
[
i
]
.
borrow
:
fgraph
.
change_node_input
(
"output"
,
i
,
view_op
(
fgraph
.
outputs
[
i
]),
*
output_client
,
view_op
(
original_out
),
reason
=
reason
,
)
break
else
:
fgraph
.
change_node_input
(
"output"
,
i
,
deep_copy_op
(
fgraph
.
outputs
[
i
]),
*
output_client
,
deep_copy_op
(
original_out
),
reason
=
reason
,
)
break
...
...
pytensor/compile/profiling.py
浏览文件 @
9ba6d99f
...
...
@@ -16,20 +16,17 @@ import sys
import
time
from
collections
import
Counter
,
defaultdict
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
Any
import
numpy
as
np
import
pytensor
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.link.utils
import
get_destroy_dependencies
if
TYPE_CHECKING
:
from
pytensor.graph.fg
import
FunctionGraph
@contextmanager
def
extended_open
(
filename
,
mode
=
"r"
):
if
filename
==
"<stdout>"
:
...
...
@@ -1038,7 +1035,7 @@ class ProfileStats:
executable_nodes
=
set
()
for
var
in
fgraph
.
inputs
:
for
c
,
_
in
fgraph
.
clients
[
var
]:
if
c
!=
"output"
:
if
not
isinstance
(
c
.
op
,
Output
)
:
deps
=
c
.
inputs
+
destroy_dependencies
[
c
]
if
all
(
compute_map
[
v
][
0
]
for
v
in
deps
):
executable_nodes
.
add
(
c
)
...
...
@@ -1166,7 +1163,7 @@ class ProfileStats:
for
var
in
node
.
outputs
:
for
c
,
_
in
fgraph
.
clients
[
var
]:
if
c
!=
"output"
:
if
not
isinstance
(
c
.
op
,
Output
)
:
deps
=
c
.
inputs
+
destroy_dependencies
[
c
]
if
all
(
compute_map
[
v
][
0
]
for
v
in
deps
):
new_exec_nodes
.
add
(
c
)
...
...
pytensor/graph/destroyhandler.py
浏览文件 @
9ba6d99f
...
...
@@ -11,6 +11,7 @@ import pytensor
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.features
import
AlreadyThere
,
Bookkeeper
from
pytensor.graph.fg
import
Output
from
pytensor.graph.utils
import
InconsistencyError
from
pytensor.misc.ordered_set
import
OrderedSet
...
...
@@ -401,8 +402,6 @@ class DestroyHandler(Bookkeeper):
def
recursive_destroys_finder
(
protected_var
):
# protected_var is the idx'th input of app.
for
app
,
idx
in
fgraph
.
clients
[
protected_var
]:
if
app
==
"output"
:
continue
destroy_maps
=
app
.
op
.
destroy_map
.
values
()
# If True means that the apply node, destroys the protected_var.
if
idx
in
[
dmap
for
sublist
in
destroy_maps
for
dmap
in
sublist
]:
...
...
@@ -578,7 +577,7 @@ class DestroyHandler(Bookkeeper):
app.inputs[i] changed from old_r to new_r.
"""
if
app
==
"output"
:
if
isinstance
(
app
.
op
,
Output
)
:
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph.
pass
...
...
pytensor/graph/fg.py
浏览文件 @
9ba6d99f
差异被折叠。
点击展开。
pytensor/graph/rewriting/basic.py
浏览文件 @
9ba6d99f
...
...
@@ -30,7 +30,7 @@ from pytensor.graph.basic import (
vars_between
,
)
from
pytensor.graph.features
import
AlreadyThere
,
Feature
,
NodeFinder
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.utils
import
AssocList
,
InconsistencyError
from
pytensor.misc.ordered_set
import
OrderedSet
...
...
@@ -738,7 +738,7 @@ class MergeOptimizer(GraphRewriter):
if
any
(
i
in
flatten
(
c
.
op
.
destroy_map
.
values
())
for
c
,
i
in
clients
if
c
!=
"output"
and
c
.
op
.
destroy_map
if
c
.
op
.
destroy_map
):
continue
...
...
@@ -1612,8 +1612,6 @@ class PatternNodeRewriter(NodeRewriter):
if
get_nodes
and
self
.
get_nodes
is
not
None
:
for
real_node
in
self
.
get_nodes
(
fgraph
,
node
):
if
real_node
==
"output"
:
continue
ret
=
self
.
transform
(
fgraph
,
real_node
,
get_nodes
=
False
)
if
ret
is
not
False
and
ret
is
not
None
:
return
dict
(
zip
(
real_node
.
outputs
,
ret
))
...
...
@@ -2399,7 +2397,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
if
self
.
tracks_on_change_inputs
:
def
chin_
(
node
,
i
,
r
,
new_r
,
reason
):
if
node
is
not
current_node
and
not
isinstance
(
node
,
str
):
if
node
is
not
current_node
and
not
isinstance
(
node
.
op
,
Output
):
q
.
append
(
node
)
chin
=
chin_
...
...
pytensor/graph/rewriting/utils.py
浏览文件 @
9ba6d99f
import
copy
from
collections.abc
import
Generator
,
Sequence
from
typing
import
TYPE_CHECKING
,
Optional
,
cast
from
typing
import
TYPE_CHECKING
,
Optional
import
pytensor
from
pytensor.graph.basic
import
(
...
...
@@ -10,7 +10,7 @@ from pytensor.graph.basic import (
graph_inputs
,
vars_between
,
)
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
...
...
@@ -230,11 +230,9 @@ def get_clients_at_depth(
for
var
in
node
.
outputs
:
if
depth
>
0
:
for
out_node
,
_
in
fgraph
.
clients
[
var
]:
if
out_node
==
"output"
:
if
isinstance
(
out_node
.
op
,
Output
)
:
continue
yield
from
get_clients_at_depth
(
fgraph
,
cast
(
Apply
,
out_node
),
depth
-
1
)
yield
from
get_clients_at_depth
(
fgraph
,
out_node
,
depth
-
1
)
else
:
assert
var
.
owner
is
not
None
yield
var
.
owner
pytensor/link/c/basic.py
浏览文件 @
9ba6d99f
...
...
@@ -354,9 +354,7 @@ def get_c_declare(fgraph, r, name, sub):
# it means they need `r`'s dtype to be declared, so
# we have to pass `check_input=True` to `c_declare`.
if
any
(
getattr
(
c
.
op
,
"check_input"
,
config
.
check_input
)
for
(
c
,
_
)
in
fgraph
.
clients
[
r
]
if
not
isinstance
(
c
,
str
)
getattr
(
c
.
op
,
"check_input"
,
config
.
check_input
)
for
(
c
,
_
)
in
fgraph
.
clients
[
r
]
)
or
(
r
.
owner
and
getattr
(
r
.
owner
.
op
,
"check_input"
,
config
.
check_input
)):
c_declare
=
r
.
type
.
c_declare
(
name
,
sub
,
True
)
else
:
...
...
pytensor/link/vm.py
浏览文件 @
9ba6d99f
...
...
@@ -954,8 +954,7 @@ class VMLinker(LocalLinker):
if
k
.
owner
and
self
.
fgraph
.
clients
[
k
]:
ls
=
[]
for
cl
in
self
.
fgraph
.
clients
[
k
]:
if
cl
[
0
]
!=
"output"
:
ls
+=
cl
[
0
]
.
outputs
ls
+=
cl
[
0
]
.
outputs
dependencies
[
k
]
+=
ls
return
dependencies
...
...
pytensor/printing.py
浏览文件 @
9ba6d99f
...
...
@@ -437,10 +437,15 @@ N.B.:
for
out
in
inner_outputs
:
if
(
isinstance
(
getattr
(
out
.
owner
,
"op"
,
None
),
HasInnerGraph
)
or
hasattr
(
getattr
(
out
.
owner
,
"op"
,
None
),
"scalar_op"
)
and
isinstance
(
out
.
owner
.
op
.
scalar_op
,
HasInnerGraph
)
)
and
out
not
in
inner_graph_vars
:
out
.
owner
is
not
None
and
(
isinstance
(
out
.
owner
.
op
,
HasInnerGraph
)
or
isinstance
(
getattr
(
out
.
owner
.
op
,
"scalar_op"
,
None
),
HasInnerGraph
)
)
and
out
not
in
inner_graph_vars
):
inner_graph_vars
.
append
(
out
)
_debugprint
(
...
...
pytensor/scan/rewriting.py
浏览文件 @
9ba6d99f
...
...
@@ -27,7 +27,7 @@ from pytensor.graph.basic import (
)
from
pytensor.graph.destroyhandler
import
DestroyHandler
from
pytensor.graph.features
import
ReplaceValidate
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.op
import
compute_test_value
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.rewriting.basic
import
(
...
...
@@ -1303,7 +1303,7 @@ def scan_save_mem(fgraph, node):
for
cl
,
_
in
fgraph
.
clients
[
out
]:
# 2.1 outputs of the function
# => output needs all its intermediate values
if
isinstance
(
cl
,
str
):
if
isinstance
(
cl
.
op
,
Output
):
# if the node is actually an output, then
# we need to store the entire thing
global_nsteps
=
None
...
...
@@ -1412,7 +1412,7 @@ def scan_save_mem(fgraph, node):
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
# look at all its clients
for
cl
,
_
in
fgraph
.
clients
[
out
]:
if
isinstance
(
cl
,
str
):
if
isinstance
(
cl
.
op
,
Output
):
store_steps
[
i
]
=
0
break
elif
not
isinstance
(
cl
.
op
,
Subtensor
):
...
...
@@ -2309,7 +2309,7 @@ def scan_push_out_dot1(fgraph, node):
and
isinstance
(
out
.
owner
.
op
.
scalar_op
,
ps
.
Add
)
and
inp
in
out
.
owner
.
inputs
and
len
(
fgraph
.
clients
[
outer_out
])
==
1
and
not
isinstance
(
fgraph
.
clients
[
outer_out
][
0
][
0
],
str
)
and
not
isinstance
(
fgraph
.
clients
[
outer_out
][
0
][
0
],
Output
)
and
isinstance
(
fgraph
.
clients
[
outer_out
][
0
][
0
]
.
op
,
Subtensor
)
and
fgraph
.
clients
[
outer_out
][
0
][
0
]
.
op
.
idx_list
==
(
-
1
,)
):
...
...
pytensor/tensor/basic.py
浏览文件 @
9ba6d99f
...
...
@@ -24,7 +24,7 @@ from pytensor import scalar as ps
from
pytensor.gradient
import
DisconnectedType
,
grad_undefined
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
equal_computations
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.rewriting.db
import
EquilibriumDB
...
...
@@ -1681,7 +1681,7 @@ class Alloc(COp):
return
False
for
client
,
idx
in
clients
:
if
client
==
"output"
:
if
isinstance
(
client
.
op
,
Output
)
:
# If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold.
return
False
...
...
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
9ba6d99f
...
...
@@ -31,15 +31,12 @@ def is_rv_used_in_graph(base_rv, node, fgraph):
TODO: We should apply all the shape rewrites before these rewrites, since
that would properly remove the unnecessary dependencies on `base_rv` (when
possible).
"""
def
_node_check
(
n
,
i
):
if
n
==
"output"
:
n
=
fgraph
.
outputs
[
i
]
.
owner
return
n
==
node
or
isinstance
(
n
.
op
,
Shape
|
Shape_i
)
return
not
all
(
_node_check
(
n
,
i
)
for
n
,
i
in
fgraph
.
clients
.
get
(
base_rv
,
()))
return
any
(
n
for
n
,
i
in
fgraph
.
clients
.
get
(
base_rv
,
())
if
not
(
n
is
node
or
isinstance
(
n
.
op
,
Shape
|
Shape_i
))
)
@node_rewriter
([
RandomVariable
],
inplace
=
True
)
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
9ba6d99f
...
...
@@ -14,7 +14,7 @@ from pytensor.configdefaults import config
from
pytensor.graph
import
FunctionGraph
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
ancestors
,
io_toposort
from
pytensor.graph.features
import
ReplaceValidate
from
pytensor.graph.fg
import
ApplyOr
Output
from
pytensor.graph.fg
import
Output
from
pytensor.graph.rewriting.basic
import
(
EquilibriumGraphRewriter
,
GraphRewriter
,
...
...
@@ -688,7 +688,7 @@ class FusionOptimizer(GraphRewriter):
"""
FUSEABLE_MAPPING
=
defaultdict
[
Variable
,
list
[
Apply
]]
UNFUSEABLE_MAPPING
=
defaultdict
[
Variable
,
set
[
Apply
OrOutput
]]
UNFUSEABLE_MAPPING
=
defaultdict
[
Variable
,
set
[
Apply
]]
def
initialize_fuseable_mappings
(
*
,
fg
:
FunctionGraph
...
...
@@ -727,7 +727,6 @@ class FusionOptimizer(GraphRewriter):
for
client
,
_
in
clients
:
if
(
out_maybe_fuseable
and
not
isinstance
(
client
,
str
)
# "output"
and
isinstance
(
client
.
op
,
Elemwise
)
# and not isinstance(client.op.scalar_op, ps.Composite)
and
len
(
client
.
outputs
)
==
1
...
...
@@ -841,7 +840,7 @@ class FusionOptimizer(GraphRewriter):
implied_unfuseable_clients
=
{
c
for
client
in
unfuseable_clients_clone
.
get
(
next_out
,
())
if
not
isinstance
(
client
,
str
)
# "output"
if
not
isinstance
(
client
.
op
,
Output
)
for
c
in
client
.
outputs
}
...
...
pytensor/tensor/rewriting/linalg.py
浏览文件 @
9ba6d99f
...
...
@@ -299,8 +299,6 @@ def local_det_chol(fgraph, node):
"""
(
x
,)
=
node
.
inputs
for
cl
,
xpos
in
fgraph
.
clients
[
x
]:
if
cl
==
"output"
:
continue
if
isinstance
(
cl
.
op
,
Blockwise
)
and
isinstance
(
cl
.
op
.
core_op
,
Cholesky
):
L
=
cl
.
outputs
[
0
]
return
[
prod
(
diagonal
(
L
,
axis1
=-
2
,
axis2
=-
1
)
**
2
,
axis
=-
1
)]
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
9ba6d99f
...
...
@@ -1126,14 +1126,11 @@ class AlgebraicCanonizer(NodeRewriter):
# this canonized graph... if so, we do nothing and wait for
# them to be transformed.
for
c
,
c_idx
in
out_clients
:
if
c
==
"output"
:
continue
while
(
isinstance
(
getattr
(
c
,
"op"
,
None
),
DimShuffle
)
and
len
(
fgraph
.
clients
[
c
.
outputs
[
0
]])
<=
1
isinstance
(
c
.
op
,
DimShuffle
)
and
len
(
fgraph
.
clients
[
c
.
outputs
[
0
]])
<=
1
):
c
=
fgraph
.
clients
[
c
.
outputs
[
0
]][
0
][
0
]
if
getattr
(
c
,
"op"
,
""
)
in
[
self
.
main
,
self
.
inverse
,
self
.
reciprocal
]:
if
c
.
op
in
[
self
.
main
,
self
.
inverse
,
self
.
reciprocal
]:
return
False
# Here we make the canonical version of the graph around this node
...
...
pytensor/tensor/rewriting/shape.py
浏览文件 @
9ba6d99f
...
...
@@ -401,7 +401,7 @@ class ShapeFeature(Feature):
merged_shape
.
append
(
other_shape
[
i
])
elif
(
ps
.
owner
and
isinstance
(
getattr
(
ps
.
owner
,
"op"
,
None
)
,
Shape_i
)
and
isinstance
(
ps
.
owner
.
op
,
Shape_i
)
and
ps
.
owner
.
op
.
i
==
i
and
ps
.
owner
.
inputs
[
0
]
in
(
r
,
other_r
)
):
...
...
@@ -602,7 +602,7 @@ class ShapeFeature(Feature):
# r is *scheduled*.
# At that point, node is no longer a client of r, but of new_r
for
shpnode
,
idx
in
fgraph
.
clients
[
r
]
+
[(
node
,
i
)]:
if
isinstance
(
getattr
(
shpnode
,
"op"
,
None
)
,
Shape_i
):
if
isinstance
(
shpnode
.
op
,
Shape_i
):
idx
=
shpnode
.
op
.
i
repl
=
self
.
shape_of
[
new_r
][
idx
]
if
repl
.
owner
is
shpnode
:
...
...
@@ -1057,7 +1057,10 @@ def local_Shape_of_SpecifyShape(fgraph, node):
specified_shape
=
node
.
inputs
[
0
]
if
not
isinstance
(
getattr
(
specified_shape
.
owner
,
"op"
,
None
),
SpecifyShape
):
if
not
(
specified_shape
.
owner
is
not
None
and
isinstance
(
specified_shape
.
owner
.
op
,
SpecifyShape
)
):
return
False
x
,
*
shape
=
specified_shape
.
owner
.
inputs
...
...
tests/graph/test_fg.py
浏览文件 @
9ba6d99f
...
...
@@ -6,7 +6,7 @@ import pytest
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
NominalVariable
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.utils
import
MissingInputError
from
pytensor.printing
import
debugprint
from
tests.graph.utils
import
(
...
...
@@ -78,8 +78,13 @@ class TestFunctionGraph:
assert
fg
.
variables
==
{
var1
,
var2
,
var3
,
var4
}
assert
fg
.
get_clients
(
var1
)
==
[(
var3
.
owner
,
0
)]
assert
fg
.
get_clients
(
var2
)
==
[(
var4
.
owner
,
1
)]
assert
fg
.
get_clients
(
var3
)
==
[(
"output"
,
0
),
(
var4
.
owner
,
0
)]
assert
fg
.
get_clients
(
var4
)
==
[(
"output"
,
1
)]
var3_clients
=
fg
.
get_clients
(
var3
)
assert
len
(
var3_clients
)
==
2
assert
var3_clients
[
0
][
0
]
.
op
==
Output
(
0
)
assert
var3_clients
[
1
]
==
(
var4
.
owner
,
0
)
var4_clients
=
fg
.
get_clients
(
var4
)
assert
len
(
var4_clients
)
==
1
assert
var4_clients
[
0
][
0
]
.
op
==
Output
(
1
)
varC
=
MyConstant
(
"varC"
)
var5
=
op1
(
var1
,
varC
)
...
...
@@ -208,8 +213,11 @@ class TestFunctionGraph:
fg
=
FunctionGraph
([
var1
,
var2
],
[
var3
,
var5
],
clone
=
False
)
var6
=
MyVariable2
(
"var6"
)
[
out_client
]
=
[
cl
for
cl
,
_
in
fg
.
clients
[
fg
.
outputs
[
0
]]
if
isinstance
(
cl
.
op
,
Output
)
]
with
pytest
.
raises
(
TypeError
):
fg
.
change_node_input
(
"output"
,
1
,
var6
)
fg
.
change_node_input
(
out_client
,
0
,
var6
)
with
pytest
.
raises
(
TypeError
):
fg
.
change_node_input
(
var5
.
owner
,
1
,
var6
)
...
...
@@ -358,12 +366,13 @@ class TestFunctionGraph:
# TODO: What if the index value is greater than 1? It will throw an
# `IndexError`, but that doesn't sound like anything we'd want.
out_node
=
Output
(
idx
=
1
)
.
make_node
(
var4
)
with
pytest
.
raises
(
Exception
,
match
=
"Inconsistent clients list.*"
):
fg
.
add_client
(
var4
,
(
"output"
,
1
))
fg
.
add_client
(
var4
,
(
out_node
,
0
))
fg
.
check_integrity
()
fg
.
remove_client
(
var4
,
(
"output"
,
1
))
fg
.
remove_client
(
var4
,
(
out_node
,
0
))
with
pytest
.
raises
(
TypeError
,
match
=
"The first entry of.*"
):
fg
.
add_client
(
var4
,
(
None
,
0
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论