Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1a3af4b2
提交
1a3af4b2
authored
11月 21, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reduce overhead of Function call
上级
a0c64b5f
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
77 行增加
和
75 行删除
+77
-75
types.py
pytensor/compile/function/types.py
+77
-75
没有找到文件。
pytensor/compile/function/types.py
浏览文件 @
1a3af4b2
...
...
@@ -393,6 +393,8 @@ class Function:
assert
len
(
self
.
input_storage
)
==
len
(
self
.
maker
.
fgraph
.
inputs
)
assert
len
(
self
.
output_storage
)
==
len
(
self
.
maker
.
fgraph
.
outputs
)
self
.
has_defaults
=
any
(
refeed
for
_
,
refeed
,
_
in
self
.
defaults
)
# Group indexes of inputs that are potentially aliased to each other
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
# even though there could be two distinct types that use the same kinds of underlying objects.
...
...
@@ -540,14 +542,40 @@ class Function:
self
.
_value
=
ValueAttribute
()
self
.
_container
=
ContainerAttribute
()
# TODO: Get rid of all this `expanded_inputs` nonsense
assert
len
(
self
.
maker
.
expanded_inputs
)
==
len
(
self
.
input_storage
)
update_storage
=
[
container
for
inp
,
container
in
zip
(
self
.
maker
.
expanded_inputs
,
input_storage
,
strict
=
True
)
if
inp
.
update
is
not
None
]
# Updates are the last inner outputs that are not returned by Function.__call__
self
.
n_returned_outputs
=
len
(
self
.
output_storage
)
-
len
(
update_storage
)
# Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself
self
.
update_input_storage
:
tuple
[
int
,
Container
]
=
()
if
getattr
(
vm
,
"need_update_inputs"
,
True
):
self
.
update_input_storage
=
tuple
(
zip
(
range
(
self
.
n_returned_outputs
,
len
(
output_storage
)),
update_storage
,
strict
=
True
,
)
)
# This is used only when `vm.need_update_inputs` is `False`, because
# we're using one of the VM objects and it is putting updates back into
# the input containers all by itself.
self
.
n_returned_outputs
=
len
(
self
.
output_storage
)
-
sum
(
inp
.
update
is
not
None
for
inp
in
self
.
maker
.
expanded_inputs
# In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage
# After the call, we want to erase (some of) these references, to allow Python to GC them if unused
# Required input containers are the non-default inputs, must always be provided again, so we GC them
self
.
clear_input_storage_data
=
tuple
(
container
.
storage
for
container
in
input_storage
if
container
.
required
)
# This is only done when `vm.allow_gc` is True, which can change at runtime.
self
.
clear_output_storage_data
=
tuple
(
container
.
storage
for
container
,
variable
in
zip
(
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
,
strict
=
True
)
if
variable
.
owner
is
not
None
# Not a constant output
)
for
node
in
self
.
maker
.
fgraph
.
apply_nodes
:
...
...
@@ -747,7 +775,7 @@ class Function:
elif
isinstance
(
profile
,
str
):
profile
=
pytensor
.
compile
.
profiling
.
ProfileStats
(
message
=
profile
)
f_cpy
=
maker
.
__class__
(
f_cpy
=
type
(
maker
)
(
inputs
=
ins
,
outputs
=
outs
,
fgraph
=
fg_cpy
,
...
...
@@ -765,6 +793,8 @@ class Function:
# check that.
accept_inplace
=
True
,
no_fgraph_prep
=
True
,
output_keys
=
maker
.
output_keys
,
name
=
name
,
)
.
create
(
input_storage
,
storage_map
=
new_storage_map
)
for
in_ori
,
in_cpy
,
ori
,
cpy
in
zip
(
...
...
@@ -797,8 +827,6 @@ class Function:
f_cpy
.
trust_input
=
self
.
trust_input
f_cpy
.
unpack_single
=
self
.
unpack_single
f_cpy
.
name
=
name
f_cpy
.
maker
.
fgraph
.
name
=
name
return
f_cpy
def
_restore_defaults
(
self
):
...
...
@@ -808,7 +836,7 @@ class Function:
value
=
value
.
storage
[
0
]
self
[
i
]
=
value
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
output_subset
=
None
,
**
kwargs
):
"""
Evaluates value of a function on given arguments.
...
...
@@ -836,20 +864,21 @@ class Function:
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
"""
trust_input
=
self
.
trust_input
input_storage
=
self
.
input_storage
vm
=
self
.
vm
profile
=
self
.
profile
if
profile
:
t0
=
time
.
perf_counter
()
output_subset
=
kwargs
.
pop
(
"output_subset"
,
None
)
if
output_subset
is
not
None
:
warnings
.
warn
(
"output_subset is deprecated."
,
FutureWarning
)
if
self
.
output_keys
is
not
None
:
output_subset
=
[
self
.
output_keys
.
index
(
key
)
for
key
in
output_subset
]
# Reinitialize each container's 'provided' counter
if
self
.
trust_input
:
if
trust_input
:
for
arg_container
,
arg
in
zip
(
input_storage
,
args
,
strict
=
False
):
arg_container
.
storage
[
0
]
=
arg
else
:
...
...
@@ -908,7 +937,7 @@ class Function:
for
k
,
arg
in
kwargs
.
items
():
self
[
k
]
=
arg
if
not
self
.
trust_input
:
if
not
trust_input
:
# Collect aliased inputs among the storage space
for
potential_group
in
self
.
_potential_aliased_input_groups
:
args_share_memory
:
list
[
list
[
int
]]
=
[]
...
...
@@ -960,11 +989,7 @@ class Function:
if
profile
:
t0_fn
=
time
.
perf_counter
()
try
:
outputs
=
(
self
.
vm
()
if
output_subset
is
None
else
self
.
vm
(
output_subset
=
output_subset
)
)
outputs
=
vm
()
if
output_subset
is
None
else
vm
(
output_subset
=
output_subset
)
except
Exception
:
self
.
_restore_defaults
()
if
hasattr
(
self
.
vm
,
"position_of_error"
):
...
...
@@ -991,39 +1016,23 @@ class Function:
# Retrieve the values that were computed
if
outputs
is
None
:
outputs
=
[
x
.
data
for
x
in
self
.
output_storage
]
# Remove internal references to required inputs.
# These cannot be re-used anyway.
for
arg_container
in
input_storage
:
if
arg_container
.
required
:
arg_container
.
storage
[
0
]
=
None
# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
if
getattr
(
self
.
vm
,
"allow_gc"
,
False
):
# strict=False because we are in a hot loop
for
o_container
,
o_variable
in
zip
(
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
,
strict
=
False
):
if
o_variable
.
owner
is
not
None
:
# this node is the variable of computation
# WARNING: This circumvents the 'readonly' attribute in x
o_container
.
storage
[
0
]
=
None
if
getattr
(
self
.
vm
,
"need_update_inputs"
,
True
):
# Update the inputs that have an update function
# strict=False because we are in a hot loop
for
input
,
storage
in
reversed
(
list
(
zip
(
self
.
maker
.
expanded_inputs
,
input_storage
,
strict
=
False
))
):
if
input
.
update
is
not
None
:
storage
.
data
=
outputs
.
pop
()
else
:
outputs
=
outputs
[:
self
.
n_returned_outputs
]
outputs
=
[
x
.
storage
[
0
]
for
x
in
self
.
output_storage
]
# Set updates and filter them out from the returned outputs
for
i
,
input_storage
in
self
.
update_input_storage
:
input_storage
.
storage
[
0
]
=
outputs
[
i
]
outputs
=
outputs
[:
self
.
n_returned_outputs
]
# Remove input and output values from storage data
for
storage_data
in
self
.
clear_input_storage_data
:
storage_data
[
0
]
=
None
if
getattr
(
vm
,
"allow_gc"
,
False
):
for
storage_data
in
self
.
clear_output_storage_data
:
storage_data
[
0
]
=
None
# Put default values back in the storage
self
.
_restore_defaults
()
if
self
.
has_defaults
:
self
.
_restore_defaults
()
if
profile
:
dt_call
=
time
.
perf_counter
()
-
t0
...
...
@@ -1031,33 +1040,29 @@ class Function:
self
.
maker
.
mode
.
call_time
+=
dt_call
profile
.
fct_callcount
+=
1
profile
.
fct_call_time
+=
dt_call
if
hasattr
(
self
.
vm
,
"update_profile"
):
self
.
vm
.
update_profile
(
profile
)
if
hasattr
(
vm
,
"update_profile"
):
vm
.
update_profile
(
profile
)
if
profile
.
ignore_first_call
:
profile
.
reset
()
profile
.
ignore_first_call
=
False
if
self
.
return_none
:
return
None
elif
self
.
unpack_single
and
len
(
outputs
)
==
1
and
output_subset
is
None
:
return
outputs
[
0
]
else
:
if
self
.
output_keys
is
not
None
:
assert
len
(
self
.
output_keys
)
==
len
(
outputs
)
if
output_subset
is
None
:
# strict=False because we are in a hot loop
return
dict
(
zip
(
self
.
output_keys
,
outputs
,
strict
=
False
))
else
:
return
{
self
.
output_keys
[
index
]:
outputs
[
index
]
for
index
in
output_subset
}
if
output_subset
is
not
None
:
outputs
=
[
outputs
[
i
]
for
i
in
output_subset
]
if
output_subset
is
None
:
return
outputs
if
self
.
output_keys
is
None
:
if
self
.
unpack_single
:
[
out
]
=
outputs
return
out
else
:
return
[
outputs
[
i
]
for
i
in
output_subset
]
return
outputs
else
:
output_keys
=
self
.
output_keys
if
output_subset
is
not
None
:
output_keys
=
[
output_keys
[
i
]
for
i
in
output_subset
]
return
dict
(
zip
(
output_keys
,
outputs
,
strict
=
True
))
value
=
property
(
lambda
self
:
self
.
_value
,
...
...
@@ -1077,9 +1082,10 @@ class Function:
# 1.no allow_gc return False
# 2.has allow_gc, if allow_gc is False, return True
if
not
getattr
(
self
.
vm
,
"allow_gc"
,
True
):
for
key
in
self
.
vm
.
storage_map
:
if
not
isinstance
(
key
,
Constant
):
self
.
vm
.
storage_map
[
key
][
0
]
=
None
storage_map
=
self
.
vm
.
storage_map
for
key
,
value
in
storage_map
.
items
():
if
key
.
owner
is
not
None
:
# Not a constant
value
[
0
]
=
None
for
node
in
self
.
nodes_with_inner_function
:
if
hasattr
(
node
.
fn
,
"free"
):
...
...
@@ -1091,10 +1097,6 @@ class Function:
"""
return
[
i
.
variable
for
i
in
self
.
maker
.
inputs
if
i
.
implicit
]
def
sync_shared
(
self
):
# NOTE: sync was needed on old gpu backend
pass
def
dprint
(
self
,
**
kwargs
):
"""Debug print itself
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论