Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9ba6d99f
提交
9ba6d99f
authored
5月 30, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
7月 08, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace str "output" by a dummy Op in the clients of the FunctionGraph
上级
7f623fef
全部展开
显示空白字符变更
内嵌
并排
正在显示
18 个修改的文件
包含
81 行增加
和
88 行删除
+81
-88
debugmode.py
pytensor/compile/debugmode.py
+6
-7
types.py
pytensor/compile/function/types.py
+15
-18
profiling.py
pytensor/compile/profiling.py
+4
-7
destroyhandler.py
pytensor/graph/destroyhandler.py
+2
-3
fg.py
pytensor/graph/fg.py
+0
-0
basic.py
pytensor/graph/rewriting/basic.py
+3
-5
utils.py
pytensor/graph/rewriting/utils.py
+4
-6
basic.py
pytensor/link/c/basic.py
+1
-3
vm.py
pytensor/link/vm.py
+0
-1
printing.py
pytensor/printing.py
+9
-4
rewriting.py
pytensor/scan/rewriting.py
+4
-4
basic.py
pytensor/tensor/basic.py
+2
-2
basic.py
pytensor/tensor/random/rewriting/basic.py
+5
-8
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+3
-4
linalg.py
pytensor/tensor/rewriting/linalg.py
+0
-2
math.py
pytensor/tensor/rewriting/math.py
+2
-5
shape.py
pytensor/tensor/rewriting/shape.py
+6
-3
test_fg.py
tests/graph/test_fg.py
+15
-6
没有找到文件。
pytensor/compile/debugmode.py
浏览文件 @
9ba6d99f
...
@@ -30,6 +30,7 @@ from pytensor.configdefaults import config
...
@@ -30,6 +30,7 @@ from pytensor.configdefaults import config
from
pytensor.graph.basic
import
Variable
,
io_toposort
from
pytensor.graph.basic
import
Variable
,
io_toposort
from
pytensor.graph.destroyhandler
import
DestroyHandler
from
pytensor.graph.destroyhandler
import
DestroyHandler
from
pytensor.graph.features
import
AlreadyThere
,
BadOptimization
from
pytensor.graph.features
import
AlreadyThere
,
BadOptimization
from
pytensor.graph.fg
import
Output
from
pytensor.graph.op
import
HasInnerGraph
,
Op
from
pytensor.graph.op
import
HasInnerGraph
,
Op
from
pytensor.graph.utils
import
InconsistencyError
,
MethodNotDefined
from
pytensor.graph.utils
import
InconsistencyError
,
MethodNotDefined
from
pytensor.link.basic
import
Container
,
LocalLinker
from
pytensor.link.basic
import
Container
,
LocalLinker
...
@@ -628,7 +629,9 @@ def _is_used_in_graph(fgraph, var):
...
@@ -628,7 +629,9 @@ def _is_used_in_graph(fgraph, var):
True if `var` is used by another node in the graph.
True if `var` is used by another node in the graph.
"""
"""
return
not
(
fgraph
.
clients
[
var
]
==
[(
"output"
,
1
)]
or
fgraph
.
clients
[
var
]
==
[])
return
any
(
client
for
client
,
_
in
fgraph
.
clients
[
var
]
if
not
isinstance
(
client
.
op
,
Output
)
)
def
_check_strides_match
(
a
,
b
,
warn_err
,
op
):
def
_check_strides_match
(
a
,
b
,
warn_err
,
op
):
...
@@ -977,7 +980,7 @@ def _check_preallocated_output(
...
@@ -977,7 +980,7 @@ def _check_preallocated_output(
# disable memory checks in that mode, since they were already run.
# disable memory checks in that mode, since they were already run.
try
:
try
:
changed_inner_mode
=
False
changed_inner_mode
=
False
if
isinstance
(
getattr
(
node
,
"op"
,
None
)
,
HasInnerGraph
):
if
isinstance
(
node
.
op
,
HasInnerGraph
):
fn
=
node
.
op
.
fn
fn
=
node
.
op
.
fn
if
not
(
fn
and
hasattr
(
fn
,
"maker"
)
and
hasattr
(
fn
.
maker
,
"mode"
)):
if
not
(
fn
and
hasattr
(
fn
,
"maker"
)
and
hasattr
(
fn
.
maker
,
"mode"
)):
_logger
.
warning
(
f
"Expected pytensor function not found in {node.op}.fn"
)
_logger
.
warning
(
f
"Expected pytensor function not found in {node.op}.fn"
)
...
@@ -1132,10 +1135,6 @@ class _FunctionGraphEvent:
...
@@ -1132,10 +1135,6 @@ class _FunctionGraphEvent:
def
__init__
(
self
,
kind
,
node
,
idx
=
None
,
reason
=
None
):
def
__init__
(
self
,
kind
,
node
,
idx
=
None
,
reason
=
None
):
self
.
kind
=
kind
self
.
kind
=
kind
if
node
==
"output"
:
self
.
node
=
"output"
self
.
op
=
"output"
else
:
self
.
node
=
node
self
.
node
=
node
self
.
op
=
node
.
op
self
.
op
=
node
.
op
self
.
idx
=
idx
self
.
idx
=
idx
...
@@ -1143,7 +1142,7 @@ class _FunctionGraphEvent:
...
@@ -1143,7 +1142,7 @@ class _FunctionGraphEvent:
def
__str__
(
self
):
def
__str__
(
self
):
if
self
.
kind
==
"change"
:
if
self
.
kind
==
"change"
:
if
self
.
op
!=
"output"
:
if
not
isinstance
(
self
.
op
,
Output
)
:
msg
=
str
(
len
(
self
.
node
.
inputs
))
msg
=
str
(
len
(
self
.
node
.
inputs
))
else
:
else
:
msg
=
""
msg
=
""
...
...
pytensor/compile/function/types.py
浏览文件 @
9ba6d99f
...
@@ -78,8 +78,6 @@ def view_tree_set(fgraph, v, treeset):
...
@@ -78,8 +78,6 @@ def view_tree_set(fgraph, v, treeset):
"""
"""
treeset
.
add
(
v
)
treeset
.
add
(
v
)
for
cl
,
v_input_pos_to_cl
in
fgraph
.
clients
[
v
]:
for
cl
,
v_input_pos_to_cl
in
fgraph
.
clients
[
v
]:
if
cl
==
"output"
:
continue
vmap
=
cl
.
op
.
view_map
vmap
=
cl
.
op
.
view_map
dmap
=
cl
.
op
.
destroy_map
dmap
=
cl
.
op
.
destroy_map
for
opos
,
iposlist
in
chain
(
vmap
.
items
(),
dmap
.
items
()):
for
opos
,
iposlist
in
chain
(
vmap
.
items
(),
dmap
.
items
()):
...
@@ -1202,8 +1200,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
...
@@ -1202,8 +1200,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
has_destroyers_attr
=
hasattr
(
fgraph
,
"has_destroyers"
)
has_destroyers_attr
=
hasattr
(
fgraph
,
"has_destroyers"
)
for
i
in
range
(
len
(
fgraph
.
outputs
)):
for
i
in
range
(
len
(
fgraph
.
outputs
)):
original_out
=
fgraph
.
outputs
[
i
]
output_client
=
fgraph
.
get_output_client
(
i
)
views_of_output_i
=
set
()
views_of_output_i
=
set
()
view_tree_set
(
fgraph
,
alias_root
(
fgraph
.
outputs
[
i
]
),
views_of_output_i
)
view_tree_set
(
fgraph
,
alias_root
(
original_out
),
views_of_output_i
)
copied
=
False
copied
=
False
# do not allow outputs to be aliased
# do not allow outputs to be aliased
for
j
in
range
(
i
+
1
,
len
(
fgraph
.
outputs
)):
for
j
in
range
(
i
+
1
,
len
(
fgraph
.
outputs
)):
...
@@ -1212,16 +1213,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
...
@@ -1212,16 +1213,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
if
fgraph
.
outputs
[
j
]
in
views_of_output_i
:
if
fgraph
.
outputs
[
j
]
in
views_of_output_i
:
if
wrapped_outputs
[
i
]
.
borrow
and
wrapped_outputs
[
j
]
.
borrow
:
if
wrapped_outputs
[
i
]
.
borrow
and
wrapped_outputs
[
j
]
.
borrow
:
fgraph
.
change_node_input
(
fgraph
.
change_node_input
(
"output"
,
i
,
view_op
(
fgraph
.
outputs
[
i
]
),
reason
=
reason
*
output_client
,
view_op
(
original_out
),
reason
=
reason
)
)
else
:
else
:
fgraph
.
change_node_input
(
fgraph
.
change_node_input
(
"output"
,
i
,
deep_copy_op
(
fgraph
.
outputs
[
i
]
),
reason
=
reason
*
output_client
,
deep_copy_op
(
original_out
),
reason
=
reason
)
)
copied
=
True
copied
=
True
break
break
if
not
copied
:
if
not
copied
:
# no-break
for
input_j
in
all_graph_inputs
:
for
input_j
in
all_graph_inputs
:
# do not allow outputs to be aliased to an inputs (j), unless
# do not allow outputs to be aliased to an inputs (j), unless
# a) that j'th input has been 'destroyed' by
# a) that j'th input has been 'destroyed' by
...
@@ -1239,33 +1240,29 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
...
@@ -1239,33 +1240,29 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
j
=
fgraph
.
inputs
.
index
(
input_j
)
j
=
fgraph
.
inputs
.
index
(
input_j
)
if
wrapped_outputs
[
i
]
.
borrow
and
wrapped_inputs
[
j
]
.
borrow
:
if
wrapped_outputs
[
i
]
.
borrow
and
wrapped_inputs
[
j
]
.
borrow
:
fgraph
.
change_node_input
(
fgraph
.
change_node_input
(
"output"
,
*
output_client
,
i
,
view_op
(
original_out
),
view_op
(
fgraph
.
outputs
[
i
]),
reason
=
reason
,
reason
=
reason
,
)
)
break
break
else
:
else
:
fgraph
.
change_node_input
(
fgraph
.
change_node_input
(
"output"
,
*
output_client
,
i
,
deep_copy_op
(
original_out
),
deep_copy_op
(
fgraph
.
outputs
[
i
]),
reason
=
reason
,
reason
=
reason
,
)
)
break
break
elif
wrapped_outputs
[
i
]
.
borrow
:
elif
wrapped_outputs
[
i
]
.
borrow
:
fgraph
.
change_node_input
(
fgraph
.
change_node_input
(
"output"
,
*
output_client
,
i
,
view_op
(
original_out
),
view_op
(
fgraph
.
outputs
[
i
]),
reason
=
reason
,
reason
=
reason
,
)
)
break
break
else
:
else
:
fgraph
.
change_node_input
(
fgraph
.
change_node_input
(
"output"
,
*
output_client
,
i
,
deep_copy_op
(
original_out
),
deep_copy_op
(
fgraph
.
outputs
[
i
]),
reason
=
reason
,
reason
=
reason
,
)
)
break
break
...
...
pytensor/compile/profiling.py
浏览文件 @
9ba6d99f
...
@@ -16,20 +16,17 @@ import sys
...
@@ -16,20 +16,17 @@ import sys
import
time
import
time
from
collections
import
Counter
,
defaultdict
from
collections
import
Counter
,
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
Any
import
numpy
as
np
import
numpy
as
np
import
pytensor
import
pytensor
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.link.utils
import
get_destroy_dependencies
from
pytensor.link.utils
import
get_destroy_dependencies
if
TYPE_CHECKING
:
from
pytensor.graph.fg
import
FunctionGraph
@contextmanager
@contextmanager
def
extended_open
(
filename
,
mode
=
"r"
):
def
extended_open
(
filename
,
mode
=
"r"
):
if
filename
==
"<stdout>"
:
if
filename
==
"<stdout>"
:
...
@@ -1038,7 +1035,7 @@ class ProfileStats:
...
@@ -1038,7 +1035,7 @@ class ProfileStats:
executable_nodes
=
set
()
executable_nodes
=
set
()
for
var
in
fgraph
.
inputs
:
for
var
in
fgraph
.
inputs
:
for
c
,
_
in
fgraph
.
clients
[
var
]:
for
c
,
_
in
fgraph
.
clients
[
var
]:
if
c
!=
"output"
:
if
not
isinstance
(
c
.
op
,
Output
)
:
deps
=
c
.
inputs
+
destroy_dependencies
[
c
]
deps
=
c
.
inputs
+
destroy_dependencies
[
c
]
if
all
(
compute_map
[
v
][
0
]
for
v
in
deps
):
if
all
(
compute_map
[
v
][
0
]
for
v
in
deps
):
executable_nodes
.
add
(
c
)
executable_nodes
.
add
(
c
)
...
@@ -1166,7 +1163,7 @@ class ProfileStats:
...
@@ -1166,7 +1163,7 @@ class ProfileStats:
for
var
in
node
.
outputs
:
for
var
in
node
.
outputs
:
for
c
,
_
in
fgraph
.
clients
[
var
]:
for
c
,
_
in
fgraph
.
clients
[
var
]:
if
c
!=
"output"
:
if
not
isinstance
(
c
.
op
,
Output
)
:
deps
=
c
.
inputs
+
destroy_dependencies
[
c
]
deps
=
c
.
inputs
+
destroy_dependencies
[
c
]
if
all
(
compute_map
[
v
][
0
]
for
v
in
deps
):
if
all
(
compute_map
[
v
][
0
]
for
v
in
deps
):
new_exec_nodes
.
add
(
c
)
new_exec_nodes
.
add
(
c
)
...
...
pytensor/graph/destroyhandler.py
浏览文件 @
9ba6d99f
...
@@ -11,6 +11,7 @@ import pytensor
...
@@ -11,6 +11,7 @@ import pytensor
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.features
import
AlreadyThere
,
Bookkeeper
from
pytensor.graph.features
import
AlreadyThere
,
Bookkeeper
from
pytensor.graph.fg
import
Output
from
pytensor.graph.utils
import
InconsistencyError
from
pytensor.graph.utils
import
InconsistencyError
from
pytensor.misc.ordered_set
import
OrderedSet
from
pytensor.misc.ordered_set
import
OrderedSet
...
@@ -401,8 +402,6 @@ class DestroyHandler(Bookkeeper):
...
@@ -401,8 +402,6 @@ class DestroyHandler(Bookkeeper):
def
recursive_destroys_finder
(
protected_var
):
def
recursive_destroys_finder
(
protected_var
):
# protected_var is the idx'th input of app.
# protected_var is the idx'th input of app.
for
app
,
idx
in
fgraph
.
clients
[
protected_var
]:
for
app
,
idx
in
fgraph
.
clients
[
protected_var
]:
if
app
==
"output"
:
continue
destroy_maps
=
app
.
op
.
destroy_map
.
values
()
destroy_maps
=
app
.
op
.
destroy_map
.
values
()
# If True means that the apply node, destroys the protected_var.
# If True means that the apply node, destroys the protected_var.
if
idx
in
[
dmap
for
sublist
in
destroy_maps
for
dmap
in
sublist
]:
if
idx
in
[
dmap
for
sublist
in
destroy_maps
for
dmap
in
sublist
]:
...
@@ -578,7 +577,7 @@ class DestroyHandler(Bookkeeper):
...
@@ -578,7 +577,7 @@ class DestroyHandler(Bookkeeper):
app.inputs[i] changed from old_r to new_r.
app.inputs[i] changed from old_r to new_r.
"""
"""
if
app
==
"output"
:
if
isinstance
(
app
.
op
,
Output
)
:
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph.
# considered 'outputs' of the graph.
pass
pass
...
...
pytensor/graph/fg.py
浏览文件 @
9ba6d99f
差异被折叠。
点击展开。
pytensor/graph/rewriting/basic.py
浏览文件 @
9ba6d99f
...
@@ -30,7 +30,7 @@ from pytensor.graph.basic import (
...
@@ -30,7 +30,7 @@ from pytensor.graph.basic import (
vars_between
,
vars_between
,
)
)
from
pytensor.graph.features
import
AlreadyThere
,
Feature
,
NodeFinder
from
pytensor.graph.features
import
AlreadyThere
,
Feature
,
NodeFinder
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.utils
import
AssocList
,
InconsistencyError
from
pytensor.graph.utils
import
AssocList
,
InconsistencyError
from
pytensor.misc.ordered_set
import
OrderedSet
from
pytensor.misc.ordered_set
import
OrderedSet
...
@@ -738,7 +738,7 @@ class MergeOptimizer(GraphRewriter):
...
@@ -738,7 +738,7 @@ class MergeOptimizer(GraphRewriter):
if
any
(
if
any
(
i
in
flatten
(
c
.
op
.
destroy_map
.
values
())
i
in
flatten
(
c
.
op
.
destroy_map
.
values
())
for
c
,
i
in
clients
for
c
,
i
in
clients
if
c
!=
"output"
and
c
.
op
.
destroy_map
if
c
.
op
.
destroy_map
):
):
continue
continue
...
@@ -1612,8 +1612,6 @@ class PatternNodeRewriter(NodeRewriter):
...
@@ -1612,8 +1612,6 @@ class PatternNodeRewriter(NodeRewriter):
if
get_nodes
and
self
.
get_nodes
is
not
None
:
if
get_nodes
and
self
.
get_nodes
is
not
None
:
for
real_node
in
self
.
get_nodes
(
fgraph
,
node
):
for
real_node
in
self
.
get_nodes
(
fgraph
,
node
):
if
real_node
==
"output"
:
continue
ret
=
self
.
transform
(
fgraph
,
real_node
,
get_nodes
=
False
)
ret
=
self
.
transform
(
fgraph
,
real_node
,
get_nodes
=
False
)
if
ret
is
not
False
and
ret
is
not
None
:
if
ret
is
not
False
and
ret
is
not
None
:
return
dict
(
zip
(
real_node
.
outputs
,
ret
))
return
dict
(
zip
(
real_node
.
outputs
,
ret
))
...
@@ -2399,7 +2397,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
...
@@ -2399,7 +2397,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
if
self
.
tracks_on_change_inputs
:
if
self
.
tracks_on_change_inputs
:
def
chin_
(
node
,
i
,
r
,
new_r
,
reason
):
def
chin_
(
node
,
i
,
r
,
new_r
,
reason
):
if
node
is
not
current_node
and
not
isinstance
(
node
,
str
):
if
node
is
not
current_node
and
not
isinstance
(
node
.
op
,
Output
):
q
.
append
(
node
)
q
.
append
(
node
)
chin
=
chin_
chin
=
chin_
...
...
pytensor/graph/rewriting/utils.py
浏览文件 @
9ba6d99f
import
copy
import
copy
from
collections.abc
import
Generator
,
Sequence
from
collections.abc
import
Generator
,
Sequence
from
typing
import
TYPE_CHECKING
,
Optional
,
cast
from
typing
import
TYPE_CHECKING
,
Optional
import
pytensor
import
pytensor
from
pytensor.graph.basic
import
(
from
pytensor.graph.basic
import
(
...
@@ -10,7 +10,7 @@ from pytensor.graph.basic import (
...
@@ -10,7 +10,7 @@ from pytensor.graph.basic import (
graph_inputs
,
graph_inputs
,
vars_between
,
vars_between
,
)
)
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
...
@@ -230,11 +230,9 @@ def get_clients_at_depth(
...
@@ -230,11 +230,9 @@ def get_clients_at_depth(
for
var
in
node
.
outputs
:
for
var
in
node
.
outputs
:
if
depth
>
0
:
if
depth
>
0
:
for
out_node
,
_
in
fgraph
.
clients
[
var
]:
for
out_node
,
_
in
fgraph
.
clients
[
var
]:
if
out_node
==
"output"
:
if
isinstance
(
out_node
.
op
,
Output
)
:
continue
continue
yield
from
get_clients_at_depth
(
yield
from
get_clients_at_depth
(
fgraph
,
out_node
,
depth
-
1
)
fgraph
,
cast
(
Apply
,
out_node
),
depth
-
1
)
else
:
else
:
assert
var
.
owner
is
not
None
assert
var
.
owner
is
not
None
yield
var
.
owner
yield
var
.
owner
pytensor/link/c/basic.py
浏览文件 @
9ba6d99f
...
@@ -354,9 +354,7 @@ def get_c_declare(fgraph, r, name, sub):
...
@@ -354,9 +354,7 @@ def get_c_declare(fgraph, r, name, sub):
# it means they need `r`'s dtype to be declared, so
# it means they need `r`'s dtype to be declared, so
# we have to pass `check_input=True` to `c_declare`.
# we have to pass `check_input=True` to `c_declare`.
if
any
(
if
any
(
getattr
(
c
.
op
,
"check_input"
,
config
.
check_input
)
getattr
(
c
.
op
,
"check_input"
,
config
.
check_input
)
for
(
c
,
_
)
in
fgraph
.
clients
[
r
]
for
(
c
,
_
)
in
fgraph
.
clients
[
r
]
if
not
isinstance
(
c
,
str
)
)
or
(
r
.
owner
and
getattr
(
r
.
owner
.
op
,
"check_input"
,
config
.
check_input
)):
)
or
(
r
.
owner
and
getattr
(
r
.
owner
.
op
,
"check_input"
,
config
.
check_input
)):
c_declare
=
r
.
type
.
c_declare
(
name
,
sub
,
True
)
c_declare
=
r
.
type
.
c_declare
(
name
,
sub
,
True
)
else
:
else
:
...
...
pytensor/link/vm.py
浏览文件 @
9ba6d99f
...
@@ -954,7 +954,6 @@ class VMLinker(LocalLinker):
...
@@ -954,7 +954,6 @@ class VMLinker(LocalLinker):
if
k
.
owner
and
self
.
fgraph
.
clients
[
k
]:
if
k
.
owner
and
self
.
fgraph
.
clients
[
k
]:
ls
=
[]
ls
=
[]
for
cl
in
self
.
fgraph
.
clients
[
k
]:
for
cl
in
self
.
fgraph
.
clients
[
k
]:
if
cl
[
0
]
!=
"output"
:
ls
+=
cl
[
0
]
.
outputs
ls
+=
cl
[
0
]
.
outputs
dependencies
[
k
]
+=
ls
dependencies
[
k
]
+=
ls
return
dependencies
return
dependencies
...
...
pytensor/printing.py
浏览文件 @
9ba6d99f
...
@@ -437,10 +437,15 @@ N.B.:
...
@@ -437,10 +437,15 @@ N.B.:
for
out
in
inner_outputs
:
for
out
in
inner_outputs
:
if
(
if
(
isinstance
(
getattr
(
out
.
owner
,
"op"
,
None
),
HasInnerGraph
)
out
.
owner
is
not
None
or
hasattr
(
getattr
(
out
.
owner
,
"op"
,
None
),
"scalar_op"
)
and
(
and
isinstance
(
out
.
owner
.
op
.
scalar_op
,
HasInnerGraph
)
isinstance
(
out
.
owner
.
op
,
HasInnerGraph
)
)
and
out
not
in
inner_graph_vars
:
or
isinstance
(
getattr
(
out
.
owner
.
op
,
"scalar_op"
,
None
),
HasInnerGraph
)
)
and
out
not
in
inner_graph_vars
):
inner_graph_vars
.
append
(
out
)
inner_graph_vars
.
append
(
out
)
_debugprint
(
_debugprint
(
...
...
pytensor/scan/rewriting.py
浏览文件 @
9ba6d99f
...
@@ -27,7 +27,7 @@ from pytensor.graph.basic import (
...
@@ -27,7 +27,7 @@ from pytensor.graph.basic import (
)
)
from
pytensor.graph.destroyhandler
import
DestroyHandler
from
pytensor.graph.destroyhandler
import
DestroyHandler
from
pytensor.graph.features
import
ReplaceValidate
from
pytensor.graph.features
import
ReplaceValidate
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.op
import
compute_test_value
from
pytensor.graph.op
import
compute_test_value
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.rewriting.basic
import
(
from
pytensor.graph.rewriting.basic
import
(
...
@@ -1303,7 +1303,7 @@ def scan_save_mem(fgraph, node):
...
@@ -1303,7 +1303,7 @@ def scan_save_mem(fgraph, node):
for
cl
,
_
in
fgraph
.
clients
[
out
]:
for
cl
,
_
in
fgraph
.
clients
[
out
]:
# 2.1 outputs of the function
# 2.1 outputs of the function
# => output needs all its intermediate values
# => output needs all its intermediate values
if
isinstance
(
cl
,
str
):
if
isinstance
(
cl
.
op
,
Output
):
# if the node is actually an output, then
# if the node is actually an output, then
# we need to store the entire thing
# we need to store the entire thing
global_nsteps
=
None
global_nsteps
=
None
...
@@ -1412,7 +1412,7 @@ def scan_save_mem(fgraph, node):
...
@@ -1412,7 +1412,7 @@ def scan_save_mem(fgraph, node):
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
# look at all its clients
# look at all its clients
for
cl
,
_
in
fgraph
.
clients
[
out
]:
for
cl
,
_
in
fgraph
.
clients
[
out
]:
if
isinstance
(
cl
,
str
):
if
isinstance
(
cl
.
op
,
Output
):
store_steps
[
i
]
=
0
store_steps
[
i
]
=
0
break
break
elif
not
isinstance
(
cl
.
op
,
Subtensor
):
elif
not
isinstance
(
cl
.
op
,
Subtensor
):
...
@@ -2309,7 +2309,7 @@ def scan_push_out_dot1(fgraph, node):
...
@@ -2309,7 +2309,7 @@ def scan_push_out_dot1(fgraph, node):
and
isinstance
(
out
.
owner
.
op
.
scalar_op
,
ps
.
Add
)
and
isinstance
(
out
.
owner
.
op
.
scalar_op
,
ps
.
Add
)
and
inp
in
out
.
owner
.
inputs
and
inp
in
out
.
owner
.
inputs
and
len
(
fgraph
.
clients
[
outer_out
])
==
1
and
len
(
fgraph
.
clients
[
outer_out
])
==
1
and
not
isinstance
(
fgraph
.
clients
[
outer_out
][
0
][
0
],
str
)
and
not
isinstance
(
fgraph
.
clients
[
outer_out
][
0
][
0
],
Output
)
and
isinstance
(
fgraph
.
clients
[
outer_out
][
0
][
0
]
.
op
,
Subtensor
)
and
isinstance
(
fgraph
.
clients
[
outer_out
][
0
][
0
]
.
op
,
Subtensor
)
and
fgraph
.
clients
[
outer_out
][
0
][
0
]
.
op
.
idx_list
==
(
-
1
,)
and
fgraph
.
clients
[
outer_out
][
0
][
0
]
.
op
.
idx_list
==
(
-
1
,)
):
):
...
...
pytensor/tensor/basic.py
浏览文件 @
9ba6d99f
...
@@ -24,7 +24,7 @@ from pytensor import scalar as ps
...
@@ -24,7 +24,7 @@ from pytensor import scalar as ps
from
pytensor.gradient
import
DisconnectedType
,
grad_undefined
from
pytensor.gradient
import
DisconnectedType
,
grad_undefined
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
equal_computations
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
equal_computations
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.rewriting.db
import
EquilibriumDB
from
pytensor.graph.rewriting.db
import
EquilibriumDB
...
@@ -1681,7 +1681,7 @@ class Alloc(COp):
...
@@ -1681,7 +1681,7 @@ class Alloc(COp):
return
False
return
False
for
client
,
idx
in
clients
:
for
client
,
idx
in
clients
:
if
client
==
"output"
:
if
isinstance
(
client
.
op
,
Output
)
:
# If the output is a constant, it will have to be deepcopied
# If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold.
# each time the function is called. So we do not fold.
return
False
return
False
...
...
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
9ba6d99f
...
@@ -31,15 +31,12 @@ def is_rv_used_in_graph(base_rv, node, fgraph):
...
@@ -31,15 +31,12 @@ def is_rv_used_in_graph(base_rv, node, fgraph):
TODO: We should apply all the shape rewrites before these rewrites, since
TODO: We should apply all the shape rewrites before these rewrites, since
that would properly remove the unnecessary dependencies on `base_rv` (when
that would properly remove the unnecessary dependencies on `base_rv` (when
possible).
possible).
"""
"""
return
any
(
def
_node_check
(
n
,
i
):
n
if
n
==
"output"
:
for
n
,
i
in
fgraph
.
clients
.
get
(
base_rv
,
())
n
=
fgraph
.
outputs
[
i
]
.
owner
if
not
(
n
is
node
or
isinstance
(
n
.
op
,
Shape
|
Shape_i
))
return
n
==
node
or
isinstance
(
n
.
op
,
Shape
|
Shape_i
)
)
return
not
all
(
_node_check
(
n
,
i
)
for
n
,
i
in
fgraph
.
clients
.
get
(
base_rv
,
()))
@node_rewriter
([
RandomVariable
],
inplace
=
True
)
@node_rewriter
([
RandomVariable
],
inplace
=
True
)
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
9ba6d99f
...
@@ -14,7 +14,7 @@ from pytensor.configdefaults import config
...
@@ -14,7 +14,7 @@ from pytensor.configdefaults import config
from
pytensor.graph
import
FunctionGraph
from
pytensor.graph
import
FunctionGraph
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
ancestors
,
io_toposort
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
ancestors
,
io_toposort
from
pytensor.graph.features
import
ReplaceValidate
from
pytensor.graph.features
import
ReplaceValidate
from
pytensor.graph.fg
import
ApplyOr
Output
from
pytensor.graph.fg
import
Output
from
pytensor.graph.rewriting.basic
import
(
from
pytensor.graph.rewriting.basic
import
(
EquilibriumGraphRewriter
,
EquilibriumGraphRewriter
,
GraphRewriter
,
GraphRewriter
,
...
@@ -688,7 +688,7 @@ class FusionOptimizer(GraphRewriter):
...
@@ -688,7 +688,7 @@ class FusionOptimizer(GraphRewriter):
"""
"""
FUSEABLE_MAPPING
=
defaultdict
[
Variable
,
list
[
Apply
]]
FUSEABLE_MAPPING
=
defaultdict
[
Variable
,
list
[
Apply
]]
UNFUSEABLE_MAPPING
=
defaultdict
[
Variable
,
set
[
Apply
OrOutput
]]
UNFUSEABLE_MAPPING
=
defaultdict
[
Variable
,
set
[
Apply
]]
def
initialize_fuseable_mappings
(
def
initialize_fuseable_mappings
(
*
,
fg
:
FunctionGraph
*
,
fg
:
FunctionGraph
...
@@ -727,7 +727,6 @@ class FusionOptimizer(GraphRewriter):
...
@@ -727,7 +727,6 @@ class FusionOptimizer(GraphRewriter):
for
client
,
_
in
clients
:
for
client
,
_
in
clients
:
if
(
if
(
out_maybe_fuseable
out_maybe_fuseable
and
not
isinstance
(
client
,
str
)
# "output"
and
isinstance
(
client
.
op
,
Elemwise
)
and
isinstance
(
client
.
op
,
Elemwise
)
# and not isinstance(client.op.scalar_op, ps.Composite)
# and not isinstance(client.op.scalar_op, ps.Composite)
and
len
(
client
.
outputs
)
==
1
and
len
(
client
.
outputs
)
==
1
...
@@ -841,7 +840,7 @@ class FusionOptimizer(GraphRewriter):
...
@@ -841,7 +840,7 @@ class FusionOptimizer(GraphRewriter):
implied_unfuseable_clients
=
{
implied_unfuseable_clients
=
{
c
c
for
client
in
unfuseable_clients_clone
.
get
(
next_out
,
())
for
client
in
unfuseable_clients_clone
.
get
(
next_out
,
())
if
not
isinstance
(
client
,
str
)
# "output"
if
not
isinstance
(
client
.
op
,
Output
)
for
c
in
client
.
outputs
for
c
in
client
.
outputs
}
}
...
...
pytensor/tensor/rewriting/linalg.py
浏览文件 @
9ba6d99f
...
@@ -299,8 +299,6 @@ def local_det_chol(fgraph, node):
...
@@ -299,8 +299,6 @@ def local_det_chol(fgraph, node):
"""
"""
(
x
,)
=
node
.
inputs
(
x
,)
=
node
.
inputs
for
cl
,
xpos
in
fgraph
.
clients
[
x
]:
for
cl
,
xpos
in
fgraph
.
clients
[
x
]:
if
cl
==
"output"
:
continue
if
isinstance
(
cl
.
op
,
Blockwise
)
and
isinstance
(
cl
.
op
.
core_op
,
Cholesky
):
if
isinstance
(
cl
.
op
,
Blockwise
)
and
isinstance
(
cl
.
op
.
core_op
,
Cholesky
):
L
=
cl
.
outputs
[
0
]
L
=
cl
.
outputs
[
0
]
return
[
prod
(
diagonal
(
L
,
axis1
=-
2
,
axis2
=-
1
)
**
2
,
axis
=-
1
)]
return
[
prod
(
diagonal
(
L
,
axis1
=-
2
,
axis2
=-
1
)
**
2
,
axis
=-
1
)]
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
9ba6d99f
...
@@ -1126,14 +1126,11 @@ class AlgebraicCanonizer(NodeRewriter):
...
@@ -1126,14 +1126,11 @@ class AlgebraicCanonizer(NodeRewriter):
# this canonized graph... if so, we do nothing and wait for
# this canonized graph... if so, we do nothing and wait for
# them to be transformed.
# them to be transformed.
for
c
,
c_idx
in
out_clients
:
for
c
,
c_idx
in
out_clients
:
if
c
==
"output"
:
continue
while
(
while
(
isinstance
(
getattr
(
c
,
"op"
,
None
),
DimShuffle
)
isinstance
(
c
.
op
,
DimShuffle
)
and
len
(
fgraph
.
clients
[
c
.
outputs
[
0
]])
<=
1
and
len
(
fgraph
.
clients
[
c
.
outputs
[
0
]])
<=
1
):
):
c
=
fgraph
.
clients
[
c
.
outputs
[
0
]][
0
][
0
]
c
=
fgraph
.
clients
[
c
.
outputs
[
0
]][
0
][
0
]
if
getattr
(
c
,
"op"
,
""
)
in
[
self
.
main
,
self
.
inverse
,
self
.
reciprocal
]:
if
c
.
op
in
[
self
.
main
,
self
.
inverse
,
self
.
reciprocal
]:
return
False
return
False
# Here we make the canonical version of the graph around this node
# Here we make the canonical version of the graph around this node
...
...
pytensor/tensor/rewriting/shape.py
浏览文件 @
9ba6d99f
...
@@ -401,7 +401,7 @@ class ShapeFeature(Feature):
...
@@ -401,7 +401,7 @@ class ShapeFeature(Feature):
merged_shape
.
append
(
other_shape
[
i
])
merged_shape
.
append
(
other_shape
[
i
])
elif
(
elif
(
ps
.
owner
ps
.
owner
and
isinstance
(
getattr
(
ps
.
owner
,
"op"
,
None
)
,
Shape_i
)
and
isinstance
(
ps
.
owner
.
op
,
Shape_i
)
and
ps
.
owner
.
op
.
i
==
i
and
ps
.
owner
.
op
.
i
==
i
and
ps
.
owner
.
inputs
[
0
]
in
(
r
,
other_r
)
and
ps
.
owner
.
inputs
[
0
]
in
(
r
,
other_r
)
):
):
...
@@ -602,7 +602,7 @@ class ShapeFeature(Feature):
...
@@ -602,7 +602,7 @@ class ShapeFeature(Feature):
# r is *scheduled*.
# r is *scheduled*.
# At that point, node is no longer a client of r, but of new_r
# At that point, node is no longer a client of r, but of new_r
for
shpnode
,
idx
in
fgraph
.
clients
[
r
]
+
[(
node
,
i
)]:
for
shpnode
,
idx
in
fgraph
.
clients
[
r
]
+
[(
node
,
i
)]:
if
isinstance
(
getattr
(
shpnode
,
"op"
,
None
)
,
Shape_i
):
if
isinstance
(
shpnode
.
op
,
Shape_i
):
idx
=
shpnode
.
op
.
i
idx
=
shpnode
.
op
.
i
repl
=
self
.
shape_of
[
new_r
][
idx
]
repl
=
self
.
shape_of
[
new_r
][
idx
]
if
repl
.
owner
is
shpnode
:
if
repl
.
owner
is
shpnode
:
...
@@ -1057,7 +1057,10 @@ def local_Shape_of_SpecifyShape(fgraph, node):
...
@@ -1057,7 +1057,10 @@ def local_Shape_of_SpecifyShape(fgraph, node):
specified_shape
=
node
.
inputs
[
0
]
specified_shape
=
node
.
inputs
[
0
]
if
not
isinstance
(
getattr
(
specified_shape
.
owner
,
"op"
,
None
),
SpecifyShape
):
if
not
(
specified_shape
.
owner
is
not
None
and
isinstance
(
specified_shape
.
owner
.
op
,
SpecifyShape
)
):
return
False
return
False
x
,
*
shape
=
specified_shape
.
owner
.
inputs
x
,
*
shape
=
specified_shape
.
owner
.
inputs
...
...
tests/graph/test_fg.py
浏览文件 @
9ba6d99f
...
@@ -6,7 +6,7 @@ import pytest
...
@@ -6,7 +6,7 @@ import pytest
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
NominalVariable
from
pytensor.graph.basic
import
NominalVariable
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.utils
import
MissingInputError
from
pytensor.graph.utils
import
MissingInputError
from
pytensor.printing
import
debugprint
from
pytensor.printing
import
debugprint
from
tests.graph.utils
import
(
from
tests.graph.utils
import
(
...
@@ -78,8 +78,13 @@ class TestFunctionGraph:
...
@@ -78,8 +78,13 @@ class TestFunctionGraph:
assert
fg
.
variables
==
{
var1
,
var2
,
var3
,
var4
}
assert
fg
.
variables
==
{
var1
,
var2
,
var3
,
var4
}
assert
fg
.
get_clients
(
var1
)
==
[(
var3
.
owner
,
0
)]
assert
fg
.
get_clients
(
var1
)
==
[(
var3
.
owner
,
0
)]
assert
fg
.
get_clients
(
var2
)
==
[(
var4
.
owner
,
1
)]
assert
fg
.
get_clients
(
var2
)
==
[(
var4
.
owner
,
1
)]
assert
fg
.
get_clients
(
var3
)
==
[(
"output"
,
0
),
(
var4
.
owner
,
0
)]
var3_clients
=
fg
.
get_clients
(
var3
)
assert
fg
.
get_clients
(
var4
)
==
[(
"output"
,
1
)]
assert
len
(
var3_clients
)
==
2
assert
var3_clients
[
0
][
0
]
.
op
==
Output
(
0
)
assert
var3_clients
[
1
]
==
(
var4
.
owner
,
0
)
var4_clients
=
fg
.
get_clients
(
var4
)
assert
len
(
var4_clients
)
==
1
assert
var4_clients
[
0
][
0
]
.
op
==
Output
(
1
)
varC
=
MyConstant
(
"varC"
)
varC
=
MyConstant
(
"varC"
)
var5
=
op1
(
var1
,
varC
)
var5
=
op1
(
var1
,
varC
)
...
@@ -208,8 +213,11 @@ class TestFunctionGraph:
...
@@ -208,8 +213,11 @@ class TestFunctionGraph:
fg
=
FunctionGraph
([
var1
,
var2
],
[
var3
,
var5
],
clone
=
False
)
fg
=
FunctionGraph
([
var1
,
var2
],
[
var3
,
var5
],
clone
=
False
)
var6
=
MyVariable2
(
"var6"
)
var6
=
MyVariable2
(
"var6"
)
[
out_client
]
=
[
cl
for
cl
,
_
in
fg
.
clients
[
fg
.
outputs
[
0
]]
if
isinstance
(
cl
.
op
,
Output
)
]
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
fg
.
change_node_input
(
"output"
,
1
,
var6
)
fg
.
change_node_input
(
out_client
,
0
,
var6
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
fg
.
change_node_input
(
var5
.
owner
,
1
,
var6
)
fg
.
change_node_input
(
var5
.
owner
,
1
,
var6
)
...
@@ -358,12 +366,13 @@ class TestFunctionGraph:
...
@@ -358,12 +366,13 @@ class TestFunctionGraph:
# TODO: What if the index value is greater than 1? It will throw an
# TODO: What if the index value is greater than 1? It will throw an
# `IndexError`, but that doesn't sound like anything we'd want.
# `IndexError`, but that doesn't sound like anything we'd want.
out_node
=
Output
(
idx
=
1
)
.
make_node
(
var4
)
with
pytest
.
raises
(
Exception
,
match
=
"Inconsistent clients list.*"
):
with
pytest
.
raises
(
Exception
,
match
=
"Inconsistent clients list.*"
):
fg
.
add_client
(
var4
,
(
"output"
,
1
))
fg
.
add_client
(
var4
,
(
out_node
,
0
))
fg
.
check_integrity
()
fg
.
check_integrity
()
fg
.
remove_client
(
var4
,
(
"output"
,
1
))
fg
.
remove_client
(
var4
,
(
out_node
,
0
))
with
pytest
.
raises
(
TypeError
,
match
=
"The first entry of.*"
):
with
pytest
.
raises
(
TypeError
,
match
=
"The first entry of.*"
):
fg
.
add_client
(
var4
,
(
None
,
0
))
fg
.
add_client
(
var4
,
(
None
,
0
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论