Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1a3af4b2
提交
1a3af4b2
authored
11月 21, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reduce overhead of Function call
上级
a0c64b5f
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
77 行增加
和
75 行删除
+77
-75
types.py
pytensor/compile/function/types.py
+77
-75
没有找到文件。
pytensor/compile/function/types.py
浏览文件 @
1a3af4b2
...
@@ -393,6 +393,8 @@ class Function:
...
@@ -393,6 +393,8 @@ class Function:
assert
len
(
self
.
input_storage
)
==
len
(
self
.
maker
.
fgraph
.
inputs
)
assert
len
(
self
.
input_storage
)
==
len
(
self
.
maker
.
fgraph
.
inputs
)
assert
len
(
self
.
output_storage
)
==
len
(
self
.
maker
.
fgraph
.
outputs
)
assert
len
(
self
.
output_storage
)
==
len
(
self
.
maker
.
fgraph
.
outputs
)
self
.
has_defaults
=
any
(
refeed
for
_
,
refeed
,
_
in
self
.
defaults
)
# Group indexes of inputs that are potentially aliased to each other
# Group indexes of inputs that are potentially aliased to each other
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
# even though there could be two distinct types that use the same kinds of underlying objects.
# even though there could be two distinct types that use the same kinds of underlying objects.
...
@@ -540,14 +542,40 @@ class Function:
...
@@ -540,14 +542,40 @@ class Function:
self
.
_value
=
ValueAttribute
()
self
.
_value
=
ValueAttribute
()
self
.
_container
=
ContainerAttribute
()
self
.
_container
=
ContainerAttribute
()
# TODO: Get rid of all this `expanded_inputs` nonsense
update_storage
=
[
assert
len
(
self
.
maker
.
expanded_inputs
)
==
len
(
self
.
input_storage
)
container
for
inp
,
container
in
zip
(
self
.
maker
.
expanded_inputs
,
input_storage
,
strict
=
True
)
if
inp
.
update
is
not
None
]
# Updates are the last inner outputs that are not returned by Function.__call__
self
.
n_returned_outputs
=
len
(
self
.
output_storage
)
-
len
(
update_storage
)
# Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself
self
.
update_input_storage
:
tuple
[
int
,
Container
]
=
()
if
getattr
(
vm
,
"need_update_inputs"
,
True
):
self
.
update_input_storage
=
tuple
(
zip
(
range
(
self
.
n_returned_outputs
,
len
(
output_storage
)),
update_storage
,
strict
=
True
,
)
)
# This is used only when `vm.need_update_inputs` is `False`, because
# In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage
# we're using one of the VM objects and it is putting updates back into
# After the call, we want to erase (some of) these references, to allow Python to GC them if unused
# the input containers all by itself.
# Required input containers are the non-default inputs, must always be provided again, so we GC them
self
.
n_returned_outputs
=
len
(
self
.
output_storage
)
-
sum
(
self
.
clear_input_storage_data
=
tuple
(
inp
.
update
is
not
None
for
inp
in
self
.
maker
.
expanded_inputs
container
.
storage
for
container
in
input_storage
if
container
.
required
)
# This is only done when `vm.allow_gc` is True, which can change at runtime.
self
.
clear_output_storage_data
=
tuple
(
container
.
storage
for
container
,
variable
in
zip
(
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
,
strict
=
True
)
if
variable
.
owner
is
not
None
# Not a constant output
)
)
for
node
in
self
.
maker
.
fgraph
.
apply_nodes
:
for
node
in
self
.
maker
.
fgraph
.
apply_nodes
:
...
@@ -747,7 +775,7 @@ class Function:
...
@@ -747,7 +775,7 @@ class Function:
elif
isinstance
(
profile
,
str
):
elif
isinstance
(
profile
,
str
):
profile
=
pytensor
.
compile
.
profiling
.
ProfileStats
(
message
=
profile
)
profile
=
pytensor
.
compile
.
profiling
.
ProfileStats
(
message
=
profile
)
f_cpy
=
maker
.
__class__
(
f_cpy
=
type
(
maker
)
(
inputs
=
ins
,
inputs
=
ins
,
outputs
=
outs
,
outputs
=
outs
,
fgraph
=
fg_cpy
,
fgraph
=
fg_cpy
,
...
@@ -765,6 +793,8 @@ class Function:
...
@@ -765,6 +793,8 @@ class Function:
# check that.
# check that.
accept_inplace
=
True
,
accept_inplace
=
True
,
no_fgraph_prep
=
True
,
no_fgraph_prep
=
True
,
output_keys
=
maker
.
output_keys
,
name
=
name
,
)
.
create
(
input_storage
,
storage_map
=
new_storage_map
)
)
.
create
(
input_storage
,
storage_map
=
new_storage_map
)
for
in_ori
,
in_cpy
,
ori
,
cpy
in
zip
(
for
in_ori
,
in_cpy
,
ori
,
cpy
in
zip
(
...
@@ -797,8 +827,6 @@ class Function:
...
@@ -797,8 +827,6 @@ class Function:
f_cpy
.
trust_input
=
self
.
trust_input
f_cpy
.
trust_input
=
self
.
trust_input
f_cpy
.
unpack_single
=
self
.
unpack_single
f_cpy
.
unpack_single
=
self
.
unpack_single
f_cpy
.
name
=
name
f_cpy
.
maker
.
fgraph
.
name
=
name
return
f_cpy
return
f_cpy
def
_restore_defaults
(
self
):
def
_restore_defaults
(
self
):
...
@@ -808,7 +836,7 @@ class Function:
...
@@ -808,7 +836,7 @@ class Function:
value
=
value
.
storage
[
0
]
value
=
value
.
storage
[
0
]
self
[
i
]
=
value
self
[
i
]
=
value
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
output_subset
=
None
,
**
kwargs
):
"""
"""
Evaluates value of a function on given arguments.
Evaluates value of a function on given arguments.
...
@@ -836,20 +864,21 @@ class Function:
...
@@ -836,20 +864,21 @@ class Function:
List of outputs on indices/keys from ``output_subset`` or all of them,
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
if ``output_subset`` is not passed.
"""
"""
trust_input
=
self
.
trust_input
input_storage
=
self
.
input_storage
input_storage
=
self
.
input_storage
vm
=
self
.
vm
profile
=
self
.
profile
profile
=
self
.
profile
if
profile
:
if
profile
:
t0
=
time
.
perf_counter
()
t0
=
time
.
perf_counter
()
output_subset
=
kwargs
.
pop
(
"output_subset"
,
None
)
if
output_subset
is
not
None
:
if
output_subset
is
not
None
:
warnings
.
warn
(
"output_subset is deprecated."
,
FutureWarning
)
warnings
.
warn
(
"output_subset is deprecated."
,
FutureWarning
)
if
self
.
output_keys
is
not
None
:
if
self
.
output_keys
is
not
None
:
output_subset
=
[
self
.
output_keys
.
index
(
key
)
for
key
in
output_subset
]
output_subset
=
[
self
.
output_keys
.
index
(
key
)
for
key
in
output_subset
]
# Reinitialize each container's 'provided' counter
# Reinitialize each container's 'provided' counter
if
self
.
trust_input
:
if
trust_input
:
for
arg_container
,
arg
in
zip
(
input_storage
,
args
,
strict
=
False
):
for
arg_container
,
arg
in
zip
(
input_storage
,
args
,
strict
=
False
):
arg_container
.
storage
[
0
]
=
arg
arg_container
.
storage
[
0
]
=
arg
else
:
else
:
...
@@ -908,7 +937,7 @@ class Function:
...
@@ -908,7 +937,7 @@ class Function:
for
k
,
arg
in
kwargs
.
items
():
for
k
,
arg
in
kwargs
.
items
():
self
[
k
]
=
arg
self
[
k
]
=
arg
if
not
self
.
trust_input
:
if
not
trust_input
:
# Collect aliased inputs among the storage space
# Collect aliased inputs among the storage space
for
potential_group
in
self
.
_potential_aliased_input_groups
:
for
potential_group
in
self
.
_potential_aliased_input_groups
:
args_share_memory
:
list
[
list
[
int
]]
=
[]
args_share_memory
:
list
[
list
[
int
]]
=
[]
...
@@ -960,11 +989,7 @@ class Function:
...
@@ -960,11 +989,7 @@ class Function:
if
profile
:
if
profile
:
t0_fn
=
time
.
perf_counter
()
t0_fn
=
time
.
perf_counter
()
try
:
try
:
outputs
=
(
outputs
=
vm
()
if
output_subset
is
None
else
vm
(
output_subset
=
output_subset
)
self
.
vm
()
if
output_subset
is
None
else
self
.
vm
(
output_subset
=
output_subset
)
)
except
Exception
:
except
Exception
:
self
.
_restore_defaults
()
self
.
_restore_defaults
()
if
hasattr
(
self
.
vm
,
"position_of_error"
):
if
hasattr
(
self
.
vm
,
"position_of_error"
):
...
@@ -991,39 +1016,23 @@ class Function:
...
@@ -991,39 +1016,23 @@ class Function:
# Retrieve the values that were computed
# Retrieve the values that were computed
if
outputs
is
None
:
if
outputs
is
None
:
outputs
=
[
x
.
data
for
x
in
self
.
output_storage
]
outputs
=
[
x
.
storage
[
0
]
for
x
in
self
.
output_storage
]
# Remove internal references to required inputs.
# Set updates and filter them out from the returned outputs
# These cannot be re-used anyway.
for
i
,
input_storage
in
self
.
update_input_storage
:
for
arg_container
in
input_storage
:
input_storage
.
storage
[
0
]
=
outputs
[
i
]
if
arg_container
.
required
:
outputs
=
outputs
[:
self
.
n_returned_outputs
]
arg_container
.
storage
[
0
]
=
None
# Remove input and output values from storage data
# if we are allowing garbage collection, remove the
for
storage_data
in
self
.
clear_input_storage_data
:
# output reference from the internal storage cells
storage_data
[
0
]
=
None
if
getattr
(
self
.
vm
,
"allow_gc"
,
False
):
if
getattr
(
vm
,
"allow_gc"
,
False
):
# strict=False because we are in a hot loop
for
storage_data
in
self
.
clear_output_storage_data
:
for
o_container
,
o_variable
in
zip
(
storage_data
[
0
]
=
None
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
,
strict
=
False
):
if
o_variable
.
owner
is
not
None
:
# this node is the variable of computation
# WARNING: This circumvents the 'readonly' attribute in x
o_container
.
storage
[
0
]
=
None
if
getattr
(
self
.
vm
,
"need_update_inputs"
,
True
):
# Update the inputs that have an update function
# strict=False because we are in a hot loop
for
input
,
storage
in
reversed
(
list
(
zip
(
self
.
maker
.
expanded_inputs
,
input_storage
,
strict
=
False
))
):
if
input
.
update
is
not
None
:
storage
.
data
=
outputs
.
pop
()
else
:
outputs
=
outputs
[:
self
.
n_returned_outputs
]
# Put default values back in the storage
# Put default values back in the storage
self
.
_restore_defaults
()
if
self
.
has_defaults
:
self
.
_restore_defaults
()
if
profile
:
if
profile
:
dt_call
=
time
.
perf_counter
()
-
t0
dt_call
=
time
.
perf_counter
()
-
t0
...
@@ -1031,33 +1040,29 @@ class Function:
...
@@ -1031,33 +1040,29 @@ class Function:
self
.
maker
.
mode
.
call_time
+=
dt_call
self
.
maker
.
mode
.
call_time
+=
dt_call
profile
.
fct_callcount
+=
1
profile
.
fct_callcount
+=
1
profile
.
fct_call_time
+=
dt_call
profile
.
fct_call_time
+=
dt_call
if
hasattr
(
self
.
vm
,
"update_profile"
):
if
hasattr
(
vm
,
"update_profile"
):
self
.
vm
.
update_profile
(
profile
)
vm
.
update_profile
(
profile
)
if
profile
.
ignore_first_call
:
if
profile
.
ignore_first_call
:
profile
.
reset
()
profile
.
reset
()
profile
.
ignore_first_call
=
False
profile
.
ignore_first_call
=
False
if
self
.
return_none
:
if
self
.
return_none
:
return
None
return
None
elif
self
.
unpack_single
and
len
(
outputs
)
==
1
and
output_subset
is
None
:
return
outputs
[
0
]
else
:
if
self
.
output_keys
is
not
None
:
assert
len
(
self
.
output_keys
)
==
len
(
outputs
)
if
output_subset
is
None
:
if
output_subset
is
not
None
:
# strict=False because we are in a hot loop
outputs
=
[
outputs
[
i
]
for
i
in
output_subset
]
return
dict
(
zip
(
self
.
output_keys
,
outputs
,
strict
=
False
))
else
:
return
{
self
.
output_keys
[
index
]:
outputs
[
index
]
for
index
in
output_subset
}
if
output_subset
is
None
:
if
self
.
output_keys
is
None
:
return
outputs
if
self
.
unpack_single
:
[
out
]
=
outputs
return
out
else
:
else
:
return
[
outputs
[
i
]
for
i
in
output_subset
]
return
outputs
else
:
output_keys
=
self
.
output_keys
if
output_subset
is
not
None
:
output_keys
=
[
output_keys
[
i
]
for
i
in
output_subset
]
return
dict
(
zip
(
output_keys
,
outputs
,
strict
=
True
))
value
=
property
(
value
=
property
(
lambda
self
:
self
.
_value
,
lambda
self
:
self
.
_value
,
...
@@ -1077,9 +1082,10 @@ class Function:
...
@@ -1077,9 +1082,10 @@ class Function:
# 1.no allow_gc return False
# 1.no allow_gc return False
# 2.has allow_gc, if allow_gc is False, return True
# 2.has allow_gc, if allow_gc is False, return True
if
not
getattr
(
self
.
vm
,
"allow_gc"
,
True
):
if
not
getattr
(
self
.
vm
,
"allow_gc"
,
True
):
for
key
in
self
.
vm
.
storage_map
:
storage_map
=
self
.
vm
.
storage_map
if
not
isinstance
(
key
,
Constant
):
for
key
,
value
in
storage_map
.
items
():
self
.
vm
.
storage_map
[
key
][
0
]
=
None
if
key
.
owner
is
not
None
:
# Not a constant
value
[
0
]
=
None
for
node
in
self
.
nodes_with_inner_function
:
for
node
in
self
.
nodes_with_inner_function
:
if
hasattr
(
node
.
fn
,
"free"
):
if
hasattr
(
node
.
fn
,
"free"
):
...
@@ -1091,10 +1097,6 @@ class Function:
...
@@ -1091,10 +1097,6 @@ class Function:
"""
"""
return
[
i
.
variable
for
i
in
self
.
maker
.
inputs
if
i
.
implicit
]
return
[
i
.
variable
for
i
in
self
.
maker
.
inputs
if
i
.
implicit
]
def
sync_shared
(
self
):
# NOTE: sync was needed on old gpu backend
pass
def
dprint
(
self
,
**
kwargs
):
def
dprint
(
self
,
**
kwargs
):
"""Debug print itself
"""Debug print itself
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论