Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8ae2a195
提交
8ae2a195
authored
7月 07, 2024
作者:
Virgile Andreani
提交者:
Virgile Andreani
7月 09, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove dict.keys() when unnecessary
上级
63da6d16
隐藏空白字符变更
内嵌
并排
正在显示
28 个修改的文件
包含
53 行增加
和
64 行删除
+53
-64
types.py
pytensor/compile/function/types.py
+4
-4
profiling.py
pytensor/compile/profiling.py
+1
-1
configdefaults.py
pytensor/configdefaults.py
+1
-1
configparser.py
pytensor/configparser.py
+1
-1
gradient.py
pytensor/gradient.py
+2
-2
basic.py
pytensor/graph/basic.py
+1
-1
destroyhandler.py
pytensor/graph/destroyhandler.py
+1
-1
basic.py
pytensor/graph/rewriting/basic.py
+4
-4
utils.py
pytensor/graph/utils.py
+1
-1
params_type.py
pytensor/link/c/params_type.py
+2
-5
type.py
pytensor/link/c/type.py
+4
-6
scalar.py
pytensor/link/numba/dispatch/scalar.py
+1
-3
scan.py
pytensor/link/numba/dispatch/scan.py
+1
-1
vm.py
pytensor/link/vm.py
+1
-1
printing.py
pytensor/printing.py
+2
-6
basic.py
pytensor/scalar/basic.py
+1
-1
op.py
pytensor/scan/op.py
+3
-3
utils.py
pytensor/scan/utils.py
+1
-1
basic.py
pytensor/sparse/basic.py
+1
-1
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+1
-1
utils.py
pytensor/tensor/utils.py
+1
-1
test_types.py
tests/compile/function/test_types.py
+6
-6
test_basic.py
tests/link/numba/test_basic.py
+1
-1
test_scan.py
tests/link/numba/test_scan.py
+2
-2
test_utils.py
tests/link/test_utils.py
+1
-1
test_vm.py
tests/link/test_vm.py
+1
-1
test_rewriting.py
tests/scan/test_rewriting.py
+6
-6
test_config.py
tests/test_config.py
+1
-1
没有找到文件。
pytensor/compile/function/types.py
浏览文件 @
8ae2a195
...
@@ -659,7 +659,7 @@ class Function:
...
@@ -659,7 +659,7 @@ class Function:
exist_svs
=
[
i
.
variable
for
i
in
maker
.
inputs
]
exist_svs
=
[
i
.
variable
for
i
in
maker
.
inputs
]
# Check if given ShareVariables exist
# Check if given ShareVariables exist
for
sv
in
swap
.
keys
()
:
for
sv
in
swap
:
if
sv
not
in
exist_svs
:
if
sv
not
in
exist_svs
:
raise
ValueError
(
f
"SharedVariable: {sv.name} not found"
)
raise
ValueError
(
f
"SharedVariable: {sv.name} not found"
)
...
@@ -711,9 +711,9 @@ class Function:
...
@@ -711,9 +711,9 @@ class Function:
# it is well tested, we don't share the part of the storage_map.
# it is well tested, we don't share the part of the storage_map.
if
share_memory
:
if
share_memory
:
i_o_vars
=
maker
.
fgraph
.
inputs
+
maker
.
fgraph
.
outputs
i_o_vars
=
maker
.
fgraph
.
inputs
+
maker
.
fgraph
.
outputs
for
key
in
storage_map
.
key
s
():
for
key
,
val
in
storage_map
.
item
s
():
if
key
not
in
i_o_vars
:
if
key
not
in
i_o_vars
:
new_storage_map
[
memo
[
key
]]
=
storage_map
[
key
]
new_storage_map
[
memo
[
key
]]
=
val
if
not
name
and
self
.
name
:
if
not
name
and
self
.
name
:
name
=
self
.
name
+
" copy"
name
=
self
.
name
+
" copy"
...
@@ -1446,7 +1446,7 @@ class FunctionMaker:
...
@@ -1446,7 +1446,7 @@ class FunctionMaker:
if
not
hasattr
(
mode
.
linker
,
"accept"
):
if
not
hasattr
(
mode
.
linker
,
"accept"
):
raise
ValueError
(
raise
ValueError
(
"'linker' parameter of FunctionMaker should be "
"'linker' parameter of FunctionMaker should be "
f
"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers
.keys()
)}"
f
"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers)}"
)
)
def
__init__
(
def
__init__
(
...
...
pytensor/compile/profiling.py
浏览文件 @
8ae2a195
...
@@ -1446,7 +1446,7 @@ class ProfileStats:
...
@@ -1446,7 +1446,7 @@ class ProfileStats:
file
=
file
,
file
=
file
,
)
)
if
config
.
profiling__debugprint
:
if
config
.
profiling__debugprint
:
fcts
=
{
fgraph
for
(
fgraph
,
n
)
in
self
.
apply_time
.
keys
()
}
fcts
=
{
fgraph
for
(
fgraph
,
n
)
in
self
.
apply_time
}
pytensor
.
printing
.
debugprint
(
fcts
,
print_type
=
True
)
pytensor
.
printing
.
debugprint
(
fcts
,
print_type
=
True
)
if
self
.
variable_shape
or
self
.
variable_strides
:
if
self
.
variable_shape
or
self
.
variable_strides
:
self
.
summary_memory
(
file
,
n_apply_to_print
)
self
.
summary_memory
(
file
,
n_apply_to_print
)
...
...
pytensor/configdefaults.py
浏览文件 @
8ae2a195
...
@@ -1318,7 +1318,7 @@ def add_caching_dir_configvars():
...
@@ -1318,7 +1318,7 @@ def add_caching_dir_configvars():
_compiledir_format_dict
[
"short_platform"
]
=
short_platform
()
_compiledir_format_dict
[
"short_platform"
]
=
short_platform
()
# Allow to have easily one compiledir per device.
# Allow to have easily one compiledir per device.
_compiledir_format_dict
[
"device"
]
=
config
.
device
_compiledir_format_dict
[
"device"
]
=
config
.
device
compiledir_format_keys
=
", "
.
join
(
sorted
(
_compiledir_format_dict
.
keys
()
))
compiledir_format_keys
=
", "
.
join
(
sorted
(
_compiledir_format_dict
))
_default_compiledir_format
=
(
_default_compiledir_format
=
(
"compiledir_
%(short_platform)
s-
%(processor)
s-"
"compiledir_
%(short_platform)
s-
%(processor)
s-"
"
%(python_version)
s-
%(python_bitwidth)
s"
"
%(python_version)
s-
%(python_bitwidth)
s"
...
...
pytensor/configparser.py
浏览文件 @
8ae2a195
...
@@ -214,7 +214,7 @@ class PyTensorConfigParser:
...
@@ -214,7 +214,7 @@ class PyTensorConfigParser:
return
_ChangeFlagsDecorator
(
*
args
,
_root
=
self
,
**
kwargs
)
return
_ChangeFlagsDecorator
(
*
args
,
_root
=
self
,
**
kwargs
)
def
warn_unused_flags
(
self
):
def
warn_unused_flags
(
self
):
for
key
in
self
.
_flags_dict
.
keys
()
:
for
key
in
self
.
_flags_dict
:
warnings
.
warn
(
f
"PyTensor does not recognise this flag: {key}"
)
warnings
.
warn
(
f
"PyTensor does not recognise this flag: {key}"
)
...
...
pytensor/gradient.py
浏览文件 @
8ae2a195
...
@@ -500,7 +500,7 @@ def grad(
...
@@ -500,7 +500,7 @@ def grad(
if
cost
is
not
None
:
if
cost
is
not
None
:
outputs
.
append
(
cost
)
outputs
.
append
(
cost
)
if
known_grads
is
not
None
:
if
known_grads
is
not
None
:
outputs
.
extend
(
list
(
known_grads
.
keys
()
))
outputs
.
extend
(
list
(
known_grads
))
var_to_app_to_idx
=
_populate_var_to_app_to_idx
(
outputs
,
_wrt
,
consider_constant
)
var_to_app_to_idx
=
_populate_var_to_app_to_idx
(
outputs
,
_wrt
,
consider_constant
)
...
@@ -966,7 +966,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
...
@@ -966,7 +966,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
visit
(
elem
)
visit
(
elem
)
# Remove variables that don't have wrt as a true ancestor
# Remove variables that don't have wrt as a true ancestor
orig_vars
=
list
(
var_to_app_to_idx
.
keys
()
)
orig_vars
=
list
(
var_to_app_to_idx
)
for
var
in
orig_vars
:
for
var
in
orig_vars
:
if
var
not
in
visited
:
if
var
not
in
visited
:
del
var_to_app_to_idx
[
var
]
del
var_to_app_to_idx
[
var
]
...
...
pytensor/graph/basic.py
浏览文件 @
8ae2a195
...
@@ -631,7 +631,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
...
@@ -631,7 +631,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
if
not
hasattr
(
self
,
"_fn_cache"
):
if
not
hasattr
(
self
,
"_fn_cache"
):
self
.
_fn_cache
:
dict
=
dict
()
self
.
_fn_cache
:
dict
=
dict
()
inputs
=
tuple
(
sorted
(
parsed_inputs_to_values
.
keys
()
,
key
=
id
))
inputs
=
tuple
(
sorted
(
parsed_inputs_to_values
,
key
=
id
))
cache_key
=
(
inputs
,
tuple
(
kwargs
.
items
()))
cache_key
=
(
inputs
,
tuple
(
kwargs
.
items
()))
try
:
try
:
fn
=
self
.
_fn_cache
[
cache_key
]
fn
=
self
.
_fn_cache
[
cache_key
]
...
...
pytensor/graph/destroyhandler.py
浏览文件 @
8ae2a195
...
@@ -406,7 +406,7 @@ class DestroyHandler(Bookkeeper):
...
@@ -406,7 +406,7 @@ class DestroyHandler(Bookkeeper):
# If True means that the apply node, destroys the protected_var.
# If True means that the apply node, destroys the protected_var.
if
idx
in
[
dmap
for
sublist
in
destroy_maps
for
dmap
in
sublist
]:
if
idx
in
[
dmap
for
sublist
in
destroy_maps
for
dmap
in
sublist
]:
return
True
return
True
for
var_idx
in
app
.
op
.
view_map
.
keys
()
:
for
var_idx
in
app
.
op
.
view_map
:
if
idx
in
app
.
op
.
view_map
[
var_idx
]:
if
idx
in
app
.
op
.
view_map
[
var_idx
]:
# We need to recursively check the destroy_map of all the
# We need to recursively check the destroy_map of all the
# outputs that we have a view_map on.
# outputs that we have a view_map on.
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
8ae2a195
...
@@ -15,7 +15,7 @@ from collections.abc import Callable, Iterable, Sequence
...
@@ -15,7 +15,7 @@ from collections.abc import Callable, Iterable, Sequence
from
collections.abc
import
Iterable
as
IterableType
from
collections.abc
import
Iterable
as
IterableType
from
functools
import
_compose_mro
,
partial
,
reduce
# type: ignore
from
functools
import
_compose_mro
,
partial
,
reduce
# type: ignore
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
TYPE_CHECKING
,
Literal
,
cast
from
typing
import
TYPE_CHECKING
,
Literal
import
pytensor
import
pytensor
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
...
@@ -1924,9 +1924,9 @@ class NodeProcessingGraphRewriter(GraphRewriter):
...
@@ -1924,9 +1924,9 @@ class NodeProcessingGraphRewriter(GraphRewriter):
remove
:
list
[
Variable
]
=
[]
remove
:
list
[
Variable
]
=
[]
if
isinstance
(
replacements
,
dict
):
if
isinstance
(
replacements
,
dict
):
if
"remove"
in
replacements
:
if
"remove"
in
replacements
:
remove
=
list
(
cast
(
Sequence
[
Variable
],
replacements
.
pop
(
"remove"
)
))
remove
=
list
(
replacements
.
pop
(
"remove"
))
old_vars
=
list
(
cast
(
Sequence
[
Variable
],
replacements
.
keys
())
)
old_vars
=
list
(
replacements
)
replacements
=
list
(
cast
(
Sequence
[
Variable
],
replacements
.
values
()
))
replacements
=
list
(
replacements
.
values
(
))
elif
not
isinstance
(
replacements
,
tuple
|
list
):
elif
not
isinstance
(
replacements
,
tuple
|
list
):
raise
TypeError
(
raise
TypeError
(
f
"Node rewriter {node_rewriter} gave wrong type of replacement. "
f
"Node rewriter {node_rewriter} gave wrong type of replacement. "
...
...
pytensor/graph/utils.py
浏览文件 @
8ae2a195
...
@@ -168,7 +168,7 @@ class MissingInputError(Exception):
...
@@ -168,7 +168,7 @@ class MissingInputError(Exception):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
if
kwargs
:
if
kwargs
:
# The call to list is needed for Python 3
# The call to list is needed for Python 3
assert
list
(
kwargs
.
keys
()
)
==
[
"variable"
]
assert
list
(
kwargs
)
==
[
"variable"
]
error_msg
=
get_variable_trace_string
(
kwargs
[
"variable"
])
error_msg
=
get_variable_trace_string
(
kwargs
[
"variable"
])
if
error_msg
:
if
error_msg
:
args
=
(
*
args
,
error_msg
)
args
=
(
*
args
,
error_msg
)
...
...
pytensor/link/c/params_type.py
浏览文件 @
8ae2a195
...
@@ -264,10 +264,7 @@ class Params(dict):
...
@@ -264,10 +264,7 @@ class Params(dict):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"Params({})"
.
format
(
return
"Params({})"
.
format
(
", "
.
join
(
", "
.
join
(
[
[(
f
"{k}:{type(self[k]).__name__}:{self[k]}"
)
for
k
in
sorted
(
self
)]
(
f
"{k}:{type(self[k]).__name__}:{self[k]}"
)
for
k
in
sorted
(
self
.
keys
())
]
)
)
)
)
...
@@ -365,7 +362,7 @@ class ParamsType(CType):
...
@@ -365,7 +362,7 @@ class ParamsType(CType):
)
)
self
.
length
=
len
(
kwargs
)
self
.
length
=
len
(
kwargs
)
self
.
fields
=
tuple
(
sorted
(
kwargs
.
keys
()
))
self
.
fields
=
tuple
(
sorted
(
kwargs
))
self
.
types
=
tuple
(
kwargs
[
field
]
for
field
in
self
.
fields
)
self
.
types
=
tuple
(
kwargs
[
field
]
for
field
in
self
.
fields
)
self
.
name
=
self
.
generate_struct_name
()
self
.
name
=
self
.
generate_struct_name
()
...
...
pytensor/link/c/type.py
浏览文件 @
8ae2a195
...
@@ -472,7 +472,7 @@ class EnumType(CType, dict):
...
@@ -472,7 +472,7 @@ class EnumType(CType, dict):
"""
"""
Return the sorted tuple of all aliases in this enumeration.
Return the sorted tuple of all aliases in this enumeration.
"""
"""
return
tuple
(
sorted
(
self
.
aliases
.
keys
()
))
return
tuple
(
sorted
(
self
.
aliases
))
def
__repr__
(
self
):
def
__repr__
(
self
):
names_to_aliases
=
{
constant_name
:
""
for
constant_name
in
self
}
names_to_aliases
=
{
constant_name
:
""
for
constant_name
in
self
}
...
@@ -481,9 +481,7 @@ class EnumType(CType, dict):
...
@@ -481,9 +481,7 @@ class EnumType(CType, dict):
return
"{}<{}>({})"
.
format
(
return
"{}<{}>({})"
.
format
(
type
(
self
)
.
__name__
,
type
(
self
)
.
__name__
,
self
.
ctype
,
self
.
ctype
,
", "
.
join
(
", "
.
join
(
f
"{k}{names_to_aliases[k]}:{self[k]}"
for
k
in
sorted
(
self
)),
f
"{k}{names_to_aliases[k]}:{self[k]}"
for
k
in
sorted
(
self
.
keys
())
),
)
)
def
__getattr__
(
self
,
key
):
def
__getattr__
(
self
,
key
):
...
@@ -612,7 +610,7 @@ class EnumType(CType, dict):
...
@@ -612,7 +610,7 @@ class EnumType(CType, dict):
f
"""
f
"""
#define {k} {self[k]!s}
#define {k} {self[k]!s}
"""
"""
for
k
in
sorted
(
self
.
keys
()
)
for
k
in
sorted
(
self
)
)
)
+
self
.
c_to_string
()
+
self
.
c_to_string
()
)
)
...
@@ -772,7 +770,7 @@ class CEnumType(EnumList):
...
@@ -772,7 +770,7 @@ class CEnumType(EnumList):
case
%(i)
d:
%(name)
s =
%(constant_cname)
s; break;
case
%(i)
d:
%(name)
s =
%(constant_cname)
s; break;
"""
"""
%
dict
(
i
=
i
,
name
=
name
,
constant_cname
=
swapped_dict
[
i
])
%
dict
(
i
=
i
,
name
=
name
,
constant_cname
=
swapped_dict
[
i
])
for
i
in
sorted
(
swapped_dict
.
keys
()
)
for
i
in
sorted
(
swapped_dict
)
),
),
fail
=
sub
[
"fail"
],
fail
=
sub
[
"fail"
],
)
)
...
...
pytensor/link/numba/dispatch/scalar.py
浏览文件 @
8ae2a195
...
@@ -117,9 +117,7 @@ def {scalar_op_fn_name}({input_names}):
...
@@ -117,9 +117,7 @@ def {scalar_op_fn_name}({input_names}):
converted_call_args
=
", "
.
join
(
converted_call_args
=
", "
.
join
(
[
[
f
"direct_cast({i_name}, {i_tmp_dtype_name})"
f
"direct_cast({i_name}, {i_tmp_dtype_name})"
for
i_name
,
i_tmp_dtype_name
in
zip
(
for
i_name
,
i_tmp_dtype_name
in
zip
(
input_names
,
input_tmp_dtype_names
)
input_names
,
input_tmp_dtype_names
.
keys
()
)
]
]
)
)
if
not
has_pyx_skip_dispatch
:
if
not
has_pyx_skip_dispatch
:
...
...
pytensor/link/numba/dispatch/scan.py
浏览文件 @
8ae2a195
...
@@ -70,7 +70,7 @@ def numba_funcify_Scan(op, node, **kwargs):
...
@@ -70,7 +70,7 @@ def numba_funcify_Scan(op, node, **kwargs):
outer_in_names_to_vars
=
{
outer_in_names_to_vars
=
{
(
f
"outer_in_{i}"
if
i
>
0
else
"n_steps"
):
v
for
i
,
v
in
enumerate
(
node
.
inputs
)
(
f
"outer_in_{i}"
if
i
>
0
else
"n_steps"
):
v
for
i
,
v
in
enumerate
(
node
.
inputs
)
}
}
outer_in_names
=
list
(
outer_in_names_to_vars
.
keys
()
)
outer_in_names
=
list
(
outer_in_names_to_vars
)
outer_in_seqs_names
=
op
.
outer_seqs
(
outer_in_names
)
outer_in_seqs_names
=
op
.
outer_seqs
(
outer_in_names
)
outer_in_mit_mot_names
=
op
.
outer_mitmot
(
outer_in_names
)
outer_in_mit_mot_names
=
op
.
outer_mitmot
(
outer_in_names
)
outer_in_mit_sot_names
=
op
.
outer_mitsot
(
outer_in_names
)
outer_in_mit_sot_names
=
op
.
outer_mitsot
(
outer_in_names
)
...
...
pytensor/link/vm.py
浏览文件 @
8ae2a195
...
@@ -990,7 +990,7 @@ class VMLinker(LocalLinker):
...
@@ -990,7 +990,7 @@ class VMLinker(LocalLinker):
for
pair
in
reallocated_info
.
values
():
for
pair
in
reallocated_info
.
values
():
storage_map
[
pair
[
1
]]
=
storage_map
[
pair
[
0
]]
storage_map
[
pair
[
1
]]
=
storage_map
[
pair
[
0
]]
return
tuple
(
reallocated_info
.
keys
()
)
return
tuple
(
reallocated_info
)
def
make_vm
(
def
make_vm
(
self
,
self
,
...
...
pytensor/printing.py
浏览文件 @
8ae2a195
...
@@ -1106,9 +1106,7 @@ class PPrinter(Printer):
...
@@ -1106,9 +1106,7 @@ class PPrinter(Printer):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
current
=
None
current
=
None
if
display_inputs
:
if
display_inputs
:
strings
=
[
strings
=
[(
0
,
"inputs: "
+
", "
.
join
(
str
(
x
)
for
x
in
[
*
inputs
,
*
updates
]))]
(
0
,
"inputs: "
+
", "
.
join
(
map
(
str
,
list
(
inputs
)
+
updates
.
keys
())))
]
else
:
else
:
strings
=
[]
strings
=
[]
pprinter
=
self
.
clone_assign
(
pprinter
=
self
.
clone_assign
(
...
@@ -1116,9 +1114,7 @@ class PPrinter(Printer):
...
@@ -1116,9 +1114,7 @@ class PPrinter(Printer):
)
)
inv_updates
=
{
b
:
a
for
(
a
,
b
)
in
updates
.
items
()}
inv_updates
=
{
b
:
a
for
(
a
,
b
)
in
updates
.
items
()}
i
=
1
i
=
1
for
node
in
io_toposort
(
for
node
in
io_toposort
([
*
inputs
,
*
updates
],
[
*
outputs
,
*
updates
.
values
()]):
list
(
inputs
)
+
updates
.
keys
(),
list
(
outputs
)
+
updates
.
values
()
):
for
output
in
node
.
outputs
:
for
output
in
node
.
outputs
:
if
output
in
inv_updates
:
if
output
in
inv_updates
:
name
=
str
(
inv_updates
[
output
])
name
=
str
(
inv_updates
[
output
])
...
...
pytensor/scalar/basic.py
浏览文件 @
8ae2a195
...
@@ -4426,7 +4426,7 @@ class Compositef32:
...
@@ -4426,7 +4426,7 @@ class Compositef32:
else
:
else
:
ni
=
i
ni
=
i
mapping
[
i
]
=
ni
mapping
[
i
]
=
ni
if
isinstance
(
node
.
op
,
tuple
(
self
.
special
.
keys
()
)):
if
isinstance
(
node
.
op
,
tuple
(
self
.
special
)):
self
.
special
[
type
(
node
.
op
)](
node
,
mapping
)
self
.
special
[
type
(
node
.
op
)](
node
,
mapping
)
continue
continue
new_node
=
node
.
clone_with_new_inputs
(
new_node
=
node
.
clone_with_new_inputs
(
...
...
pytensor/scan/op.py
浏览文件 @
8ae2a195
...
@@ -1284,14 +1284,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1284,14 +1284,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def
__str__
(
self
):
def
__str__
(
self
):
inplace
=
"none"
inplace
=
"none"
if
len
(
self
.
destroy_map
.
keys
())
>
0
:
if
self
.
destroy_map
:
# Check if all outputs are inplace
# Check if all outputs are inplace
if
sorted
(
self
.
destroy_map
.
keys
()
)
==
sorted
(
if
sorted
(
self
.
destroy_map
)
==
sorted
(
range
(
self
.
info
.
n_mit_mot
+
self
.
info
.
n_mit_sot
+
self
.
info
.
n_sit_sot
)
range
(
self
.
info
.
n_mit_mot
+
self
.
info
.
n_mit_sot
+
self
.
info
.
n_sit_sot
)
):
):
inplace
=
"all"
inplace
=
"all"
else
:
else
:
inplace
=
str
(
list
(
self
.
destroy_map
.
keys
()
))
inplace
=
str
(
list
(
self
.
destroy_map
))
return
(
return
(
f
"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
f
"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
)
)
...
...
pytensor/scan/utils.py
浏览文件 @
8ae2a195
...
@@ -269,7 +269,7 @@ class Validator:
...
@@ -269,7 +269,7 @@ class Validator:
# Mapping from invalid variables to equivalent valid ones.
# Mapping from invalid variables to equivalent valid ones.
self
.
valid_equivalent
=
valid_equivalent
.
copy
()
self
.
valid_equivalent
=
valid_equivalent
.
copy
()
self
.
valid
.
update
(
list
(
valid_equivalent
.
values
()))
self
.
valid
.
update
(
list
(
valid_equivalent
.
values
()))
self
.
invalid
.
update
(
list
(
valid_equivalent
.
keys
()
))
self
.
invalid
.
update
(
list
(
valid_equivalent
))
def
check
(
self
,
out
):
def
check
(
self
,
out
):
"""
"""
...
...
pytensor/sparse/basic.py
浏览文件 @
8ae2a195
...
@@ -524,7 +524,7 @@ csc_fmatrix = SparseTensorType(format="csc", dtype="float32")
...
@@ -524,7 +524,7 @@ csc_fmatrix = SparseTensorType(format="csc", dtype="float32")
csr_fmatrix
=
SparseTensorType
(
format
=
"csr"
,
dtype
=
"float32"
)
csr_fmatrix
=
SparseTensorType
(
format
=
"csr"
,
dtype
=
"float32"
)
bsr_fmatrix
=
SparseTensorType
(
format
=
"bsr"
,
dtype
=
"float32"
)
bsr_fmatrix
=
SparseTensorType
(
format
=
"bsr"
,
dtype
=
"float32"
)
all_dtypes
=
list
(
SparseTensorType
.
dtype_specs_map
.
keys
()
)
all_dtypes
=
list
(
SparseTensorType
.
dtype_specs_map
)
complex_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
7
]
==
"complex"
]
complex_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
7
]
==
"complex"
]
float_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
5
]
==
"float"
]
float_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
5
]
==
"float"
]
int_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
3
]
==
"int"
]
int_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
3
]
==
"int"
]
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
8ae2a195
...
@@ -71,7 +71,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
...
@@ -71,7 +71,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
ndim
=
prof
[
"ndim"
]
ndim
=
prof
[
"ndim"
]
if
ndim
:
if
ndim
:
print
(
blanc
,
"ndim"
,
"nb"
,
file
=
stream
)
print
(
blanc
,
"ndim"
,
"nb"
,
file
=
stream
)
for
n
in
sorted
(
ndim
.
keys
()
):
for
n
in
sorted
(
ndim
):
print
(
blanc
,
n
,
ndim
[
n
],
file
=
stream
)
print
(
blanc
,
n
,
ndim
[
n
],
file
=
stream
)
def
candidate_input_idxs
(
self
,
node
):
def
candidate_input_idxs
(
self
,
node
):
...
...
pytensor/tensor/utils.py
浏览文件 @
8ae2a195
...
@@ -88,7 +88,7 @@ def shape_of_variables(
...
@@ -88,7 +88,7 @@ def shape_of_variables(
compute_shapes
=
pytensor
.
function
(
input_dims
,
output_dims
)
compute_shapes
=
pytensor
.
function
(
input_dims
,
output_dims
)
if
any
(
i
not
in
fgraph
.
inputs
for
i
in
input_shapes
.
keys
()
):
if
any
(
i
not
in
fgraph
.
inputs
for
i
in
input_shapes
):
raise
ValueError
(
raise
ValueError
(
"input_shapes keys aren't in the fgraph.inputs. FunctionGraph()"
"input_shapes keys aren't in the fgraph.inputs. FunctionGraph()"
" interface changed. Now by default, it clones the graph it receives."
" interface changed. Now by default, it clones the graph it receives."
...
...
tests/compile/function/test_types.py
浏览文件 @
8ae2a195
...
@@ -889,9 +889,9 @@ class TestPicklefunction:
...
@@ -889,9 +889,9 @@ class TestPicklefunction:
return
return
else
:
else
:
raise
raise
# if they both return, assume
that they return equivalent things.
# if they both return, assume that they return equivalent things.
# print [(k,
id(k)) for k in f.finder.keys()
]
# print [(k,
id(k)) for k in f.finder
]
# print [(k,
id(k)) for k in g.finder.keys()
]
# print [(k,
id(k)) for k in g.finder
]
assert
g
.
container
[
0
]
.
storage
is
not
f
.
container
[
0
]
.
storage
assert
g
.
container
[
0
]
.
storage
is
not
f
.
container
[
0
]
.
storage
assert
g
.
container
[
1
]
.
storage
is
not
f
.
container
[
1
]
.
storage
assert
g
.
container
[
1
]
.
storage
is
not
f
.
container
[
1
]
.
storage
...
@@ -1012,9 +1012,9 @@ class TestPicklefunction:
...
@@ -1012,9 +1012,9 @@ class TestPicklefunction:
return
return
else
:
else
:
raise
raise
# if they both return, assume
that they return equivalent things.
# if they both return, assume that they return equivalent things.
# print [(k,
id(k)) for k in f.finder.keys()
]
# print [(k,
id(k)) for k in f.finder
]
# print [(k,
id(k)) for k in g.finder.keys()
]
# print [(k,
id(k)) for k in g.finder
]
assert
g
.
container
[
0
]
.
storage
is
not
f
.
container
[
0
]
.
storage
assert
g
.
container
[
0
]
.
storage
is
not
f
.
container
[
0
]
.
storage
assert
g
.
container
[
1
]
.
storage
is
not
f
.
container
[
1
]
.
storage
assert
g
.
container
[
1
]
.
storage
is
not
f
.
container
[
1
]
.
storage
...
...
tests/link/numba/test_basic.py
浏览文件 @
8ae2a195
...
@@ -829,7 +829,7 @@ def test_config_options_fastmath():
...
@@ -829,7 +829,7 @@ def test_config_options_fastmath():
with
config
.
change_flags
(
numba__fastmath
=
True
):
with
config
.
change_flags
(
numba__fastmath
=
True
):
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
print
(
list
(
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
.
keys
()
))
print
(
list
(
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
))
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
assert
numba_mul_fn
.
targetoptions
[
"fastmath"
]
is
True
assert
numba_mul_fn
.
targetoptions
[
"fastmath"
]
is
True
...
...
tests/link/numba/test_scan.py
浏览文件 @
8ae2a195
...
@@ -479,14 +479,14 @@ def test_vector_taps_benchmark(benchmark):
...
@@ -479,14 +479,14 @@ def test_vector_taps_benchmark(benchmark):
sitsot_init
:
rng
.
normal
(),
sitsot_init
:
rng
.
normal
(),
}
}
numba_fn
=
pytensor
.
function
(
list
(
test
.
keys
()
),
outs
,
mode
=
get_mode
(
"NUMBA"
))
numba_fn
=
pytensor
.
function
(
list
(
test
),
outs
,
mode
=
get_mode
(
"NUMBA"
))
scan_nodes
=
[
scan_nodes
=
[
node
for
node
in
numba_fn
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
node
for
node
in
numba_fn
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
]
]
assert
len
(
scan_nodes
)
==
1
assert
len
(
scan_nodes
)
==
1
numba_res
=
numba_fn
(
*
test
.
values
())
numba_res
=
numba_fn
(
*
test
.
values
())
ref_fn
=
pytensor
.
function
(
list
(
test
.
keys
()
),
outs
,
mode
=
get_mode
(
"FAST_COMPILE"
))
ref_fn
=
pytensor
.
function
(
list
(
test
),
outs
,
mode
=
get_mode
(
"FAST_COMPILE"
))
ref_res
=
ref_fn
(
*
test
.
values
())
ref_res
=
ref_fn
(
*
test
.
values
())
for
numba_r
,
ref_r
in
zip
(
numba_res
,
ref_res
):
for
numba_r
,
ref_r
in
zip
(
numba_res
,
ref_res
):
np
.
testing
.
assert_array_almost_equal
(
numba_r
,
ref_r
)
np
.
testing
.
assert_array_almost_equal
(
numba_r
,
ref_r
)
...
...
tests/link/test_utils.py
浏览文件 @
8ae2a195
...
@@ -57,7 +57,7 @@ def test_fgraph_to_python_names():
...
@@ -57,7 +57,7 @@ def test_fgraph_to_python_names():
"scalar_variable"
,
"scalar_variable"
,
"tensor_variable_1"
,
"tensor_variable_1"
,
r
.
name
,
r
.
name
,
)
==
tuple
(
sig
.
parameters
.
keys
()
)
)
==
tuple
(
sig
.
parameters
)
assert
(
1
,
2
,
3
,
4
,
5
)
==
out_jx
(
1
,
2
,
3
,
4
,
5
)
assert
(
1
,
2
,
3
,
4
,
5
)
==
out_jx
(
1
,
2
,
3
,
4
,
5
)
obj
=
object
()
obj
=
object
()
...
...
tests/link/test_vm.py
浏览文件 @
8ae2a195
...
@@ -337,7 +337,7 @@ def test_reallocation():
...
@@ -337,7 +337,7 @@ def test_reallocation():
def
check_storage
(
storage_map
):
def
check_storage
(
storage_map
):
for
i
in
storage_map
:
for
i
in
storage_map
:
if
not
isinstance
(
i
,
TensorConstant
):
if
not
isinstance
(
i
,
TensorConstant
):
keys_copy
=
list
(
storage_map
.
keys
()
)[:]
keys_copy
=
list
(
storage_map
)[:]
keys_copy
.
remove
(
i
)
keys_copy
.
remove
(
i
)
for
o
in
keys_copy
:
for
o
in
keys_copy
:
if
storage_map
[
i
][
0
]
and
storage_map
[
i
][
0
]
is
storage_map
[
o
][
0
]:
if
storage_map
[
i
][
0
]
and
storage_map
[
i
][
0
]
is
storage_map
[
o
][
0
]:
...
...
tests/scan/test_rewriting.py
浏览文件 @
8ae2a195
...
@@ -1097,8 +1097,8 @@ class TestScanInplaceOptimizer:
...
@@ -1097,8 +1097,8 @@ class TestScanInplaceOptimizer:
allow_input_downcast
=
True
,
allow_input_downcast
=
True
,
)
)
scan_node
=
[
x
for
x
in
f9
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
scan_node
=
[
x
for
x
in
f9
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
assert
0
in
scan_node
[
0
]
.
op
.
destroy_map
.
keys
()
assert
0
in
scan_node
[
0
]
.
op
.
destroy_map
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
.
keys
()
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
# compute output in numpy
# compute output in numpy
numpy_x0
=
np
.
zeros
((
3
,))
numpy_x0
=
np
.
zeros
((
3
,))
numpy_x1
=
np
.
zeros
((
3
,))
numpy_x1
=
np
.
zeros
((
3
,))
...
@@ -1163,8 +1163,8 @@ class TestScanInplaceOptimizer:
...
@@ -1163,8 +1163,8 @@ class TestScanInplaceOptimizer:
)
)
scan_node
=
[
x
for
x
in
f9
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
scan_node
=
[
x
for
x
in
f9
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
assert
0
in
scan_node
[
0
]
.
op
.
destroy_map
.
keys
()
assert
0
in
scan_node
[
0
]
.
op
.
destroy_map
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
.
keys
()
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
# compute output in numpy
# compute output in numpy
numpy_x0
=
np
.
zeros
((
3
,))
numpy_x0
=
np
.
zeros
((
3
,))
numpy_x1
=
np
.
zeros
((
3
,))
numpy_x1
=
np
.
zeros
((
3
,))
...
@@ -1203,8 +1203,8 @@ class TestScanInplaceOptimizer:
...
@@ -1203,8 +1203,8 @@ class TestScanInplaceOptimizer:
f9
=
function
([],
outputs
,
updates
=
updates
,
mode
=
self
.
mode
)
f9
=
function
([],
outputs
,
updates
=
updates
,
mode
=
self
.
mode
)
scan_node
=
[
x
for
x
in
f9
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
scan_node
=
[
x
for
x
in
f9
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
assert
0
not
in
scan_node
[
0
]
.
op
.
destroy_map
.
keys
()
assert
0
not
in
scan_node
[
0
]
.
op
.
destroy_map
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
.
keys
()
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
class
TestSaveMem
:
class
TestSaveMem
:
...
...
tests/test_config.py
浏览文件 @
8ae2a195
...
@@ -222,7 +222,7 @@ def test_config_pickling():
...
@@ -222,7 +222,7 @@ def test_config_pickling():
buffer
.
seek
(
0
)
buffer
.
seek
(
0
)
restored
=
pickle
.
load
(
buffer
)
restored
=
pickle
.
load
(
buffer
)
# ...without a change in the config values
# ...without a change in the config values
for
name
in
root
.
_config_var_dict
.
keys
()
:
for
name
in
root
.
_config_var_dict
:
v_original
=
getattr
(
root
,
name
)
v_original
=
getattr
(
root
,
name
)
v_restored
=
getattr
(
restored
,
name
)
v_restored
=
getattr
(
restored
,
name
)
assert
(
assert
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论