Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
82f6a14f
提交
82f6a14f
authored
10月 09, 2024
作者:
ricardoV94
提交者:
Ricardo Vieira
10月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cleanup Function.__call__
上级
f0a9ec25
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
136 行增加
和
161 行删除
+136
-161
types.py
pytensor/compile/function/types.py
+113
-125
gradient.py
pytensor/gradient.py
+0
-3
null_type.py
pytensor/graph/null_type.py
+0
-3
type.py
pytensor/graph/type.py
+1
-4
basic.py
pytensor/scalar/basic.py
+0
-7
type_other.py
pytensor/tensor/type_other.py
+0
-6
test_types.py
tests/compile/function/test_types.py
+22
-13
没有找到文件。
pytensor/compile/function/types.py
浏览文件 @
82f6a14f
...
...
@@ -326,8 +326,8 @@ class Function:
def
__init__
(
self
,
vm
:
"VM"
,
input_storage
,
output_storage
,
input_storage
:
list
[
Container
]
,
output_storage
:
list
[
Container
]
,
indices
,
outputs
,
defaults
,
...
...
@@ -372,7 +372,6 @@ class Function:
name
A string name.
"""
# TODO: Rename to `vm`
self
.
vm
=
vm
self
.
input_storage
=
input_storage
self
.
output_storage
=
output_storage
...
...
@@ -388,31 +387,49 @@ class Function:
self
.
nodes_with_inner_function
=
[]
self
.
output_keys
=
output_keys
# See if we have any mutable / borrow inputs
# TODO: this only need to be set if there is more than one input
self
.
_check_for_aliased_inputs
=
False
for
i
in
maker
.
inputs
:
# If the input is a shared variable, the memory region is
# under PyTensor control and so we don't need to check if it
# is aliased as we never do that.
if
(
isinstance
(
i
,
In
)
and
not
i
.
shared
and
(
getattr
(
i
,
"borrow"
,
False
)
or
getattr
(
i
,
"mutable"
,
False
))
assert
len
(
self
.
input_storage
)
==
len
(
self
.
maker
.
fgraph
.
inputs
)
assert
len
(
self
.
output_storage
)
==
len
(
self
.
maker
.
fgraph
.
outputs
)
# Group indexes of inputs that are potentially aliased to each other
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
# even though there could be two distinct types that use the same kinds of underlying objects.
potential_aliased_input_groups
=
[]
for
inp
in
maker
.
inputs
:
# If the input is a shared variable, the memory region is under PyTensor control
# and can't be aliased.
if
not
(
isinstance
(
inp
,
In
)
and
inp
.
borrow
and
not
inp
.
shared
and
hasattr
(
inp
.
variable
.
type
,
"may_share_memory"
)
):
self
.
_check_for_aliased_inputs
=
True
break
continue
for
group
in
potential_aliased_input_groups
:
# If one is super of the other, that means one could be replaced by the other
if
any
(
inp
.
variable
.
type
.
is_super
(
other_inp
.
variable
.
type
)
or
other_inp
.
variable
.
type
.
is_super
(
inp
.
variable
.
type
)
for
other_inp
in
group
):
group
.
append
(
inp
)
break
else
:
# no break
# Input makes a new group
potential_aliased_input_groups
.
append
([
inp
])
# Potential aliased inputs are those that belong to the same group
self
.
_potential_aliased_input_groups
:
tuple
[
tuple
[
int
,
...
],
...
]
=
tuple
(
tuple
(
maker
.
inputs
.
index
(
inp
)
for
inp
in
group
)
for
group
in
potential_aliased_input_groups
if
len
(
group
)
>
1
)
# We will be popping stuff off this `containers` object. It is a copy.
containers
=
list
(
self
.
input_storage
)
finder
=
{}
inv_finder
=
{}
def
distribute
(
indices
,
cs
,
value
):
input
.
distribute
(
value
,
indices
,
cs
)
for
c
in
cs
:
c
.
provided
+=
1
# Store the list of names of named inputs.
named_inputs
=
[]
# Count the number of un-named inputs.
...
...
@@ -777,6 +794,13 @@ class Function:
f_cpy
.
maker
.
fgraph
.
name
=
name
return
f_cpy
def
_restore_defaults
(
self
):
for
i
,
(
required
,
refeed
,
value
)
in
enumerate
(
self
.
defaults
):
if
refeed
:
if
isinstance
(
value
,
Container
):
value
=
value
.
storage
[
0
]
self
[
i
]
=
value
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""
Evaluates value of a function on given arguments.
...
...
@@ -805,16 +829,11 @@ class Function:
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
"""
def
restore_defaults
():
for
i
,
(
required
,
refeed
,
value
)
in
enumerate
(
self
.
defaults
):
if
refeed
:
if
isinstance
(
value
,
Container
):
value
=
value
.
storage
[
0
]
self
[
i
]
=
value
input_storage
=
self
.
input_storage
profile
=
self
.
profile
t0
=
time
.
perf_counter
()
if
profile
:
t0
=
time
.
perf_counter
()
output_subset
=
kwargs
.
pop
(
"output_subset"
,
None
)
if
output_subset
is
not
None
and
self
.
output_keys
is
not
None
:
...
...
@@ -822,35 +841,31 @@ class Function:
# Reinitialize each container's 'provided' counter
if
self
.
trust_input
:
i
=
0
for
arg
in
args
:
s
=
self
.
input_storage
[
i
]
s
.
storage
[
0
]
=
arg
i
+=
1
for
arg_container
,
arg
in
zip
(
input_storage
,
args
,
strict
=
False
):
arg_container
.
storage
[
0
]
=
arg
else
:
for
c
in
self
.
input_storage
:
c
.
provided
=
0
for
arg_container
in
input_storage
:
arg_container
.
provided
=
0
if
len
(
args
)
+
len
(
kwargs
)
>
len
(
self
.
input_storage
):
if
len
(
args
)
+
len
(
kwargs
)
>
len
(
input_storage
):
raise
TypeError
(
"Too many parameter passed to pytensor function"
)
# Set positional arguments
i
=
0
for
arg
in
args
:
# TODO: provide a option for skipping the filter if we really
# want speed.
s
=
self
.
input_storage
[
i
]
# see this emails for a discuation about None as input
for
arg_container
,
arg
in
zip
(
input_storage
,
args
,
strict
=
False
):
# See discussion about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
if
arg
is
None
:
s
.
storage
[
0
]
=
arg
arg_container
.
storage
[
0
]
=
arg
else
:
try
:
s
.
storage
[
0
]
=
s
.
type
.
filter
(
arg
,
strict
=
s
.
strict
,
allow_downcast
=
s
.
allow_downcast
arg_container
.
storage
[
0
]
=
arg_container
.
type
.
filter
(
arg
,
strict
=
arg_container
.
strict
,
allow_downcast
=
arg_container
.
allow_downcast
,
)
except
Exception
as
e
:
i
=
input_storage
.
index
(
arg_container
)
function_name
=
"pytensor function"
argument_name
=
"argument"
if
self
.
name
:
...
...
@@ -875,85 +890,66 @@ class Function:
+
function_name
+
f
" at index {int(i)} (0-based). {where}"
)
+
e
.
args
restore_defaults
()
self
.
_
restore_defaults
()
raise
s
.
provided
+=
1
i
+=
1
arg_container
.
provided
+=
1
# Set keyword arguments
if
kwargs
:
# for speed, skip the items for empty kwargs
for
k
,
arg
in
kwargs
.
items
():
self
[
k
]
=
arg
if
(
not
self
.
trust_input
and
# The getattr is only needed for old pickle
getattr
(
self
,
"_check_for_aliased_inputs"
,
True
)
):
if
not
self
.
trust_input
:
# Collect aliased inputs among the storage space
args_share_memory
=
[]
for
i
in
range
(
len
(
self
.
input_storage
)):
i_var
=
self
.
maker
.
inputs
[
i
]
.
variable
i_val
=
self
.
input_storage
[
i
]
.
storage
[
0
]
if
hasattr
(
i_var
.
type
,
"may_share_memory"
):
is_aliased
=
False
for
j
in
range
(
len
(
args_share_memory
)):
group_j
=
zip
(
[
self
.
maker
.
inputs
[
k
]
.
variable
for
k
in
args_share_memory
[
j
]
],
[
self
.
input_storage
[
k
]
.
storage
[
0
]
for
k
in
args_share_memory
[
j
]
],
)
for
potential_group
in
self
.
_potential_aliased_input_groups
:
args_share_memory
:
list
[
list
[
int
]]
=
[]
for
i
in
potential_group
:
i_type
=
self
.
maker
.
inputs
[
i
]
.
variable
.
type
i_val
=
input_storage
[
i
]
.
storage
[
0
]
# Check if value is aliased with any of the values in one of the groups
for
j_group
in
args_share_memory
:
if
any
(
(
var
.
type
is
i_var
.
type
and
var
.
type
.
may_share_memory
(
val
,
i_val
)
)
for
(
var
,
val
)
in
group_j
i_type
.
may_share_memory
(
input_storage
[
j
]
.
storage
[
0
],
i_val
)
for
j
in
j_group
):
is_aliased
=
True
args_share_memory
[
j
]
.
append
(
i
)
j_group
.
append
(
i
)
break
if
not
is_aliased
:
else
:
# no break
# Create a new group
args_share_memory
.
append
([
i
])
# Check for groups of more than one argument that share memory
for
group
in
args_share_memory
:
if
len
(
group
)
>
1
:
# copy all but the first
for
j
in
group
[
1
:]:
self
.
input_storage
[
j
]
.
storage
[
0
]
=
copy
.
copy
(
self
.
input_storage
[
j
]
.
storage
[
0
]
)
# Check for groups of more than one argument that share memory
for
group
in
args_share_memory
:
if
len
(
group
)
>
1
:
# copy all but the first
for
i
in
group
[
1
:]:
input_storage
[
i
]
.
storage
[
0
]
=
copy
.
copy
(
input_storage
[
i
]
.
storage
[
0
]
)
# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
if
not
self
.
trust_input
:
for
c
in
self
.
input_storage
:
if
c
.
required
and
not
c
.
provided
:
restore_defaults
()
# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
for
arg_container
in
input_storage
:
if
arg_container
.
required
and
not
arg_container
.
provided
:
self
.
_restore_defaults
()
raise
TypeError
(
f
"Missing required input: {getattr(self.inv_finder[
c], 'variable', self.inv_finder[c
])}"
f
"Missing required input: {getattr(self.inv_finder[
arg_container], 'variable', self.inv_finder[arg_container
])}"
)
if
c
.
provided
>
1
:
restore_defaults
()
if
arg_container
.
provided
>
1
:
self
.
_
restore_defaults
()
raise
TypeError
(
f
"Multiple values for input: {getattr(self.inv_finder[
c], 'variable', self.inv_finder[c
])}"
f
"Multiple values for input: {getattr(self.inv_finder[
arg_container], 'variable', self.inv_finder[arg_container
])}"
)
if
c
.
implicit
and
c
.
provided
>
0
:
restore_defaults
()
if
arg_container
.
implicit
and
arg_container
.
provided
>
0
:
self
.
_
restore_defaults
()
raise
TypeError
(
f
"Tried to provide value for implicit input: {getattr(self.inv_finder[
c], 'variable', self.inv_finder[c
])}"
f
"Tried to provide value for implicit input: {getattr(self.inv_finder[
arg_container], 'variable', self.inv_finder[arg_container
])}"
)
# Do the actual work
t0_fn
=
time
.
perf_counter
()
if
profile
:
t0_fn
=
time
.
perf_counter
()
try
:
outputs
=
(
self
.
vm
()
...
...
@@ -961,7 +957,7 @@ class Function:
else
self
.
vm
(
output_subset
=
output_subset
)
)
except
Exception
:
restore_defaults
()
self
.
_
restore_defaults
()
if
hasattr
(
self
.
vm
,
"position_of_error"
):
# this is a new vm-provided function or c linker
# they need this because the exception manipulation
...
...
@@ -979,26 +975,24 @@ class Function:
# old-style linkers raise their own exceptions
raise
dt_fn
=
time
.
perf_counter
()
-
t0_fn
self
.
maker
.
mode
.
fn_time
+=
dt_fn
if
profile
:
dt_fn
=
time
.
perf_counter
()
-
t0_fn
self
.
maker
.
mode
.
fn_time
+=
dt_fn
profile
.
vm_call_time
+=
dt_fn
# Retrieve the values that were computed
if
outputs
is
None
:
outputs
=
[
x
.
data
for
x
in
self
.
output_storage
]
assert
len
(
outputs
)
==
len
(
self
.
output_storage
)
# Remove internal references to required inputs.
# These cannot be re-used anyway.
for
c
in
self
.
input_storage
:
if
c
.
required
:
c
.
storage
[
0
]
=
None
for
arg_container
in
input_storage
:
if
arg_container
.
required
:
arg_container
.
storage
[
0
]
=
None
# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
if
getattr
(
self
.
vm
,
"allow_gc"
,
False
):
assert
len
(
self
.
output_storage
)
==
len
(
self
.
maker
.
fgraph
.
outputs
)
for
o_container
,
o_variable
in
zip
(
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
):
...
...
@@ -1007,12 +1001,10 @@ class Function:
# WARNING: This circumvents the 'readonly' attribute in x
o_container
.
storage
[
0
]
=
None
# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
# perform the updates themselves
if
getattr
(
self
.
vm
,
"need_update_inputs"
,
True
):
# Update the inputs that have an update function
for
input
,
storage
in
reversed
(
list
(
zip
(
self
.
maker
.
expanded_inputs
,
self
.
input_storage
))
list
(
zip
(
self
.
maker
.
expanded_inputs
,
input_storage
))
):
if
input
.
update
is
not
None
:
storage
.
data
=
outputs
.
pop
()
...
...
@@ -1020,17 +1012,12 @@ class Function:
outputs
=
outputs
[:
self
.
n_returned_outputs
]
# Put default values back in the storage
restore_defaults
()
#
# NOTE: This logic needs to be replicated in
# scan.
# grep for 'PROFILE_CODE'
#
dt_call
=
time
.
perf_counter
()
-
t0
pytensor
.
compile
.
profiling
.
total_fct_exec_time
+=
dt_call
self
.
maker
.
mode
.
call_time
+=
dt_call
self
.
_restore_defaults
()
if
profile
:
dt_call
=
time
.
perf_counter
()
-
t0
pytensor
.
compile
.
profiling
.
total_fct_exec_time
+=
dt_call
self
.
maker
.
mode
.
call_time
+=
dt_call
profile
.
fct_callcount
+=
1
profile
.
fct_call_time
+=
dt_call
if
hasattr
(
self
.
vm
,
"update_profile"
):
...
...
@@ -1038,6 +1025,7 @@ class Function:
if
profile
.
ignore_first_call
:
profile
.
reset
()
profile
.
ignore_first_call
=
False
if
self
.
return_none
:
return
None
elif
self
.
unpack_single
and
len
(
outputs
)
==
1
and
output_subset
is
None
:
...
...
pytensor/gradient.py
浏览文件 @
82f6a14f
...
...
@@ -128,9 +128,6 @@ class DisconnectedType(Type):
" a symbolic placeholder."
)
def
may_share_memory
(
a
,
b
):
return
False
def
value_eq
(
a
,
b
,
force_same_dtype
=
True
):
raise
AssertionError
(
"If you're assigning to a DisconnectedType you're"
...
...
pytensor/graph/null_type.py
浏览文件 @
82f6a14f
...
...
@@ -26,9 +26,6 @@ class NullType(Type):
def
filter_variable
(
self
,
other
,
allow_convert
=
True
):
raise
ValueError
(
"No values may be assigned to a NullType"
)
def
may_share_memory
(
a
,
b
):
return
False
def
values_eq
(
self
,
a
,
b
,
force_same_dtype
=
True
):
raise
ValueError
(
"NullType has no values to compare"
)
...
...
pytensor/graph/type.py
浏览文件 @
82f6a14f
...
...
@@ -48,10 +48,7 @@ class Type(MetaObject, Generic[D]):
unique element (i.e. it uses `self.__eq__`).
"""
if
self
==
otype
:
return
True
return
False
return
self
==
otype
def
is_super
(
self
,
otype
:
"Type"
)
->
bool
|
None
:
"""Determine if `self` is a supertype of `otype`.
...
...
pytensor/scalar/basic.py
浏览文件 @
82f6a14f
...
...
@@ -303,13 +303,6 @@ class ScalarType(CType, HasDataType, HasShape):
dtype
=
self
.
dtype
return
type
(
self
)(
dtype
)
@staticmethod
def
may_share_memory
(
a
,
b
):
# This class represent basic c type, represented in python
# with numpy.scalar. They are read only. So from python, they
# can never share memory.
return
False
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
py_type
=
self
.
dtype_specs
()[
0
]
if
strict
and
not
isinstance
(
data
,
py_type
):
...
...
pytensor/tensor/type_other.py
浏览文件 @
82f6a14f
...
...
@@ -126,12 +126,6 @@ class NoneTypeT(Generic):
else
:
raise
TypeError
(
"Expected None!"
)
@staticmethod
def
may_share_memory
(
a
,
b
):
# None never share memory between object, in the sense of DebugMode.
# Python None are singleton
return
False
none_type_t
=
NoneTypeT
()
...
...
tests/compile/function/test_types.py
浏览文件 @
82f6a14f
...
...
@@ -730,6 +730,8 @@ class TestFunction:
s1
=
shared
(
b
)
s2
=
shared
(
b
)
x1
=
vector
()
x2
=
vector
(
shape
=
(
3
,))
x3
=
vector
(
shape
=
(
1
,))
# Assert cases we should not check for aliased inputs
for
d
in
[
...
...
@@ -737,27 +739,29 @@ class TestFunction:
dict
(
outputs
=
[
s1
+
1
,
s2
+
3
]),
dict
(
outputs
=
[
s1
+
1
],
updates
=
[(
s2
,
s2
+
3
)]),
dict
(
inputs
=
[
x1
],
outputs
=
[
x1
+
1
],
updates
=
[(
s2
,
s2
+
3
)]),
dict
(
inputs
=
[
In
(
x1
,
mutable
=
True
)],
outputs
=
[
x1
+
1
],
updates
=
[(
s2
,
s2
+
3
)]
),
dict
(
inputs
=
[
In
(
x2
,
mutable
=
True
),
In
(
x3
,
mutable
=
True
)],
outputs
=
[
x2
+
2
,
x3
+
3
],
),
]:
if
"inputs"
not
in
d
:
d
[
"inputs"
]
=
[]
f
=
function
(
**
d
)
assert
not
f
.
_
check_for_aliased_input
s
,
d
assert
not
f
.
_
potential_aliased_input_group
s
,
d
# Assert cases we should check for aliased inputs
for
d
in
[
dict
(
inputs
=
[
In
(
x1
,
borrow
=
True
)],
outputs
=
[
x1
+
1
],
updates
=
[(
s2
,
s2
+
3
)],
),
dict
(
inputs
=
[
In
(
x1
,
borrow
=
True
,
mutable
=
True
)],
outputs
=
[
x1
+
1
],
inputs
=
[
In
(
x1
,
mutable
=
True
),
In
(
x2
,
mutable
=
True
)],
outputs
=
[
x1
+
1
,
x2
+
2
],
updates
=
[(
s2
,
s2
+
3
)],
),
dict
(
inputs
=
[
In
(
x1
,
mutable
=
True
)],
outputs
=
[
x1
+
1
],
inputs
=
[
In
(
x1
,
mutable
=
True
)
,
In
(
x3
,
mutable
=
True
)
],
outputs
=
[
x1
+
1
,
x3
+
3
],
updates
=
[(
s2
,
s2
+
3
)],
),
]:
...
...
@@ -765,7 +769,7 @@ class TestFunction:
d
[
"inputs"
]
=
[]
f
=
function
(
**
d
)
assert
f
.
_
check_for_aliased_input
s
,
d
assert
f
.
_
potential_aliased_input_group
s
,
d
def
test_output_dictionary
(
self
):
# Tests that function works when outputs is a dictionary
...
...
@@ -879,7 +883,7 @@ class TestPicklefunction:
f
=
function
(
[
x
,
In
(
a
,
value
=
1.0
,
name
=
"a"
),
In
(
a
,
value
=
1.0
,
name
=
"a"
,
mutable
=
True
),
In
(
s
,
value
=
0.0
,
update
=
s
+
a
*
x
,
mutable
=
True
),
],
s
+
a
*
x
,
...
...
@@ -901,7 +905,12 @@ class TestPicklefunction:
assert
x
not
in
g
.
container
assert
x
not
in
g
.
value
assert
len
(
f
.
defaults
)
==
len
(
g
.
defaults
)
assert
f
.
_check_for_aliased_inputs
is
g
.
_check_for_aliased_inputs
# Shared variable is the first input
assert
(
f
.
_potential_aliased_input_groups
==
g
.
_potential_aliased_input_groups
==
((
1
,
2
),)
)
assert
f
.
name
==
g
.
name
assert
f
.
maker
.
fgraph
.
name
==
g
.
maker
.
fgraph
.
name
# print(f"{f.defaults = }")
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论