Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
82f6a14f
提交
82f6a14f
authored
10月 09, 2024
作者:
ricardoV94
提交者:
Ricardo Vieira
10月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cleanup Function.__call__
上级
f0a9ec25
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
122 行增加
和
147 行删除
+122
-147
types.py
pytensor/compile/function/types.py
+99
-111
gradient.py
pytensor/gradient.py
+0
-3
null_type.py
pytensor/graph/null_type.py
+0
-3
type.py
pytensor/graph/type.py
+1
-4
basic.py
pytensor/scalar/basic.py
+0
-7
type_other.py
pytensor/tensor/type_other.py
+0
-6
test_types.py
tests/compile/function/test_types.py
+22
-13
没有找到文件。
pytensor/compile/function/types.py
浏览文件 @
82f6a14f
...
@@ -326,8 +326,8 @@ class Function:
...
@@ -326,8 +326,8 @@ class Function:
def
__init__
(
def
__init__
(
self
,
self
,
vm
:
"VM"
,
vm
:
"VM"
,
input_storage
,
input_storage
:
list
[
Container
]
,
output_storage
,
output_storage
:
list
[
Container
]
,
indices
,
indices
,
outputs
,
outputs
,
defaults
,
defaults
,
...
@@ -372,7 +372,6 @@ class Function:
...
@@ -372,7 +372,6 @@ class Function:
name
name
A string name.
A string name.
"""
"""
# TODO: Rename to `vm`
self
.
vm
=
vm
self
.
vm
=
vm
self
.
input_storage
=
input_storage
self
.
input_storage
=
input_storage
self
.
output_storage
=
output_storage
self
.
output_storage
=
output_storage
...
@@ -388,31 +387,49 @@ class Function:
...
@@ -388,31 +387,49 @@ class Function:
self
.
nodes_with_inner_function
=
[]
self
.
nodes_with_inner_function
=
[]
self
.
output_keys
=
output_keys
self
.
output_keys
=
output_keys
# See if we have any mutable / borrow inputs
assert
len
(
self
.
input_storage
)
==
len
(
self
.
maker
.
fgraph
.
inputs
)
# TODO: this only need to be set if there is more than one input
assert
len
(
self
.
output_storage
)
==
len
(
self
.
maker
.
fgraph
.
outputs
)
self
.
_check_for_aliased_inputs
=
False
for
i
in
maker
.
inputs
:
# Group indexes of inputs that are potentially aliased to each other
# If the input is a shared variable, the memory region is
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
# under PyTensor control and so we don't need to check if it
# even though there could be two distinct types that use the same kinds of underlying objects.
# is aliased as we never do that.
potential_aliased_input_groups
=
[]
if
(
for
inp
in
maker
.
inputs
:
isinstance
(
i
,
In
)
# If the input is a shared variable, the memory region is under PyTensor control
and
not
i
.
shared
# and can't be aliased.
and
(
getattr
(
i
,
"borrow"
,
False
)
or
getattr
(
i
,
"mutable"
,
False
))
if
not
(
isinstance
(
inp
,
In
)
and
inp
.
borrow
and
not
inp
.
shared
and
hasattr
(
inp
.
variable
.
type
,
"may_share_memory"
)
):
):
self
.
_check_for_aliased_inputs
=
True
continue
for
group
in
potential_aliased_input_groups
:
# If one is super of the other, that means one could be replaced by the other
if
any
(
inp
.
variable
.
type
.
is_super
(
other_inp
.
variable
.
type
)
or
other_inp
.
variable
.
type
.
is_super
(
inp
.
variable
.
type
)
for
other_inp
in
group
):
group
.
append
(
inp
)
break
break
else
:
# no break
# Input makes a new group
potential_aliased_input_groups
.
append
([
inp
])
# Potential aliased inputs are those that belong to the same group
self
.
_potential_aliased_input_groups
:
tuple
[
tuple
[
int
,
...
],
...
]
=
tuple
(
tuple
(
maker
.
inputs
.
index
(
inp
)
for
inp
in
group
)
for
group
in
potential_aliased_input_groups
if
len
(
group
)
>
1
)
# We will be popping stuff off this `containers` object. It is a copy.
# We will be popping stuff off this `containers` object. It is a copy.
containers
=
list
(
self
.
input_storage
)
containers
=
list
(
self
.
input_storage
)
finder
=
{}
finder
=
{}
inv_finder
=
{}
inv_finder
=
{}
def
distribute
(
indices
,
cs
,
value
):
input
.
distribute
(
value
,
indices
,
cs
)
for
c
in
cs
:
c
.
provided
+=
1
# Store the list of names of named inputs.
# Store the list of names of named inputs.
named_inputs
=
[]
named_inputs
=
[]
# Count the number of un-named inputs.
# Count the number of un-named inputs.
...
@@ -777,6 +794,13 @@ class Function:
...
@@ -777,6 +794,13 @@ class Function:
f_cpy
.
maker
.
fgraph
.
name
=
name
f_cpy
.
maker
.
fgraph
.
name
=
name
return
f_cpy
return
f_cpy
def
_restore_defaults
(
self
):
for
i
,
(
required
,
refeed
,
value
)
in
enumerate
(
self
.
defaults
):
if
refeed
:
if
isinstance
(
value
,
Container
):
value
=
value
.
storage
[
0
]
self
[
i
]
=
value
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""
"""
Evaluates value of a function on given arguments.
Evaluates value of a function on given arguments.
...
@@ -805,15 +829,10 @@ class Function:
...
@@ -805,15 +829,10 @@ class Function:
List of outputs on indices/keys from ``output_subset`` or all of them,
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
if ``output_subset`` is not passed.
"""
"""
input_storage
=
self
.
input_storage
def
restore_defaults
():
for
i
,
(
required
,
refeed
,
value
)
in
enumerate
(
self
.
defaults
):
if
refeed
:
if
isinstance
(
value
,
Container
):
value
=
value
.
storage
[
0
]
self
[
i
]
=
value
profile
=
self
.
profile
profile
=
self
.
profile
if
profile
:
t0
=
time
.
perf_counter
()
t0
=
time
.
perf_counter
()
output_subset
=
kwargs
.
pop
(
"output_subset"
,
None
)
output_subset
=
kwargs
.
pop
(
"output_subset"
,
None
)
...
@@ -822,35 +841,31 @@ class Function:
...
@@ -822,35 +841,31 @@ class Function:
# Reinitialize each container's 'provided' counter
# Reinitialize each container's 'provided' counter
if
self
.
trust_input
:
if
self
.
trust_input
:
i
=
0
for
arg_container
,
arg
in
zip
(
input_storage
,
args
,
strict
=
False
):
for
arg
in
args
:
arg_container
.
storage
[
0
]
=
arg
s
=
self
.
input_storage
[
i
]
s
.
storage
[
0
]
=
arg
i
+=
1
else
:
else
:
for
c
in
self
.
input_storage
:
for
arg_container
in
input_storage
:
c
.
provided
=
0
arg_container
.
provided
=
0
if
len
(
args
)
+
len
(
kwargs
)
>
len
(
self
.
input_storage
):
if
len
(
args
)
+
len
(
kwargs
)
>
len
(
input_storage
):
raise
TypeError
(
"Too many parameter passed to pytensor function"
)
raise
TypeError
(
"Too many parameter passed to pytensor function"
)
# Set positional arguments
# Set positional arguments
i
=
0
for
arg_container
,
arg
in
zip
(
input_storage
,
args
,
strict
=
False
):
for
arg
in
args
:
# See discussion about None as input
# TODO: provide a option for skipping the filter if we really
# want speed.
s
=
self
.
input_storage
[
i
]
# see this emails for a discuation about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
if
arg
is
None
:
if
arg
is
None
:
s
.
storage
[
0
]
=
arg
arg_container
.
storage
[
0
]
=
arg
else
:
else
:
try
:
try
:
s
.
storage
[
0
]
=
s
.
type
.
filter
(
arg_container
.
storage
[
0
]
=
arg_container
.
type
.
filter
(
arg
,
strict
=
s
.
strict
,
allow_downcast
=
s
.
allow_downcast
arg
,
strict
=
arg_container
.
strict
,
allow_downcast
=
arg_container
.
allow_downcast
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
i
=
input_storage
.
index
(
arg_container
)
function_name
=
"pytensor function"
function_name
=
"pytensor function"
argument_name
=
"argument"
argument_name
=
"argument"
if
self
.
name
:
if
self
.
name
:
...
@@ -875,84 +890,65 @@ class Function:
...
@@ -875,84 +890,65 @@ class Function:
+
function_name
+
function_name
+
f
" at index {int(i)} (0-based). {where}"
+
f
" at index {int(i)} (0-based). {where}"
)
+
e
.
args
)
+
e
.
args
restore_defaults
()
self
.
_
restore_defaults
()
raise
raise
s
.
provided
+=
1
arg_container
.
provided
+=
1
i
+=
1
# Set keyword arguments
# Set keyword arguments
if
kwargs
:
# for speed, skip the items for empty kwargs
if
kwargs
:
# for speed, skip the items for empty kwargs
for
k
,
arg
in
kwargs
.
items
():
for
k
,
arg
in
kwargs
.
items
():
self
[
k
]
=
arg
self
[
k
]
=
arg
if
(
if
not
self
.
trust_input
:
not
self
.
trust_input
and
# The getattr is only needed for old pickle
getattr
(
self
,
"_check_for_aliased_inputs"
,
True
)
):
# Collect aliased inputs among the storage space
# Collect aliased inputs among the storage space
args_share_memory
=
[]
for
potential_group
in
self
.
_potential_aliased_input_groups
:
for
i
in
range
(
len
(
self
.
input_storage
)):
args_share_memory
:
list
[
list
[
int
]]
=
[]
i_var
=
self
.
maker
.
inputs
[
i
]
.
variable
for
i
in
potential_group
:
i_val
=
self
.
input_storage
[
i
]
.
storage
[
0
]
i_type
=
self
.
maker
.
inputs
[
i
]
.
variable
.
type
if
hasattr
(
i_var
.
type
,
"may_share_memory"
):
i_val
=
input_storage
[
i
]
.
storage
[
0
]
is_aliased
=
False
for
j
in
range
(
len
(
args_share_memory
)):
# Check if value is aliased with any of the values in one of the groups
group_j
=
zip
(
for
j_group
in
args_share_memory
:
[
self
.
maker
.
inputs
[
k
]
.
variable
for
k
in
args_share_memory
[
j
]
],
[
self
.
input_storage
[
k
]
.
storage
[
0
]
for
k
in
args_share_memory
[
j
]
],
)
if
any
(
if
any
(
(
i_type
.
may_share_memory
(
input_storage
[
j
]
.
storage
[
0
],
i_val
)
var
.
type
is
i_var
.
type
for
j
in
j_group
and
var
.
type
.
may_share_memory
(
val
,
i_val
)
)
for
(
var
,
val
)
in
group_j
):
):
is_aliased
=
True
j_group
.
append
(
i
)
args_share_memory
[
j
]
.
append
(
i
)
break
break
else
:
# no break
if
not
is_aliased
:
# Create a new group
args_share_memory
.
append
([
i
])
args_share_memory
.
append
([
i
])
# Check for groups of more than one argument that share memory
# Check for groups of more than one argument that share memory
for
group
in
args_share_memory
:
for
group
in
args_share_memory
:
if
len
(
group
)
>
1
:
if
len
(
group
)
>
1
:
# copy all but the first
# copy all but the first
for
j
in
group
[
1
:]:
for
i
in
group
[
1
:]:
self
.
input_storage
[
j
]
.
storage
[
0
]
=
copy
.
copy
(
input_storage
[
i
]
.
storage
[
0
]
=
copy
.
copy
(
self
.
input_storage
[
j
]
.
storage
[
0
]
input_storage
[
i
]
.
storage
[
0
]
)
)
# Check if inputs are missing, or if inputs were set more than once, or
# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
# if we tried to provide inputs that are supposed to be implicit.
if
not
self
.
trust_input
:
for
arg_container
in
input_storage
:
for
c
in
self
.
input_storage
:
if
arg_container
.
required
and
not
arg_container
.
provided
:
if
c
.
required
and
not
c
.
provided
:
self
.
_restore_defaults
()
restore_defaults
()
raise
TypeError
(
raise
TypeError
(
f
"Missing required input: {getattr(self.inv_finder[
c], 'variable', self.inv_finder[c
])}"
f
"Missing required input: {getattr(self.inv_finder[
arg_container], 'variable', self.inv_finder[arg_container
])}"
)
)
if
c
.
provided
>
1
:
if
arg_container
.
provided
>
1
:
restore_defaults
()
self
.
_
restore_defaults
()
raise
TypeError
(
raise
TypeError
(
f
"Multiple values for input: {getattr(self.inv_finder[
c], 'variable', self.inv_finder[c
])}"
f
"Multiple values for input: {getattr(self.inv_finder[
arg_container], 'variable', self.inv_finder[arg_container
])}"
)
)
if
c
.
implicit
and
c
.
provided
>
0
:
if
arg_container
.
implicit
and
arg_container
.
provided
>
0
:
restore_defaults
()
self
.
_
restore_defaults
()
raise
TypeError
(
raise
TypeError
(
f
"Tried to provide value for implicit input: {getattr(self.inv_finder[
c], 'variable', self.inv_finder[c
])}"
f
"Tried to provide value for implicit input: {getattr(self.inv_finder[
arg_container], 'variable', self.inv_finder[arg_container
])}"
)
)
# Do the actual work
# Do the actual work
if
profile
:
t0_fn
=
time
.
perf_counter
()
t0_fn
=
time
.
perf_counter
()
try
:
try
:
outputs
=
(
outputs
=
(
...
@@ -961,7 +957,7 @@ class Function:
...
@@ -961,7 +957,7 @@ class Function:
else
self
.
vm
(
output_subset
=
output_subset
)
else
self
.
vm
(
output_subset
=
output_subset
)
)
)
except
Exception
:
except
Exception
:
restore_defaults
()
self
.
_
restore_defaults
()
if
hasattr
(
self
.
vm
,
"position_of_error"
):
if
hasattr
(
self
.
vm
,
"position_of_error"
):
# this is a new vm-provided function or c linker
# this is a new vm-provided function or c linker
# they need this because the exception manipulation
# they need this because the exception manipulation
...
@@ -979,26 +975,24 @@ class Function:
...
@@ -979,26 +975,24 @@ class Function:
# old-style linkers raise their own exceptions
# old-style linkers raise their own exceptions
raise
raise
if
profile
:
dt_fn
=
time
.
perf_counter
()
-
t0_fn
dt_fn
=
time
.
perf_counter
()
-
t0_fn
self
.
maker
.
mode
.
fn_time
+=
dt_fn
self
.
maker
.
mode
.
fn_time
+=
dt_fn
if
profile
:
profile
.
vm_call_time
+=
dt_fn
profile
.
vm_call_time
+=
dt_fn
# Retrieve the values that were computed
# Retrieve the values that were computed
if
outputs
is
None
:
if
outputs
is
None
:
outputs
=
[
x
.
data
for
x
in
self
.
output_storage
]
outputs
=
[
x
.
data
for
x
in
self
.
output_storage
]
assert
len
(
outputs
)
==
len
(
self
.
output_storage
)
# Remove internal references to required inputs.
# Remove internal references to required inputs.
# These cannot be re-used anyway.
# These cannot be re-used anyway.
for
c
in
self
.
input_storage
:
for
arg_container
in
input_storage
:
if
c
.
required
:
if
arg_container
.
required
:
c
.
storage
[
0
]
=
None
arg_container
.
storage
[
0
]
=
None
# if we are allowing garbage collection, remove the
# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
# output reference from the internal storage cells
if
getattr
(
self
.
vm
,
"allow_gc"
,
False
):
if
getattr
(
self
.
vm
,
"allow_gc"
,
False
):
assert
len
(
self
.
output_storage
)
==
len
(
self
.
maker
.
fgraph
.
outputs
)
for
o_container
,
o_variable
in
zip
(
for
o_container
,
o_variable
in
zip
(
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
):
):
...
@@ -1007,12 +1001,10 @@ class Function:
...
@@ -1007,12 +1001,10 @@ class Function:
# WARNING: This circumvents the 'readonly' attribute in x
# WARNING: This circumvents the 'readonly' attribute in x
o_container
.
storage
[
0
]
=
None
o_container
.
storage
[
0
]
=
None
# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
# perform the updates themselves
if
getattr
(
self
.
vm
,
"need_update_inputs"
,
True
):
if
getattr
(
self
.
vm
,
"need_update_inputs"
,
True
):
# Update the inputs that have an update function
# Update the inputs that have an update function
for
input
,
storage
in
reversed
(
for
input
,
storage
in
reversed
(
list
(
zip
(
self
.
maker
.
expanded_inputs
,
self
.
input_storage
))
list
(
zip
(
self
.
maker
.
expanded_inputs
,
input_storage
))
):
):
if
input
.
update
is
not
None
:
if
input
.
update
is
not
None
:
storage
.
data
=
outputs
.
pop
()
storage
.
data
=
outputs
.
pop
()
...
@@ -1020,17 +1012,12 @@ class Function:
...
@@ -1020,17 +1012,12 @@ class Function:
outputs
=
outputs
[:
self
.
n_returned_outputs
]
outputs
=
outputs
[:
self
.
n_returned_outputs
]
# Put default values back in the storage
# Put default values back in the storage
restore_defaults
()
self
.
_restore_defaults
()
#
# NOTE: This logic needs to be replicated in
# scan.
# grep for 'PROFILE_CODE'
#
if
profile
:
dt_call
=
time
.
perf_counter
()
-
t0
dt_call
=
time
.
perf_counter
()
-
t0
pytensor
.
compile
.
profiling
.
total_fct_exec_time
+=
dt_call
pytensor
.
compile
.
profiling
.
total_fct_exec_time
+=
dt_call
self
.
maker
.
mode
.
call_time
+=
dt_call
self
.
maker
.
mode
.
call_time
+=
dt_call
if
profile
:
profile
.
fct_callcount
+=
1
profile
.
fct_callcount
+=
1
profile
.
fct_call_time
+=
dt_call
profile
.
fct_call_time
+=
dt_call
if
hasattr
(
self
.
vm
,
"update_profile"
):
if
hasattr
(
self
.
vm
,
"update_profile"
):
...
@@ -1038,6 +1025,7 @@ class Function:
...
@@ -1038,6 +1025,7 @@ class Function:
if
profile
.
ignore_first_call
:
if
profile
.
ignore_first_call
:
profile
.
reset
()
profile
.
reset
()
profile
.
ignore_first_call
=
False
profile
.
ignore_first_call
=
False
if
self
.
return_none
:
if
self
.
return_none
:
return
None
return
None
elif
self
.
unpack_single
and
len
(
outputs
)
==
1
and
output_subset
is
None
:
elif
self
.
unpack_single
and
len
(
outputs
)
==
1
and
output_subset
is
None
:
...
...
pytensor/gradient.py
浏览文件 @
82f6a14f
...
@@ -128,9 +128,6 @@ class DisconnectedType(Type):
...
@@ -128,9 +128,6 @@ class DisconnectedType(Type):
" a symbolic placeholder."
" a symbolic placeholder."
)
)
def
may_share_memory
(
a
,
b
):
return
False
def
value_eq
(
a
,
b
,
force_same_dtype
=
True
):
def
value_eq
(
a
,
b
,
force_same_dtype
=
True
):
raise
AssertionError
(
raise
AssertionError
(
"If you're assigning to a DisconnectedType you're"
"If you're assigning to a DisconnectedType you're"
...
...
pytensor/graph/null_type.py
浏览文件 @
82f6a14f
...
@@ -26,9 +26,6 @@ class NullType(Type):
...
@@ -26,9 +26,6 @@ class NullType(Type):
def
filter_variable
(
self
,
other
,
allow_convert
=
True
):
def
filter_variable
(
self
,
other
,
allow_convert
=
True
):
raise
ValueError
(
"No values may be assigned to a NullType"
)
raise
ValueError
(
"No values may be assigned to a NullType"
)
def
may_share_memory
(
a
,
b
):
return
False
def
values_eq
(
self
,
a
,
b
,
force_same_dtype
=
True
):
def
values_eq
(
self
,
a
,
b
,
force_same_dtype
=
True
):
raise
ValueError
(
"NullType has no values to compare"
)
raise
ValueError
(
"NullType has no values to compare"
)
...
...
pytensor/graph/type.py
浏览文件 @
82f6a14f
...
@@ -48,10 +48,7 @@ class Type(MetaObject, Generic[D]):
...
@@ -48,10 +48,7 @@ class Type(MetaObject, Generic[D]):
unique element (i.e. it uses `self.__eq__`).
unique element (i.e. it uses `self.__eq__`).
"""
"""
if
self
==
otype
:
return
self
==
otype
return
True
return
False
def
is_super
(
self
,
otype
:
"Type"
)
->
bool
|
None
:
def
is_super
(
self
,
otype
:
"Type"
)
->
bool
|
None
:
"""Determine if `self` is a supertype of `otype`.
"""Determine if `self` is a supertype of `otype`.
...
...
pytensor/scalar/basic.py
浏览文件 @
82f6a14f
...
@@ -303,13 +303,6 @@ class ScalarType(CType, HasDataType, HasShape):
...
@@ -303,13 +303,6 @@ class ScalarType(CType, HasDataType, HasShape):
dtype
=
self
.
dtype
dtype
=
self
.
dtype
return
type
(
self
)(
dtype
)
return
type
(
self
)(
dtype
)
@staticmethod
def
may_share_memory
(
a
,
b
):
# This class represent basic c type, represented in python
# with numpy.scalar. They are read only. So from python, they
# can never share memory.
return
False
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
py_type
=
self
.
dtype_specs
()[
0
]
py_type
=
self
.
dtype_specs
()[
0
]
if
strict
and
not
isinstance
(
data
,
py_type
):
if
strict
and
not
isinstance
(
data
,
py_type
):
...
...
pytensor/tensor/type_other.py
浏览文件 @
82f6a14f
...
@@ -126,12 +126,6 @@ class NoneTypeT(Generic):
...
@@ -126,12 +126,6 @@ class NoneTypeT(Generic):
else
:
else
:
raise
TypeError
(
"Expected None!"
)
raise
TypeError
(
"Expected None!"
)
@staticmethod
def
may_share_memory
(
a
,
b
):
# None never share memory between object, in the sense of DebugMode.
# Python None are singleton
return
False
none_type_t
=
NoneTypeT
()
none_type_t
=
NoneTypeT
()
...
...
tests/compile/function/test_types.py
浏览文件 @
82f6a14f
...
@@ -730,6 +730,8 @@ class TestFunction:
...
@@ -730,6 +730,8 @@ class TestFunction:
s1
=
shared
(
b
)
s1
=
shared
(
b
)
s2
=
shared
(
b
)
s2
=
shared
(
b
)
x1
=
vector
()
x1
=
vector
()
x2
=
vector
(
shape
=
(
3
,))
x3
=
vector
(
shape
=
(
1
,))
# Assert cases we should not check for aliased inputs
# Assert cases we should not check for aliased inputs
for
d
in
[
for
d
in
[
...
@@ -737,27 +739,29 @@ class TestFunction:
...
@@ -737,27 +739,29 @@ class TestFunction:
dict
(
outputs
=
[
s1
+
1
,
s2
+
3
]),
dict
(
outputs
=
[
s1
+
1
,
s2
+
3
]),
dict
(
outputs
=
[
s1
+
1
],
updates
=
[(
s2
,
s2
+
3
)]),
dict
(
outputs
=
[
s1
+
1
],
updates
=
[(
s2
,
s2
+
3
)]),
dict
(
inputs
=
[
x1
],
outputs
=
[
x1
+
1
],
updates
=
[(
s2
,
s2
+
3
)]),
dict
(
inputs
=
[
x1
],
outputs
=
[
x1
+
1
],
updates
=
[(
s2
,
s2
+
3
)]),
dict
(
inputs
=
[
In
(
x1
,
mutable
=
True
)],
outputs
=
[
x1
+
1
],
updates
=
[(
s2
,
s2
+
3
)]
),
dict
(
inputs
=
[
In
(
x2
,
mutable
=
True
),
In
(
x3
,
mutable
=
True
)],
outputs
=
[
x2
+
2
,
x3
+
3
],
),
]:
]:
if
"inputs"
not
in
d
:
if
"inputs"
not
in
d
:
d
[
"inputs"
]
=
[]
d
[
"inputs"
]
=
[]
f
=
function
(
**
d
)
f
=
function
(
**
d
)
assert
not
f
.
_
check_for_aliased_input
s
,
d
assert
not
f
.
_
potential_aliased_input_group
s
,
d
# Assert cases we should check for aliased inputs
# Assert cases we should check for aliased inputs
for
d
in
[
for
d
in
[
dict
(
dict
(
inputs
=
[
In
(
x1
,
borrow
=
True
)],
inputs
=
[
In
(
x1
,
mutable
=
True
),
In
(
x2
,
mutable
=
True
)],
outputs
=
[
x1
+
1
],
outputs
=
[
x1
+
1
,
x2
+
2
],
updates
=
[(
s2
,
s2
+
3
)],
),
dict
(
inputs
=
[
In
(
x1
,
borrow
=
True
,
mutable
=
True
)],
outputs
=
[
x1
+
1
],
updates
=
[(
s2
,
s2
+
3
)],
updates
=
[(
s2
,
s2
+
3
)],
),
),
dict
(
dict
(
inputs
=
[
In
(
x1
,
mutable
=
True
)],
inputs
=
[
In
(
x1
,
mutable
=
True
)
,
In
(
x3
,
mutable
=
True
)
],
outputs
=
[
x1
+
1
],
outputs
=
[
x1
+
1
,
x3
+
3
],
updates
=
[(
s2
,
s2
+
3
)],
updates
=
[(
s2
,
s2
+
3
)],
),
),
]:
]:
...
@@ -765,7 +769,7 @@ class TestFunction:
...
@@ -765,7 +769,7 @@ class TestFunction:
d
[
"inputs"
]
=
[]
d
[
"inputs"
]
=
[]
f
=
function
(
**
d
)
f
=
function
(
**
d
)
assert
f
.
_
check_for_aliased_input
s
,
d
assert
f
.
_
potential_aliased_input_group
s
,
d
def
test_output_dictionary
(
self
):
def
test_output_dictionary
(
self
):
# Tests that function works when outputs is a dictionary
# Tests that function works when outputs is a dictionary
...
@@ -879,7 +883,7 @@ class TestPicklefunction:
...
@@ -879,7 +883,7 @@ class TestPicklefunction:
f
=
function
(
f
=
function
(
[
[
x
,
x
,
In
(
a
,
value
=
1.0
,
name
=
"a"
),
In
(
a
,
value
=
1.0
,
name
=
"a"
,
mutable
=
True
),
In
(
s
,
value
=
0.0
,
update
=
s
+
a
*
x
,
mutable
=
True
),
In
(
s
,
value
=
0.0
,
update
=
s
+
a
*
x
,
mutable
=
True
),
],
],
s
+
a
*
x
,
s
+
a
*
x
,
...
@@ -901,7 +905,12 @@ class TestPicklefunction:
...
@@ -901,7 +905,12 @@ class TestPicklefunction:
assert
x
not
in
g
.
container
assert
x
not
in
g
.
container
assert
x
not
in
g
.
value
assert
x
not
in
g
.
value
assert
len
(
f
.
defaults
)
==
len
(
g
.
defaults
)
assert
len
(
f
.
defaults
)
==
len
(
g
.
defaults
)
assert
f
.
_check_for_aliased_inputs
is
g
.
_check_for_aliased_inputs
# Shared variable is the first input
assert
(
f
.
_potential_aliased_input_groups
==
g
.
_potential_aliased_input_groups
==
((
1
,
2
),)
)
assert
f
.
name
==
g
.
name
assert
f
.
name
==
g
.
name
assert
f
.
maker
.
fgraph
.
name
==
g
.
maker
.
fgraph
.
name
assert
f
.
maker
.
fgraph
.
name
==
g
.
maker
.
fgraph
.
name
# print(f"{f.defaults = }")
# print(f"{f.defaults = }")
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论