Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9ba6d99f
提交
9ba6d99f
authored
5月 30, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
7月 08, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace str "output" by a dummy Op in the clients of the FunctionGraph
上级
7f623fef
隐藏空白字符变更
内嵌
并排
正在显示
18 个修改的文件
包含
172 行增加
和
180 行删除
+172
-180
debugmode.py
pytensor/compile/debugmode.py
+8
-9
types.py
pytensor/compile/function/types.py
+15
-18
profiling.py
pytensor/compile/profiling.py
+4
-7
destroyhandler.py
pytensor/graph/destroyhandler.py
+2
-3
fg.py
pytensor/graph/fg.py
+88
-89
basic.py
pytensor/graph/rewriting/basic.py
+3
-5
utils.py
pytensor/graph/rewriting/utils.py
+4
-6
basic.py
pytensor/link/c/basic.py
+1
-3
vm.py
pytensor/link/vm.py
+1
-2
printing.py
pytensor/printing.py
+9
-4
rewriting.py
pytensor/scan/rewriting.py
+4
-4
basic.py
pytensor/tensor/basic.py
+2
-2
basic.py
pytensor/tensor/random/rewriting/basic.py
+5
-8
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+3
-4
linalg.py
pytensor/tensor/rewriting/linalg.py
+0
-2
math.py
pytensor/tensor/rewriting/math.py
+2
-5
shape.py
pytensor/tensor/rewriting/shape.py
+6
-3
test_fg.py
tests/graph/test_fg.py
+15
-6
没有找到文件。
pytensor/compile/debugmode.py
浏览文件 @
9ba6d99f
...
@@ -30,6 +30,7 @@ from pytensor.configdefaults import config
...
@@ -30,6 +30,7 @@ from pytensor.configdefaults import config
from
pytensor.graph.basic
import
Variable
,
io_toposort
from
pytensor.graph.basic
import
Variable
,
io_toposort
from
pytensor.graph.destroyhandler
import
DestroyHandler
from
pytensor.graph.destroyhandler
import
DestroyHandler
from
pytensor.graph.features
import
AlreadyThere
,
BadOptimization
from
pytensor.graph.features
import
AlreadyThere
,
BadOptimization
from
pytensor.graph.fg
import
Output
from
pytensor.graph.op
import
HasInnerGraph
,
Op
from
pytensor.graph.op
import
HasInnerGraph
,
Op
from
pytensor.graph.utils
import
InconsistencyError
,
MethodNotDefined
from
pytensor.graph.utils
import
InconsistencyError
,
MethodNotDefined
from
pytensor.link.basic
import
Container
,
LocalLinker
from
pytensor.link.basic
import
Container
,
LocalLinker
...
@@ -628,7 +629,9 @@ def _is_used_in_graph(fgraph, var):
...
@@ -628,7 +629,9 @@ def _is_used_in_graph(fgraph, var):
True if `var` is used by another node in the graph.
True if `var` is used by another node in the graph.
"""
"""
return
not
(
fgraph
.
clients
[
var
]
==
[(
"output"
,
1
)]
or
fgraph
.
clients
[
var
]
==
[])
return
any
(
client
for
client
,
_
in
fgraph
.
clients
[
var
]
if
not
isinstance
(
client
.
op
,
Output
)
)
def
_check_strides_match
(
a
,
b
,
warn_err
,
op
):
def
_check_strides_match
(
a
,
b
,
warn_err
,
op
):
...
@@ -977,7 +980,7 @@ def _check_preallocated_output(
...
@@ -977,7 +980,7 @@ def _check_preallocated_output(
# disable memory checks in that mode, since they were already run.
# disable memory checks in that mode, since they were already run.
try
:
try
:
changed_inner_mode
=
False
changed_inner_mode
=
False
if
isinstance
(
getattr
(
node
,
"op"
,
None
)
,
HasInnerGraph
):
if
isinstance
(
node
.
op
,
HasInnerGraph
):
fn
=
node
.
op
.
fn
fn
=
node
.
op
.
fn
if
not
(
fn
and
hasattr
(
fn
,
"maker"
)
and
hasattr
(
fn
.
maker
,
"mode"
)):
if
not
(
fn
and
hasattr
(
fn
,
"maker"
)
and
hasattr
(
fn
.
maker
,
"mode"
)):
_logger
.
warning
(
f
"Expected pytensor function not found in {node.op}.fn"
)
_logger
.
warning
(
f
"Expected pytensor function not found in {node.op}.fn"
)
...
@@ -1132,18 +1135,14 @@ class _FunctionGraphEvent:
...
@@ -1132,18 +1135,14 @@ class _FunctionGraphEvent:
def
__init__
(
self
,
kind
,
node
,
idx
=
None
,
reason
=
None
):
def
__init__
(
self
,
kind
,
node
,
idx
=
None
,
reason
=
None
):
self
.
kind
=
kind
self
.
kind
=
kind
if
node
==
"output"
:
self
.
node
=
node
self
.
node
=
"output"
self
.
op
=
node
.
op
self
.
op
=
"output"
else
:
self
.
node
=
node
self
.
op
=
node
.
op
self
.
idx
=
idx
self
.
idx
=
idx
self
.
reason
=
str
(
reason
)
self
.
reason
=
str
(
reason
)
def
__str__
(
self
):
def
__str__
(
self
):
if
self
.
kind
==
"change"
:
if
self
.
kind
==
"change"
:
if
self
.
op
!=
"output"
:
if
not
isinstance
(
self
.
op
,
Output
)
:
msg
=
str
(
len
(
self
.
node
.
inputs
))
msg
=
str
(
len
(
self
.
node
.
inputs
))
else
:
else
:
msg
=
""
msg
=
""
...
...
pytensor/compile/function/types.py
浏览文件 @
9ba6d99f
...
@@ -78,8 +78,6 @@ def view_tree_set(fgraph, v, treeset):
...
@@ -78,8 +78,6 @@ def view_tree_set(fgraph, v, treeset):
"""
"""
treeset
.
add
(
v
)
treeset
.
add
(
v
)
for
cl
,
v_input_pos_to_cl
in
fgraph
.
clients
[
v
]:
for
cl
,
v_input_pos_to_cl
in
fgraph
.
clients
[
v
]:
if
cl
==
"output"
:
continue
vmap
=
cl
.
op
.
view_map
vmap
=
cl
.
op
.
view_map
dmap
=
cl
.
op
.
destroy_map
dmap
=
cl
.
op
.
destroy_map
for
opos
,
iposlist
in
chain
(
vmap
.
items
(),
dmap
.
items
()):
for
opos
,
iposlist
in
chain
(
vmap
.
items
(),
dmap
.
items
()):
...
@@ -1202,8 +1200,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
...
@@ -1202,8 +1200,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
has_destroyers_attr
=
hasattr
(
fgraph
,
"has_destroyers"
)
has_destroyers_attr
=
hasattr
(
fgraph
,
"has_destroyers"
)
for
i
in
range
(
len
(
fgraph
.
outputs
)):
for
i
in
range
(
len
(
fgraph
.
outputs
)):
original_out
=
fgraph
.
outputs
[
i
]
output_client
=
fgraph
.
get_output_client
(
i
)
views_of_output_i
=
set
()
views_of_output_i
=
set
()
view_tree_set
(
fgraph
,
alias_root
(
fgraph
.
outputs
[
i
]
),
views_of_output_i
)
view_tree_set
(
fgraph
,
alias_root
(
original_out
),
views_of_output_i
)
copied
=
False
copied
=
False
# do not allow outputs to be aliased
# do not allow outputs to be aliased
for
j
in
range
(
i
+
1
,
len
(
fgraph
.
outputs
)):
for
j
in
range
(
i
+
1
,
len
(
fgraph
.
outputs
)):
...
@@ -1212,16 +1213,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
...
@@ -1212,16 +1213,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
if
fgraph
.
outputs
[
j
]
in
views_of_output_i
:
if
fgraph
.
outputs
[
j
]
in
views_of_output_i
:
if
wrapped_outputs
[
i
]
.
borrow
and
wrapped_outputs
[
j
]
.
borrow
:
if
wrapped_outputs
[
i
]
.
borrow
and
wrapped_outputs
[
j
]
.
borrow
:
fgraph
.
change_node_input
(
fgraph
.
change_node_input
(
"output"
,
i
,
view_op
(
fgraph
.
outputs
[
i
]
),
reason
=
reason
*
output_client
,
view_op
(
original_out
),
reason
=
reason
)
)
else
:
else
:
fgraph
.
change_node_input
(
fgraph
.
change_node_input
(
"output"
,
i
,
deep_copy_op
(
fgraph
.
outputs
[
i
]
),
reason
=
reason
*
output_client
,
deep_copy_op
(
original_out
),
reason
=
reason
)
)
copied
=
True
copied
=
True
break
break
if
not
copied
:
if
not
copied
:
# no-break
for
input_j
in
all_graph_inputs
:
for
input_j
in
all_graph_inputs
:
# do not allow outputs to be aliased to an inputs (j), unless
# do not allow outputs to be aliased to an inputs (j), unless
# a) that j'th input has been 'destroyed' by
# a) that j'th input has been 'destroyed' by
...
@@ -1239,33 +1240,29 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
...
@@ -1239,33 +1240,29 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
j
=
fgraph
.
inputs
.
index
(
input_j
)
j
=
fgraph
.
inputs
.
index
(
input_j
)
if
wrapped_outputs
[
i
]
.
borrow
and
wrapped_inputs
[
j
]
.
borrow
:
if
wrapped_outputs
[
i
]
.
borrow
and
wrapped_inputs
[
j
]
.
borrow
:
fgraph
.
change_node_input
(
fgraph
.
change_node_input
(
"output"
,
*
output_client
,
i
,
view_op
(
original_out
),
view_op
(
fgraph
.
outputs
[
i
]),
reason
=
reason
,
reason
=
reason
,
)
)
break
break
else
:
else
:
fgraph
.
change_node_input
(
fgraph
.
change_node_input
(
"output"
,
*
output_client
,
i
,
deep_copy_op
(
original_out
),
deep_copy_op
(
fgraph
.
outputs
[
i
]),
reason
=
reason
,
reason
=
reason
,
)
)
break
break
elif
wrapped_outputs
[
i
]
.
borrow
:
elif
wrapped_outputs
[
i
]
.
borrow
:
fgraph
.
change_node_input
(
fgraph
.
change_node_input
(
"output"
,
*
output_client
,
i
,
view_op
(
original_out
),
view_op
(
fgraph
.
outputs
[
i
]),
reason
=
reason
,
reason
=
reason
,
)
)
break
break
else
:
else
:
fgraph
.
change_node_input
(
fgraph
.
change_node_input
(
"output"
,
*
output_client
,
i
,
deep_copy_op
(
original_out
),
deep_copy_op
(
fgraph
.
outputs
[
i
]),
reason
=
reason
,
reason
=
reason
,
)
)
break
break
...
...
pytensor/compile/profiling.py
浏览文件 @
9ba6d99f
...
@@ -16,20 +16,17 @@ import sys
...
@@ -16,20 +16,17 @@ import sys
import
time
import
time
from
collections
import
Counter
,
defaultdict
from
collections
import
Counter
,
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
Any
import
numpy
as
np
import
numpy
as
np
import
pytensor
import
pytensor
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.link.utils
import
get_destroy_dependencies
from
pytensor.link.utils
import
get_destroy_dependencies
if
TYPE_CHECKING
:
from
pytensor.graph.fg
import
FunctionGraph
@contextmanager
@contextmanager
def
extended_open
(
filename
,
mode
=
"r"
):
def
extended_open
(
filename
,
mode
=
"r"
):
if
filename
==
"<stdout>"
:
if
filename
==
"<stdout>"
:
...
@@ -1038,7 +1035,7 @@ class ProfileStats:
...
@@ -1038,7 +1035,7 @@ class ProfileStats:
executable_nodes
=
set
()
executable_nodes
=
set
()
for
var
in
fgraph
.
inputs
:
for
var
in
fgraph
.
inputs
:
for
c
,
_
in
fgraph
.
clients
[
var
]:
for
c
,
_
in
fgraph
.
clients
[
var
]:
if
c
!=
"output"
:
if
not
isinstance
(
c
.
op
,
Output
)
:
deps
=
c
.
inputs
+
destroy_dependencies
[
c
]
deps
=
c
.
inputs
+
destroy_dependencies
[
c
]
if
all
(
compute_map
[
v
][
0
]
for
v
in
deps
):
if
all
(
compute_map
[
v
][
0
]
for
v
in
deps
):
executable_nodes
.
add
(
c
)
executable_nodes
.
add
(
c
)
...
@@ -1166,7 +1163,7 @@ class ProfileStats:
...
@@ -1166,7 +1163,7 @@ class ProfileStats:
for
var
in
node
.
outputs
:
for
var
in
node
.
outputs
:
for
c
,
_
in
fgraph
.
clients
[
var
]:
for
c
,
_
in
fgraph
.
clients
[
var
]:
if
c
!=
"output"
:
if
not
isinstance
(
c
.
op
,
Output
)
:
deps
=
c
.
inputs
+
destroy_dependencies
[
c
]
deps
=
c
.
inputs
+
destroy_dependencies
[
c
]
if
all
(
compute_map
[
v
][
0
]
for
v
in
deps
):
if
all
(
compute_map
[
v
][
0
]
for
v
in
deps
):
new_exec_nodes
.
add
(
c
)
new_exec_nodes
.
add
(
c
)
...
...
pytensor/graph/destroyhandler.py
浏览文件 @
9ba6d99f
...
@@ -11,6 +11,7 @@ import pytensor
...
@@ -11,6 +11,7 @@ import pytensor
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.features
import
AlreadyThere
,
Bookkeeper
from
pytensor.graph.features
import
AlreadyThere
,
Bookkeeper
from
pytensor.graph.fg
import
Output
from
pytensor.graph.utils
import
InconsistencyError
from
pytensor.graph.utils
import
InconsistencyError
from
pytensor.misc.ordered_set
import
OrderedSet
from
pytensor.misc.ordered_set
import
OrderedSet
...
@@ -401,8 +402,6 @@ class DestroyHandler(Bookkeeper):
...
@@ -401,8 +402,6 @@ class DestroyHandler(Bookkeeper):
def
recursive_destroys_finder
(
protected_var
):
def
recursive_destroys_finder
(
protected_var
):
# protected_var is the idx'th input of app.
# protected_var is the idx'th input of app.
for
app
,
idx
in
fgraph
.
clients
[
protected_var
]:
for
app
,
idx
in
fgraph
.
clients
[
protected_var
]:
if
app
==
"output"
:
continue
destroy_maps
=
app
.
op
.
destroy_map
.
values
()
destroy_maps
=
app
.
op
.
destroy_map
.
values
()
# If True means that the apply node, destroys the protected_var.
# If True means that the apply node, destroys the protected_var.
if
idx
in
[
dmap
for
sublist
in
destroy_maps
for
dmap
in
sublist
]:
if
idx
in
[
dmap
for
sublist
in
destroy_maps
for
dmap
in
sublist
]:
...
@@ -578,7 +577,7 @@ class DestroyHandler(Bookkeeper):
...
@@ -578,7 +577,7 @@ class DestroyHandler(Bookkeeper):
app.inputs[i] changed from old_r to new_r.
app.inputs[i] changed from old_r to new_r.
"""
"""
if
app
==
"output"
:
if
isinstance
(
app
.
op
,
Output
)
:
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph.
# considered 'outputs' of the graph.
pass
pass
...
...
pytensor/graph/fg.py
浏览文件 @
9ba6d99f
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Iterable
,
Sequence
from
collections.abc
import
Iterable
,
Sequence
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Union
,
cast
from
typing
import
Any
,
Union
,
cast
import
pytensor
import
pytensor
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
...
@@ -19,15 +19,30 @@ from pytensor.graph.basic import (
...
@@ -19,15 +19,30 @@ from pytensor.graph.basic import (
)
)
from
pytensor.graph.basic
import
as_string
as
graph_as_string
from
pytensor.graph.basic
import
as_string
as
graph_as_string
from
pytensor.graph.features
import
AlreadyThere
,
Feature
,
ReplaceValidate
from
pytensor.graph.features
import
AlreadyThere
,
Feature
,
ReplaceValidate
from
pytensor.graph.op
import
Op
from
pytensor.graph.utils
import
MetaObject
,
MissingInputError
,
TestValueError
from
pytensor.graph.utils
import
MetaObject
,
MissingInputError
,
TestValueError
from
pytensor.misc.ordered_set
import
OrderedSet
from
pytensor.misc.ordered_set
import
OrderedSet
if
TYPE_CHECKING
:
ClientType
=
tuple
[
Apply
,
int
]
from
pytensor.graph.op
import
Op
ApplyOrOutput
=
Apply
|
Literal
[
"output"
]
ClientType
=
tuple
[
ApplyOrOutput
,
int
]
class
Output
(
Op
):
"""A dummy `Op` that represents an output variable in a `FunctionGraph`."""
__props__
=
(
"idx"
,)
def
__init__
(
self
,
idx
):
self
.
idx
=
idx
def
make_node
(
self
,
inp
):
return
Apply
(
self
,
[
inp
],
[])
def
perform
(
self
,
node
,
inputs
,
outputs
):
raise
RuntimeError
(
"Output Ops should never be evaluated"
)
def
__str__
(
self
):
return
f
"output[{self.idx}]"
class
FunctionGraph
(
MetaObject
):
class
FunctionGraph
(
MetaObject
):
...
@@ -157,7 +172,7 @@ class FunctionGraph(MetaObject):
...
@@ -157,7 +172,7 @@ class FunctionGraph(MetaObject):
"""Add a new variable as an output to this `FunctionGraph`."""
"""Add a new variable as an output to this `FunctionGraph`."""
self
.
outputs
.
append
(
var
)
self
.
outputs
.
append
(
var
)
self
.
import_var
(
var
,
reason
=
reason
,
import_missing
=
import_missing
)
self
.
import_var
(
var
,
reason
=
reason
,
import_missing
=
import_missing
)
self
.
clients
[
var
]
.
append
((
"output"
,
len
(
self
.
outputs
)
-
1
))
self
.
clients
[
var
]
.
append
((
Output
(
len
(
self
.
outputs
)
-
1
)
.
make_node
(
var
),
0
))
def
add_input
(
self
,
var
:
Variable
,
check
:
bool
=
True
)
->
None
:
def
add_input
(
self
,
var
:
Variable
,
check
:
bool
=
True
)
->
None
:
"""Add a new variable as an input to this `FunctionGraph`.
"""Add a new variable as an input to this `FunctionGraph`.
...
@@ -198,10 +213,8 @@ class FunctionGraph(MetaObject):
...
@@ -198,10 +213,8 @@ class FunctionGraph(MetaObject):
A ``(node, i)`` pair such that ``node.inputs[i]`` is `var`.
A ``(node, i)`` pair such that ``node.inputs[i]`` is `var`.
"""
"""
if
not
isinstance
(
new_client
[
0
],
Apply
)
and
new_client
[
0
]
!=
"output"
:
if
not
isinstance
(
new_client
[
0
],
Apply
):
raise
TypeError
(
raise
TypeError
(
"The first entry of `new_client` must be an `Apply` node"
)
'The first entry of `new_client` must be an `Apply` node or the string `"output"`'
)
self
.
clients
[
var
]
.
append
(
new_client
)
self
.
clients
[
var
]
.
append
(
new_client
)
def
remove_client
(
def
remove_client
(
...
@@ -278,6 +291,16 @@ class FunctionGraph(MetaObject):
...
@@ -278,6 +291,16 @@ class FunctionGraph(MetaObject):
if
remove_if_empty
:
if
remove_if_empty
:
del
clients
[
var
]
del
clients
[
var
]
def
get_output_client
(
self
,
i
:
int
)
->
ClientType
:
"""Get the dummy Output Op client to output i.
Raises lookup error if not found
"""
for
client
in
self
.
clients
[
self
.
outputs
[
i
]]:
if
isinstance
(
client
[
0
]
.
op
,
Output
)
and
client
[
0
]
.
op
.
idx
==
i
:
return
client
raise
LookupError
def
import_var
(
def
import_var
(
self
,
var
:
Variable
,
reason
:
str
|
None
=
None
,
import_missing
:
bool
=
False
self
,
var
:
Variable
,
reason
:
str
|
None
=
None
,
import_missing
:
bool
=
False
)
->
None
:
)
->
None
:
...
@@ -382,7 +405,7 @@ class FunctionGraph(MetaObject):
...
@@ -382,7 +405,7 @@ class FunctionGraph(MetaObject):
def
change_node_input
(
def
change_node_input
(
self
,
self
,
node
:
Apply
OrOutput
,
node
:
Apply
,
i
:
int
,
i
:
int
,
new_var
:
Variable
,
new_var
:
Variable
,
reason
:
str
|
None
=
None
,
reason
:
str
|
None
=
None
,
...
@@ -401,9 +424,7 @@ class FunctionGraph(MetaObject):
...
@@ -401,9 +424,7 @@ class FunctionGraph(MetaObject):
Parameters
Parameters
----------
----------
node
node
The node for which an input is to be changed. If the value is
The node for which an input is to be changed.
the string ``"output"`` then the ``self.outputs`` will be used
instead of ``node.inputs``.
i
i
The index in `node.inputs` that we want to change.
The index in `node.inputs` that we want to change.
new_var
new_var
...
@@ -417,27 +438,21 @@ class FunctionGraph(MetaObject):
...
@@ -417,27 +438,21 @@ class FunctionGraph(MetaObject):
narrowed and would otherwise fail this check.
narrowed and would otherwise fail this check.
"""
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if
node
==
"output"
:
r
=
node
.
inputs
[
i
]
r
=
self
.
outputs
[
i
]
if
check
and
not
r
.
type
.
is_super
(
new_var
.
type
):
raise
TypeError
(
f
"The type of the replacement ({new_var.type}) must be "
f
"compatible with the type of the original Variable ({r.type})."
)
self
.
outputs
[
i
]
=
new_var
else
:
assert
isinstance
(
node
,
Apply
)
r
=
node
.
inputs
[
i
]
if
check
and
not
r
.
type
.
is_super
(
new_var
.
type
):
raise
TypeError
(
f
"The type of the replacement ({new_var.type}) must be "
f
"compatible with the type of the original Variable ({r.type})."
)
node
.
inputs
[
i
]
=
new_var
if
r
is
new_var
:
if
r
is
new_var
:
return
return
if
check
and
not
r
.
type
.
is_super
(
new_var
.
type
):
raise
TypeError
(
f
"The type of the replacement ({new_var.type}) must be "
f
"compatible with the type of the original Variable ({r.type})."
)
node
.
inputs
[
i
]
=
new_var
if
isinstance
(
node
.
op
,
Output
):
self
.
outputs
[
node
.
op
.
idx
]
=
new_var
self
.
import_var
(
new_var
,
reason
=
reason
,
import_missing
=
import_missing
)
self
.
import_var
(
new_var
,
reason
=
reason
,
import_missing
=
import_missing
)
self
.
add_client
(
new_var
,
(
node
,
i
))
self
.
add_client
(
new_var
,
(
node
,
i
))
self
.
remove_client
(
r
,
(
node
,
i
),
reason
=
reason
)
self
.
remove_client
(
r
,
(
node
,
i
),
reason
=
reason
)
...
@@ -518,33 +533,6 @@ class FunctionGraph(MetaObject):
...
@@ -518,33 +533,6 @@ class FunctionGraph(MetaObject):
for
var
,
new_var
in
pairs
:
for
var
,
new_var
in
pairs
:
self
.
replace
(
var
,
new_var
,
**
kwargs
)
self
.
replace
(
var
,
new_var
,
**
kwargs
)
def
_remove_output
(
self
,
idx
:
int
):
"""Remove the output at index `idx` and update the indices in the clients entries.
`FunctionGraph.clients` contains entries like ``("output", i)`` under
each output variable in `FunctionGraph.outputs`. The ``i`` values
correspond to each output's location within the `FunctionGraph.outputs`
list, so, when an output is removed from the graph, all these entries
need to be updated. This method performs those updates.
TODO: We could track these entries in a new instance attribute and make
them lists, then each could be updated in-place very easily. This
seems fine, because the `FunctionGraph.clients` ``dict`` and list in
which they're contained are already being updated in-place.
"""
old_idx_mappings
=
tuple
((
out
,
i
)
for
i
,
out
in
enumerate
(
self
.
outputs
))
self
.
outputs
.
pop
(
idx
)
new_idx
=
0
for
out
,
old_idx
in
old_idx_mappings
:
if
old_idx
==
idx
:
continue
out_clients
=
self
.
clients
[
out
]
arrow
:
ClientType
=
(
"output"
,
old_idx
)
arrow_idx
=
out_clients
.
index
(
arrow
)
out_clients
[
arrow_idx
]
=
(
"output"
,
new_idx
)
new_idx
+=
1
def
remove_node
(
self
,
node
:
Apply
,
reason
:
str
|
None
=
None
):
def
remove_node
(
self
,
node
:
Apply
,
reason
:
str
|
None
=
None
):
"""Remove an `Apply` node from the `FunctionGraph`.
"""Remove an `Apply` node from the `FunctionGraph`.
...
@@ -571,8 +559,8 @@ class FunctionGraph(MetaObject):
...
@@ -571,8 +559,8 @@ class FunctionGraph(MetaObject):
while
out_clients
:
while
out_clients
:
out_client
,
out_idx
=
out_clients
.
pop
()
out_client
,
out_idx
=
out_clients
.
pop
()
if
out_client
==
"output"
:
if
isinstance
(
out_client
.
op
,
Output
)
:
self
.
_remove_output
(
out_idx
)
self
.
remove_output
(
out_client
.
op
.
idx
,
remove_client
=
False
)
# TODO: We could short-circuit all of the graph walking and
# TODO: We could short-circuit all of the graph walking and
# clear everything at once when all the outputs are gone.
# clear everything at once when all the outputs are gone.
...
@@ -588,7 +576,6 @@ class FunctionGraph(MetaObject):
...
@@ -588,7 +576,6 @@ class FunctionGraph(MetaObject):
#
#
# self.execute_callbacks("on_prune", node, reason)
# self.execute_callbacks("on_prune", node, reason)
else
:
else
:
assert
isinstance
(
out_client
,
Apply
)
self
.
remove_node
(
out_client
,
reason
=
reason
)
self
.
remove_node
(
out_client
,
reason
=
reason
)
clients
.
pop
(
out
,
None
)
clients
.
pop
(
out
,
None
)
...
@@ -630,32 +617,46 @@ class FunctionGraph(MetaObject):
...
@@ -630,32 +617,46 @@ class FunctionGraph(MetaObject):
self
.
execute_callbacks
(
"on_prune"
,
node
,
reason
)
self
.
execute_callbacks
(
"on_prune"
,
node
,
reason
)
def
remove_input
(
self
,
input_idx
:
int
,
reason
:
str
|
None
=
None
):
def
remove_input
(
self
,
input_idx
:
int
,
reason
:
str
|
None
=
None
):
"""Remove the input at index `input_idx`."""
"""Remove the input at index `input_idx`.
Any node that depended on such input will also be removed.
"""
var
=
self
.
inputs
.
pop
(
input_idx
)
var
=
self
.
inputs
.
pop
(
input_idx
)
for
client
,
idx
in
list
(
self
.
clients
[
var
]):
for
client
,
idx
in
list
(
self
.
clients
[
var
]):
if
client
==
"output"
:
self
.
remove_node
(
client
,
reason
=
reason
)
out_var
=
self
.
outputs
[
idx
]
out_node
=
out_var
.
owner
def
remove_output
(
if
out_node
is
None
:
self
,
output_idx
:
int
,
reason
:
str
|
None
=
None
,
remove_client
:
bool
=
True
assert
out_var
in
self
.
inputs
):
self
.
outputs
.
pop
(
idx
)
"""Remove the output at index `output_idx` and update the indices in the clients entries.
continue
client_node
=
out_node
`FunctionGraph.clients` contains entries like ``(output(i)(var), 0)`` under
else
:
each output variable in `FunctionGraph.outputs`. The ``i`` values
assert
isinstance
(
client
,
Apply
)
correspond to each output's location within the `FunctionGraph.outputs`
client_node
=
client
list, so, when an output is removed from the graph, all these entries
need to be updated. This method performs those updates.
self
.
remove_node
(
client_node
,
reason
=
reason
)
"""
outputs
=
self
.
outputs
def
remove_output
(
self
,
output_idx
:
int
,
reason
:
str
|
None
=
None
):
# We have to update all the output indexes to the right of the removed index
"""Remove the output at index `input_idx`."""
for
old_idx
,
out
in
enumerate
(
outputs
[
output_idx
+
1
:],
output_idx
+
1
):
old_client
=
self
.
get_output_client
(
old_idx
)
out_clients
=
self
.
clients
[
out
]
out_clients
[
out_clients
.
index
(
old_client
,
0
)]
=
(
Output
(
old_idx
-
1
)
.
make_node
(
out
),
0
,
)
var
=
self
.
outputs
[
output_idx
]
# Remove the Output Op client from the clients list
self
.
_remove_output
(
output_idx
)
# This is false when called from `remove_node` which removes the clients ahead of time
self
.
remove_client
(
if
remove_client
:
var
,
(
"output"
,
output_idx
),
reason
=
reason
,
remove_if_empty
=
True
output_client
=
self
.
get_output_client
(
output_idx
)
)
self
.
remove_client
(
outputs
[
output_idx
],
output_client
,
reason
=
reason
,
remove_if_empty
=
True
)
outputs
.
pop
(
output_idx
)
def
attach_feature
(
self
,
feature
:
Feature
)
->
None
:
def
attach_feature
(
self
,
feature
:
Feature
)
->
None
:
"""Add a ``graph.features.Feature`` to this function graph and trigger its ``on_attach`` callback."""
"""Add a ``graph.features.Feature`` to this function graph and trigger its ``on_attach`` callback."""
...
@@ -832,19 +833,17 @@ class FunctionGraph(MetaObject):
...
@@ -832,19 +833,17 @@ class FunctionGraph(MetaObject):
):
):
raise
Exception
(
f
"Undeclared input: {variable}"
)
raise
Exception
(
f
"Undeclared input: {variable}"
)
for
cl_node
,
i
in
clients
[
variable
]:
for
cl_node
,
i
in
clients
[
variable
]:
if
cl_node
==
"output"
:
if
isinstance
(
cl_node
.
op
,
Output
):
if
self
.
outputs
[
i
]
is
not
variable
:
out_idx
=
cl_node
.
op
.
idx
if
self
.
outputs
[
out_idx
]
is
not
variable
:
raise
Exception
(
raise
Exception
(
f
"Inconsistent clients list: {variable}, {self.outputs[
i
]}"
f
"Inconsistent clients list: {variable}, {self.outputs[
out_idx
]}"
)
)
continue
elif
cl_node
not
in
nodes
:
assert
isinstance
(
cl_node
,
Apply
)
if
cl_node
not
in
nodes
:
raise
Exception
(
raise
Exception
(
f
"Client not in FunctionGraph: {variable}, {(cl_node, i)}"
f
"Client not in FunctionGraph: {variable}, {(cl_node, i)}"
)
)
if
cl_node
.
inputs
[
i
]
is
not
variable
:
if
cl_node
.
inputs
[
i
]
is
not
variable
:
raise
Exception
(
raise
Exception
(
f
"Inconsistent clients list: {variable}, {cl_node.inputs[i]}"
f
"Inconsistent clients list: {variable}, {cl_node.inputs[i]}"
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
9ba6d99f
...
@@ -30,7 +30,7 @@ from pytensor.graph.basic import (
...
@@ -30,7 +30,7 @@ from pytensor.graph.basic import (
vars_between
,
vars_between
,
)
)
from
pytensor.graph.features
import
AlreadyThere
,
Feature
,
NodeFinder
from
pytensor.graph.features
import
AlreadyThere
,
Feature
,
NodeFinder
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.utils
import
AssocList
,
InconsistencyError
from
pytensor.graph.utils
import
AssocList
,
InconsistencyError
from
pytensor.misc.ordered_set
import
OrderedSet
from
pytensor.misc.ordered_set
import
OrderedSet
...
@@ -738,7 +738,7 @@ class MergeOptimizer(GraphRewriter):
...
@@ -738,7 +738,7 @@ class MergeOptimizer(GraphRewriter):
if
any
(
if
any
(
i
in
flatten
(
c
.
op
.
destroy_map
.
values
())
i
in
flatten
(
c
.
op
.
destroy_map
.
values
())
for
c
,
i
in
clients
for
c
,
i
in
clients
if
c
!=
"output"
and
c
.
op
.
destroy_map
if
c
.
op
.
destroy_map
):
):
continue
continue
...
@@ -1612,8 +1612,6 @@ class PatternNodeRewriter(NodeRewriter):
...
@@ -1612,8 +1612,6 @@ class PatternNodeRewriter(NodeRewriter):
if
get_nodes
and
self
.
get_nodes
is
not
None
:
if
get_nodes
and
self
.
get_nodes
is
not
None
:
for
real_node
in
self
.
get_nodes
(
fgraph
,
node
):
for
real_node
in
self
.
get_nodes
(
fgraph
,
node
):
if
real_node
==
"output"
:
continue
ret
=
self
.
transform
(
fgraph
,
real_node
,
get_nodes
=
False
)
ret
=
self
.
transform
(
fgraph
,
real_node
,
get_nodes
=
False
)
if
ret
is
not
False
and
ret
is
not
None
:
if
ret
is
not
False
and
ret
is
not
None
:
return
dict
(
zip
(
real_node
.
outputs
,
ret
))
return
dict
(
zip
(
real_node
.
outputs
,
ret
))
...
@@ -2399,7 +2397,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
...
@@ -2399,7 +2397,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
if
self
.
tracks_on_change_inputs
:
if
self
.
tracks_on_change_inputs
:
def
chin_
(
node
,
i
,
r
,
new_r
,
reason
):
def
chin_
(
node
,
i
,
r
,
new_r
,
reason
):
if
node
is
not
current_node
and
not
isinstance
(
node
,
str
):
if
node
is
not
current_node
and
not
isinstance
(
node
.
op
,
Output
):
q
.
append
(
node
)
q
.
append
(
node
)
chin
=
chin_
chin
=
chin_
...
...
pytensor/graph/rewriting/utils.py
浏览文件 @
9ba6d99f
import
copy
import
copy
from
collections.abc
import
Generator
,
Sequence
from
collections.abc
import
Generator
,
Sequence
from
typing
import
TYPE_CHECKING
,
Optional
,
cast
from
typing
import
TYPE_CHECKING
,
Optional
import
pytensor
import
pytensor
from
pytensor.graph.basic
import
(
from
pytensor.graph.basic
import
(
...
@@ -10,7 +10,7 @@ from pytensor.graph.basic import (
...
@@ -10,7 +10,7 @@ from pytensor.graph.basic import (
graph_inputs
,
graph_inputs
,
vars_between
,
vars_between
,
)
)
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
...
@@ -230,11 +230,9 @@ def get_clients_at_depth(
...
@@ -230,11 +230,9 @@ def get_clients_at_depth(
for
var
in
node
.
outputs
:
for
var
in
node
.
outputs
:
if
depth
>
0
:
if
depth
>
0
:
for
out_node
,
_
in
fgraph
.
clients
[
var
]:
for
out_node
,
_
in
fgraph
.
clients
[
var
]:
if
out_node
==
"output"
:
if
isinstance
(
out_node
.
op
,
Output
)
:
continue
continue
yield
from
get_clients_at_depth
(
yield
from
get_clients_at_depth
(
fgraph
,
out_node
,
depth
-
1
)
fgraph
,
cast
(
Apply
,
out_node
),
depth
-
1
)
else
:
else
:
assert
var
.
owner
is
not
None
assert
var
.
owner
is
not
None
yield
var
.
owner
yield
var
.
owner
pytensor/link/c/basic.py
浏览文件 @
9ba6d99f
...
@@ -354,9 +354,7 @@ def get_c_declare(fgraph, r, name, sub):
...
@@ -354,9 +354,7 @@ def get_c_declare(fgraph, r, name, sub):
# it means they need `r`'s dtype to be declared, so
# it means they need `r`'s dtype to be declared, so
# we have to pass `check_input=True` to `c_declare`.
# we have to pass `check_input=True` to `c_declare`.
if
any
(
if
any
(
getattr
(
c
.
op
,
"check_input"
,
config
.
check_input
)
getattr
(
c
.
op
,
"check_input"
,
config
.
check_input
)
for
(
c
,
_
)
in
fgraph
.
clients
[
r
]
for
(
c
,
_
)
in
fgraph
.
clients
[
r
]
if
not
isinstance
(
c
,
str
)
)
or
(
r
.
owner
and
getattr
(
r
.
owner
.
op
,
"check_input"
,
config
.
check_input
)):
)
or
(
r
.
owner
and
getattr
(
r
.
owner
.
op
,
"check_input"
,
config
.
check_input
)):
c_declare
=
r
.
type
.
c_declare
(
name
,
sub
,
True
)
c_declare
=
r
.
type
.
c_declare
(
name
,
sub
,
True
)
else
:
else
:
...
...
pytensor/link/vm.py
浏览文件 @
9ba6d99f
...
@@ -954,8 +954,7 @@ class VMLinker(LocalLinker):
...
@@ -954,8 +954,7 @@ class VMLinker(LocalLinker):
if
k
.
owner
and
self
.
fgraph
.
clients
[
k
]:
if
k
.
owner
and
self
.
fgraph
.
clients
[
k
]:
ls
=
[]
ls
=
[]
for
cl
in
self
.
fgraph
.
clients
[
k
]:
for
cl
in
self
.
fgraph
.
clients
[
k
]:
if
cl
[
0
]
!=
"output"
:
ls
+=
cl
[
0
]
.
outputs
ls
+=
cl
[
0
]
.
outputs
dependencies
[
k
]
+=
ls
dependencies
[
k
]
+=
ls
return
dependencies
return
dependencies
...
...
pytensor/printing.py
浏览文件 @
9ba6d99f
...
@@ -437,10 +437,15 @@ N.B.:
...
@@ -437,10 +437,15 @@ N.B.:
for
out
in
inner_outputs
:
for
out
in
inner_outputs
:
if
(
if
(
isinstance
(
getattr
(
out
.
owner
,
"op"
,
None
),
HasInnerGraph
)
out
.
owner
is
not
None
or
hasattr
(
getattr
(
out
.
owner
,
"op"
,
None
),
"scalar_op"
)
and
(
and
isinstance
(
out
.
owner
.
op
.
scalar_op
,
HasInnerGraph
)
isinstance
(
out
.
owner
.
op
,
HasInnerGraph
)
)
and
out
not
in
inner_graph_vars
:
or
isinstance
(
getattr
(
out
.
owner
.
op
,
"scalar_op"
,
None
),
HasInnerGraph
)
)
and
out
not
in
inner_graph_vars
):
inner_graph_vars
.
append
(
out
)
inner_graph_vars
.
append
(
out
)
_debugprint
(
_debugprint
(
...
...
pytensor/scan/rewriting.py
浏览文件 @
9ba6d99f
...
@@ -27,7 +27,7 @@ from pytensor.graph.basic import (
...
@@ -27,7 +27,7 @@ from pytensor.graph.basic import (
)
)
from
pytensor.graph.destroyhandler
import
DestroyHandler
from
pytensor.graph.destroyhandler
import
DestroyHandler
from
pytensor.graph.features
import
ReplaceValidate
from
pytensor.graph.features
import
ReplaceValidate
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.op
import
compute_test_value
from
pytensor.graph.op
import
compute_test_value
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.rewriting.basic
import
(
from
pytensor.graph.rewriting.basic
import
(
...
@@ -1303,7 +1303,7 @@ def scan_save_mem(fgraph, node):
...
@@ -1303,7 +1303,7 @@ def scan_save_mem(fgraph, node):
for
cl
,
_
in
fgraph
.
clients
[
out
]:
for
cl
,
_
in
fgraph
.
clients
[
out
]:
# 2.1 outputs of the function
# 2.1 outputs of the function
# => output needs all its intermediate values
# => output needs all its intermediate values
if
isinstance
(
cl
,
str
):
if
isinstance
(
cl
.
op
,
Output
):
# if the node is actually an output, then
# if the node is actually an output, then
# we need to store the entire thing
# we need to store the entire thing
global_nsteps
=
None
global_nsteps
=
None
...
@@ -1412,7 +1412,7 @@ def scan_save_mem(fgraph, node):
...
@@ -1412,7 +1412,7 @@ def scan_save_mem(fgraph, node):
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
# look at all its clients
# look at all its clients
for
cl
,
_
in
fgraph
.
clients
[
out
]:
for
cl
,
_
in
fgraph
.
clients
[
out
]:
if
isinstance
(
cl
,
str
):
if
isinstance
(
cl
.
op
,
Output
):
store_steps
[
i
]
=
0
store_steps
[
i
]
=
0
break
break
elif
not
isinstance
(
cl
.
op
,
Subtensor
):
elif
not
isinstance
(
cl
.
op
,
Subtensor
):
...
@@ -2309,7 +2309,7 @@ def scan_push_out_dot1(fgraph, node):
...
@@ -2309,7 +2309,7 @@ def scan_push_out_dot1(fgraph, node):
and
isinstance
(
out
.
owner
.
op
.
scalar_op
,
ps
.
Add
)
and
isinstance
(
out
.
owner
.
op
.
scalar_op
,
ps
.
Add
)
and
inp
in
out
.
owner
.
inputs
and
inp
in
out
.
owner
.
inputs
and
len
(
fgraph
.
clients
[
outer_out
])
==
1
and
len
(
fgraph
.
clients
[
outer_out
])
==
1
and
not
isinstance
(
fgraph
.
clients
[
outer_out
][
0
][
0
],
str
)
and
not
isinstance
(
fgraph
.
clients
[
outer_out
][
0
][
0
],
Output
)
and
isinstance
(
fgraph
.
clients
[
outer_out
][
0
][
0
]
.
op
,
Subtensor
)
and
isinstance
(
fgraph
.
clients
[
outer_out
][
0
][
0
]
.
op
,
Subtensor
)
and
fgraph
.
clients
[
outer_out
][
0
][
0
]
.
op
.
idx_list
==
(
-
1
,)
and
fgraph
.
clients
[
outer_out
][
0
][
0
]
.
op
.
idx_list
==
(
-
1
,)
):
):
...
...
pytensor/tensor/basic.py
浏览文件 @
9ba6d99f
...
@@ -24,7 +24,7 @@ from pytensor import scalar as ps
...
@@ -24,7 +24,7 @@ from pytensor import scalar as ps
from
pytensor.gradient
import
DisconnectedType
,
grad_undefined
from
pytensor.gradient
import
DisconnectedType
,
grad_undefined
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
equal_computations
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
equal_computations
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.rewriting.db
import
EquilibriumDB
from
pytensor.graph.rewriting.db
import
EquilibriumDB
...
@@ -1681,7 +1681,7 @@ class Alloc(COp):
...
@@ -1681,7 +1681,7 @@ class Alloc(COp):
return
False
return
False
for
client
,
idx
in
clients
:
for
client
,
idx
in
clients
:
if
client
==
"output"
:
if
isinstance
(
client
.
op
,
Output
)
:
# If the output is a constant, it will have to be deepcopied
# If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold.
# each time the function is called. So we do not fold.
return
False
return
False
...
...
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
9ba6d99f
...
@@ -31,15 +31,12 @@ def is_rv_used_in_graph(base_rv, node, fgraph):
...
@@ -31,15 +31,12 @@ def is_rv_used_in_graph(base_rv, node, fgraph):
TODO: We should apply all the shape rewrites before these rewrites, since
TODO: We should apply all the shape rewrites before these rewrites, since
that would properly remove the unnecessary dependencies on `base_rv` (when
that would properly remove the unnecessary dependencies on `base_rv` (when
possible).
possible).
"""
"""
return
any
(
def
_node_check
(
n
,
i
):
n
if
n
==
"output"
:
for
n
,
i
in
fgraph
.
clients
.
get
(
base_rv
,
())
n
=
fgraph
.
outputs
[
i
]
.
owner
if
not
(
n
is
node
or
isinstance
(
n
.
op
,
Shape
|
Shape_i
))
return
n
==
node
or
isinstance
(
n
.
op
,
Shape
|
Shape_i
)
)
return
not
all
(
_node_check
(
n
,
i
)
for
n
,
i
in
fgraph
.
clients
.
get
(
base_rv
,
()))
@node_rewriter
([
RandomVariable
],
inplace
=
True
)
@node_rewriter
([
RandomVariable
],
inplace
=
True
)
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
9ba6d99f
...
@@ -14,7 +14,7 @@ from pytensor.configdefaults import config
...
@@ -14,7 +14,7 @@ from pytensor.configdefaults import config
from
pytensor.graph
import
FunctionGraph
from
pytensor.graph
import
FunctionGraph
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
ancestors
,
io_toposort
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
,
ancestors
,
io_toposort
from
pytensor.graph.features
import
ReplaceValidate
from
pytensor.graph.features
import
ReplaceValidate
from
pytensor.graph.fg
import
ApplyOr
Output
from
pytensor.graph.fg
import
Output
from
pytensor.graph.rewriting.basic
import
(
from
pytensor.graph.rewriting.basic
import
(
EquilibriumGraphRewriter
,
EquilibriumGraphRewriter
,
GraphRewriter
,
GraphRewriter
,
...
@@ -688,7 +688,7 @@ class FusionOptimizer(GraphRewriter):
...
@@ -688,7 +688,7 @@ class FusionOptimizer(GraphRewriter):
"""
"""
FUSEABLE_MAPPING
=
defaultdict
[
Variable
,
list
[
Apply
]]
FUSEABLE_MAPPING
=
defaultdict
[
Variable
,
list
[
Apply
]]
UNFUSEABLE_MAPPING
=
defaultdict
[
Variable
,
set
[
Apply
OrOutput
]]
UNFUSEABLE_MAPPING
=
defaultdict
[
Variable
,
set
[
Apply
]]
def
initialize_fuseable_mappings
(
def
initialize_fuseable_mappings
(
*
,
fg
:
FunctionGraph
*
,
fg
:
FunctionGraph
...
@@ -727,7 +727,6 @@ class FusionOptimizer(GraphRewriter):
...
@@ -727,7 +727,6 @@ class FusionOptimizer(GraphRewriter):
for
client
,
_
in
clients
:
for
client
,
_
in
clients
:
if
(
if
(
out_maybe_fuseable
out_maybe_fuseable
and
not
isinstance
(
client
,
str
)
# "output"
and
isinstance
(
client
.
op
,
Elemwise
)
and
isinstance
(
client
.
op
,
Elemwise
)
# and not isinstance(client.op.scalar_op, ps.Composite)
# and not isinstance(client.op.scalar_op, ps.Composite)
and
len
(
client
.
outputs
)
==
1
and
len
(
client
.
outputs
)
==
1
...
@@ -841,7 +840,7 @@ class FusionOptimizer(GraphRewriter):
...
@@ -841,7 +840,7 @@ class FusionOptimizer(GraphRewriter):
implied_unfuseable_clients
=
{
implied_unfuseable_clients
=
{
c
c
for
client
in
unfuseable_clients_clone
.
get
(
next_out
,
())
for
client
in
unfuseable_clients_clone
.
get
(
next_out
,
())
if
not
isinstance
(
client
,
str
)
# "output"
if
not
isinstance
(
client
.
op
,
Output
)
for
c
in
client
.
outputs
for
c
in
client
.
outputs
}
}
...
...
pytensor/tensor/rewriting/linalg.py
浏览文件 @
9ba6d99f
...
@@ -299,8 +299,6 @@ def local_det_chol(fgraph, node):
...
@@ -299,8 +299,6 @@ def local_det_chol(fgraph, node):
"""
"""
(
x
,)
=
node
.
inputs
(
x
,)
=
node
.
inputs
for
cl
,
xpos
in
fgraph
.
clients
[
x
]:
for
cl
,
xpos
in
fgraph
.
clients
[
x
]:
if
cl
==
"output"
:
continue
if
isinstance
(
cl
.
op
,
Blockwise
)
and
isinstance
(
cl
.
op
.
core_op
,
Cholesky
):
if
isinstance
(
cl
.
op
,
Blockwise
)
and
isinstance
(
cl
.
op
.
core_op
,
Cholesky
):
L
=
cl
.
outputs
[
0
]
L
=
cl
.
outputs
[
0
]
return
[
prod
(
diagonal
(
L
,
axis1
=-
2
,
axis2
=-
1
)
**
2
,
axis
=-
1
)]
return
[
prod
(
diagonal
(
L
,
axis1
=-
2
,
axis2
=-
1
)
**
2
,
axis
=-
1
)]
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
9ba6d99f
...
@@ -1126,14 +1126,11 @@ class AlgebraicCanonizer(NodeRewriter):
...
@@ -1126,14 +1126,11 @@ class AlgebraicCanonizer(NodeRewriter):
# this canonized graph... if so, we do nothing and wait for
# this canonized graph... if so, we do nothing and wait for
# them to be transformed.
# them to be transformed.
for
c
,
c_idx
in
out_clients
:
for
c
,
c_idx
in
out_clients
:
if
c
==
"output"
:
continue
while
(
while
(
isinstance
(
getattr
(
c
,
"op"
,
None
),
DimShuffle
)
isinstance
(
c
.
op
,
DimShuffle
)
and
len
(
fgraph
.
clients
[
c
.
outputs
[
0
]])
<=
1
and
len
(
fgraph
.
clients
[
c
.
outputs
[
0
]])
<=
1
):
):
c
=
fgraph
.
clients
[
c
.
outputs
[
0
]][
0
][
0
]
c
=
fgraph
.
clients
[
c
.
outputs
[
0
]][
0
][
0
]
if
getattr
(
c
,
"op"
,
""
)
in
[
self
.
main
,
self
.
inverse
,
self
.
reciprocal
]:
if
c
.
op
in
[
self
.
main
,
self
.
inverse
,
self
.
reciprocal
]:
return
False
return
False
# Here we make the canonical version of the graph around this node
# Here we make the canonical version of the graph around this node
...
...
pytensor/tensor/rewriting/shape.py
浏览文件 @
9ba6d99f
...
@@ -401,7 +401,7 @@ class ShapeFeature(Feature):
...
@@ -401,7 +401,7 @@ class ShapeFeature(Feature):
merged_shape
.
append
(
other_shape
[
i
])
merged_shape
.
append
(
other_shape
[
i
])
elif
(
elif
(
ps
.
owner
ps
.
owner
and
isinstance
(
getattr
(
ps
.
owner
,
"op"
,
None
)
,
Shape_i
)
and
isinstance
(
ps
.
owner
.
op
,
Shape_i
)
and
ps
.
owner
.
op
.
i
==
i
and
ps
.
owner
.
op
.
i
==
i
and
ps
.
owner
.
inputs
[
0
]
in
(
r
,
other_r
)
and
ps
.
owner
.
inputs
[
0
]
in
(
r
,
other_r
)
):
):
...
@@ -602,7 +602,7 @@ class ShapeFeature(Feature):
...
@@ -602,7 +602,7 @@ class ShapeFeature(Feature):
# r is *scheduled*.
# r is *scheduled*.
# At that point, node is no longer a client of r, but of new_r
# At that point, node is no longer a client of r, but of new_r
for
shpnode
,
idx
in
fgraph
.
clients
[
r
]
+
[(
node
,
i
)]:
for
shpnode
,
idx
in
fgraph
.
clients
[
r
]
+
[(
node
,
i
)]:
if
isinstance
(
getattr
(
shpnode
,
"op"
,
None
)
,
Shape_i
):
if
isinstance
(
shpnode
.
op
,
Shape_i
):
idx
=
shpnode
.
op
.
i
idx
=
shpnode
.
op
.
i
repl
=
self
.
shape_of
[
new_r
][
idx
]
repl
=
self
.
shape_of
[
new_r
][
idx
]
if
repl
.
owner
is
shpnode
:
if
repl
.
owner
is
shpnode
:
...
@@ -1057,7 +1057,10 @@ def local_Shape_of_SpecifyShape(fgraph, node):
...
@@ -1057,7 +1057,10 @@ def local_Shape_of_SpecifyShape(fgraph, node):
specified_shape
=
node
.
inputs
[
0
]
specified_shape
=
node
.
inputs
[
0
]
if
not
isinstance
(
getattr
(
specified_shape
.
owner
,
"op"
,
None
),
SpecifyShape
):
if
not
(
specified_shape
.
owner
is
not
None
and
isinstance
(
specified_shape
.
owner
.
op
,
SpecifyShape
)
):
return
False
return
False
x
,
*
shape
=
specified_shape
.
owner
.
inputs
x
,
*
shape
=
specified_shape
.
owner
.
inputs
...
...
tests/graph/test_fg.py
浏览文件 @
9ba6d99f
...
@@ -6,7 +6,7 @@ import pytest
...
@@ -6,7 +6,7 @@ import pytest
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
NominalVariable
from
pytensor.graph.basic
import
NominalVariable
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.utils
import
MissingInputError
from
pytensor.graph.utils
import
MissingInputError
from
pytensor.printing
import
debugprint
from
pytensor.printing
import
debugprint
from
tests.graph.utils
import
(
from
tests.graph.utils
import
(
...
@@ -78,8 +78,13 @@ class TestFunctionGraph:
...
@@ -78,8 +78,13 @@ class TestFunctionGraph:
assert
fg
.
variables
==
{
var1
,
var2
,
var3
,
var4
}
assert
fg
.
variables
==
{
var1
,
var2
,
var3
,
var4
}
assert
fg
.
get_clients
(
var1
)
==
[(
var3
.
owner
,
0
)]
assert
fg
.
get_clients
(
var1
)
==
[(
var3
.
owner
,
0
)]
assert
fg
.
get_clients
(
var2
)
==
[(
var4
.
owner
,
1
)]
assert
fg
.
get_clients
(
var2
)
==
[(
var4
.
owner
,
1
)]
assert
fg
.
get_clients
(
var3
)
==
[(
"output"
,
0
),
(
var4
.
owner
,
0
)]
var3_clients
=
fg
.
get_clients
(
var3
)
assert
fg
.
get_clients
(
var4
)
==
[(
"output"
,
1
)]
assert
len
(
var3_clients
)
==
2
assert
var3_clients
[
0
][
0
]
.
op
==
Output
(
0
)
assert
var3_clients
[
1
]
==
(
var4
.
owner
,
0
)
var4_clients
=
fg
.
get_clients
(
var4
)
assert
len
(
var4_clients
)
==
1
assert
var4_clients
[
0
][
0
]
.
op
==
Output
(
1
)
varC
=
MyConstant
(
"varC"
)
varC
=
MyConstant
(
"varC"
)
var5
=
op1
(
var1
,
varC
)
var5
=
op1
(
var1
,
varC
)
...
@@ -208,8 +213,11 @@ class TestFunctionGraph:
...
@@ -208,8 +213,11 @@ class TestFunctionGraph:
fg
=
FunctionGraph
([
var1
,
var2
],
[
var3
,
var5
],
clone
=
False
)
fg
=
FunctionGraph
([
var1
,
var2
],
[
var3
,
var5
],
clone
=
False
)
var6
=
MyVariable2
(
"var6"
)
var6
=
MyVariable2
(
"var6"
)
[
out_client
]
=
[
cl
for
cl
,
_
in
fg
.
clients
[
fg
.
outputs
[
0
]]
if
isinstance
(
cl
.
op
,
Output
)
]
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
fg
.
change_node_input
(
"output"
,
1
,
var6
)
fg
.
change_node_input
(
out_client
,
0
,
var6
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
fg
.
change_node_input
(
var5
.
owner
,
1
,
var6
)
fg
.
change_node_input
(
var5
.
owner
,
1
,
var6
)
...
@@ -358,12 +366,13 @@ class TestFunctionGraph:
...
@@ -358,12 +366,13 @@ class TestFunctionGraph:
# TODO: What if the index value is greater than 1? It will throw an
# TODO: What if the index value is greater than 1? It will throw an
# `IndexError`, but that doesn't sound like anything we'd want.
# `IndexError`, but that doesn't sound like anything we'd want.
out_node
=
Output
(
idx
=
1
)
.
make_node
(
var4
)
with
pytest
.
raises
(
Exception
,
match
=
"Inconsistent clients list.*"
):
with
pytest
.
raises
(
Exception
,
match
=
"Inconsistent clients list.*"
):
fg
.
add_client
(
var4
,
(
"output"
,
1
))
fg
.
add_client
(
var4
,
(
out_node
,
0
))
fg
.
check_integrity
()
fg
.
check_integrity
()
fg
.
remove_client
(
var4
,
(
"output"
,
1
))
fg
.
remove_client
(
var4
,
(
out_node
,
0
))
with
pytest
.
raises
(
TypeError
,
match
=
"The first entry of.*"
):
with
pytest
.
raises
(
TypeError
,
match
=
"The first entry of.*"
):
fg
.
add_client
(
var4
,
(
None
,
0
))
fg
.
add_client
(
var4
,
(
None
,
0
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论