Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8ae2a195
提交
8ae2a195
authored
7月 07, 2024
作者:
Virgile Andreani
提交者:
Virgile Andreani
7月 09, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove dict.keys() when unnecessary
上级
63da6d16
隐藏空白字符变更
内嵌
并排
正在显示
28 个修改的文件
包含
53 行增加
和
64 行删除
+53
-64
types.py
pytensor/compile/function/types.py
+4
-4
profiling.py
pytensor/compile/profiling.py
+1
-1
configdefaults.py
pytensor/configdefaults.py
+1
-1
configparser.py
pytensor/configparser.py
+1
-1
gradient.py
pytensor/gradient.py
+2
-2
basic.py
pytensor/graph/basic.py
+1
-1
destroyhandler.py
pytensor/graph/destroyhandler.py
+1
-1
basic.py
pytensor/graph/rewriting/basic.py
+4
-4
utils.py
pytensor/graph/utils.py
+1
-1
params_type.py
pytensor/link/c/params_type.py
+2
-5
type.py
pytensor/link/c/type.py
+4
-6
scalar.py
pytensor/link/numba/dispatch/scalar.py
+1
-3
scan.py
pytensor/link/numba/dispatch/scan.py
+1
-1
vm.py
pytensor/link/vm.py
+1
-1
printing.py
pytensor/printing.py
+2
-6
basic.py
pytensor/scalar/basic.py
+1
-1
op.py
pytensor/scan/op.py
+3
-3
utils.py
pytensor/scan/utils.py
+1
-1
basic.py
pytensor/sparse/basic.py
+1
-1
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+1
-1
utils.py
pytensor/tensor/utils.py
+1
-1
test_types.py
tests/compile/function/test_types.py
+6
-6
test_basic.py
tests/link/numba/test_basic.py
+1
-1
test_scan.py
tests/link/numba/test_scan.py
+2
-2
test_utils.py
tests/link/test_utils.py
+1
-1
test_vm.py
tests/link/test_vm.py
+1
-1
test_rewriting.py
tests/scan/test_rewriting.py
+6
-6
test_config.py
tests/test_config.py
+1
-1
没有找到文件。
pytensor/compile/function/types.py
浏览文件 @
8ae2a195
...
...
@@ -659,7 +659,7 @@ class Function:
exist_svs
=
[
i
.
variable
for
i
in
maker
.
inputs
]
# Check if given ShareVariables exist
for
sv
in
swap
.
keys
()
:
for
sv
in
swap
:
if
sv
not
in
exist_svs
:
raise
ValueError
(
f
"SharedVariable: {sv.name} not found"
)
...
...
@@ -711,9 +711,9 @@ class Function:
# it is well tested, we don't share the part of the storage_map.
if
share_memory
:
i_o_vars
=
maker
.
fgraph
.
inputs
+
maker
.
fgraph
.
outputs
for
key
in
storage_map
.
key
s
():
for
key
,
val
in
storage_map
.
item
s
():
if
key
not
in
i_o_vars
:
new_storage_map
[
memo
[
key
]]
=
storage_map
[
key
]
new_storage_map
[
memo
[
key
]]
=
val
if
not
name
and
self
.
name
:
name
=
self
.
name
+
" copy"
...
...
@@ -1446,7 +1446,7 @@ class FunctionMaker:
if
not
hasattr
(
mode
.
linker
,
"accept"
):
raise
ValueError
(
"'linker' parameter of FunctionMaker should be "
f
"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers
.keys()
)}"
f
"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers)}"
)
def
__init__
(
...
...
pytensor/compile/profiling.py
浏览文件 @
8ae2a195
...
...
@@ -1446,7 +1446,7 @@ class ProfileStats:
file
=
file
,
)
if
config
.
profiling__debugprint
:
fcts
=
{
fgraph
for
(
fgraph
,
n
)
in
self
.
apply_time
.
keys
()
}
fcts
=
{
fgraph
for
(
fgraph
,
n
)
in
self
.
apply_time
}
pytensor
.
printing
.
debugprint
(
fcts
,
print_type
=
True
)
if
self
.
variable_shape
or
self
.
variable_strides
:
self
.
summary_memory
(
file
,
n_apply_to_print
)
...
...
pytensor/configdefaults.py
浏览文件 @
8ae2a195
...
...
@@ -1318,7 +1318,7 @@ def add_caching_dir_configvars():
_compiledir_format_dict
[
"short_platform"
]
=
short_platform
()
# Allow to have easily one compiledir per device.
_compiledir_format_dict
[
"device"
]
=
config
.
device
compiledir_format_keys
=
", "
.
join
(
sorted
(
_compiledir_format_dict
.
keys
()
))
compiledir_format_keys
=
", "
.
join
(
sorted
(
_compiledir_format_dict
))
_default_compiledir_format
=
(
"compiledir_
%(short_platform)
s-
%(processor)
s-"
"
%(python_version)
s-
%(python_bitwidth)
s"
...
...
pytensor/configparser.py
浏览文件 @
8ae2a195
...
...
@@ -214,7 +214,7 @@ class PyTensorConfigParser:
return
_ChangeFlagsDecorator
(
*
args
,
_root
=
self
,
**
kwargs
)
def
warn_unused_flags
(
self
):
for
key
in
self
.
_flags_dict
.
keys
()
:
for
key
in
self
.
_flags_dict
:
warnings
.
warn
(
f
"PyTensor does not recognise this flag: {key}"
)
...
...
pytensor/gradient.py
浏览文件 @
8ae2a195
...
...
@@ -500,7 +500,7 @@ def grad(
if
cost
is
not
None
:
outputs
.
append
(
cost
)
if
known_grads
is
not
None
:
outputs
.
extend
(
list
(
known_grads
.
keys
()
))
outputs
.
extend
(
list
(
known_grads
))
var_to_app_to_idx
=
_populate_var_to_app_to_idx
(
outputs
,
_wrt
,
consider_constant
)
...
...
@@ -966,7 +966,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
visit
(
elem
)
# Remove variables that don't have wrt as a true ancestor
orig_vars
=
list
(
var_to_app_to_idx
.
keys
()
)
orig_vars
=
list
(
var_to_app_to_idx
)
for
var
in
orig_vars
:
if
var
not
in
visited
:
del
var_to_app_to_idx
[
var
]
...
...
pytensor/graph/basic.py
浏览文件 @
8ae2a195
...
...
@@ -631,7 +631,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
if
not
hasattr
(
self
,
"_fn_cache"
):
self
.
_fn_cache
:
dict
=
dict
()
inputs
=
tuple
(
sorted
(
parsed_inputs_to_values
.
keys
()
,
key
=
id
))
inputs
=
tuple
(
sorted
(
parsed_inputs_to_values
,
key
=
id
))
cache_key
=
(
inputs
,
tuple
(
kwargs
.
items
()))
try
:
fn
=
self
.
_fn_cache
[
cache_key
]
...
...
pytensor/graph/destroyhandler.py
浏览文件 @
8ae2a195
...
...
@@ -406,7 +406,7 @@ class DestroyHandler(Bookkeeper):
# If True means that the apply node, destroys the protected_var.
if
idx
in
[
dmap
for
sublist
in
destroy_maps
for
dmap
in
sublist
]:
return
True
for
var_idx
in
app
.
op
.
view_map
.
keys
()
:
for
var_idx
in
app
.
op
.
view_map
:
if
idx
in
app
.
op
.
view_map
[
var_idx
]:
# We need to recursively check the destroy_map of all the
# outputs that we have a view_map on.
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
8ae2a195
...
...
@@ -15,7 +15,7 @@ from collections.abc import Callable, Iterable, Sequence
from
collections.abc
import
Iterable
as
IterableType
from
functools
import
_compose_mro
,
partial
,
reduce
# type: ignore
from
itertools
import
chain
from
typing
import
TYPE_CHECKING
,
Literal
,
cast
from
typing
import
TYPE_CHECKING
,
Literal
import
pytensor
from
pytensor.configdefaults
import
config
...
...
@@ -1924,9 +1924,9 @@ class NodeProcessingGraphRewriter(GraphRewriter):
remove
:
list
[
Variable
]
=
[]
if
isinstance
(
replacements
,
dict
):
if
"remove"
in
replacements
:
remove
=
list
(
cast
(
Sequence
[
Variable
],
replacements
.
pop
(
"remove"
)
))
old_vars
=
list
(
cast
(
Sequence
[
Variable
],
replacements
.
keys
())
)
replacements
=
list
(
cast
(
Sequence
[
Variable
],
replacements
.
values
()
))
remove
=
list
(
replacements
.
pop
(
"remove"
))
old_vars
=
list
(
replacements
)
replacements
=
list
(
replacements
.
values
(
))
elif
not
isinstance
(
replacements
,
tuple
|
list
):
raise
TypeError
(
f
"Node rewriter {node_rewriter} gave wrong type of replacement. "
...
...
pytensor/graph/utils.py
浏览文件 @
8ae2a195
...
...
@@ -168,7 +168,7 @@ class MissingInputError(Exception):
def
__init__
(
self
,
*
args
,
**
kwargs
):
if
kwargs
:
# The call to list is needed for Python 3
assert
list
(
kwargs
.
keys
()
)
==
[
"variable"
]
assert
list
(
kwargs
)
==
[
"variable"
]
error_msg
=
get_variable_trace_string
(
kwargs
[
"variable"
])
if
error_msg
:
args
=
(
*
args
,
error_msg
)
...
...
pytensor/link/c/params_type.py
浏览文件 @
8ae2a195
...
...
@@ -264,10 +264,7 @@ class Params(dict):
def
__repr__
(
self
):
return
"Params({})"
.
format
(
", "
.
join
(
[
(
f
"{k}:{type(self[k]).__name__}:{self[k]}"
)
for
k
in
sorted
(
self
.
keys
())
]
[(
f
"{k}:{type(self[k]).__name__}:{self[k]}"
)
for
k
in
sorted
(
self
)]
)
)
...
...
@@ -365,7 +362,7 @@ class ParamsType(CType):
)
self
.
length
=
len
(
kwargs
)
self
.
fields
=
tuple
(
sorted
(
kwargs
.
keys
()
))
self
.
fields
=
tuple
(
sorted
(
kwargs
))
self
.
types
=
tuple
(
kwargs
[
field
]
for
field
in
self
.
fields
)
self
.
name
=
self
.
generate_struct_name
()
...
...
pytensor/link/c/type.py
浏览文件 @
8ae2a195
...
...
@@ -472,7 +472,7 @@ class EnumType(CType, dict):
"""
Return the sorted tuple of all aliases in this enumeration.
"""
return
tuple
(
sorted
(
self
.
aliases
.
keys
()
))
return
tuple
(
sorted
(
self
.
aliases
))
def
__repr__
(
self
):
names_to_aliases
=
{
constant_name
:
""
for
constant_name
in
self
}
...
...
@@ -481,9 +481,7 @@ class EnumType(CType, dict):
return
"{}<{}>({})"
.
format
(
type
(
self
)
.
__name__
,
self
.
ctype
,
", "
.
join
(
f
"{k}{names_to_aliases[k]}:{self[k]}"
for
k
in
sorted
(
self
.
keys
())
),
", "
.
join
(
f
"{k}{names_to_aliases[k]}:{self[k]}"
for
k
in
sorted
(
self
)),
)
def
__getattr__
(
self
,
key
):
...
...
@@ -612,7 +610,7 @@ class EnumType(CType, dict):
f
"""
#define {k} {self[k]!s}
"""
for
k
in
sorted
(
self
.
keys
()
)
for
k
in
sorted
(
self
)
)
+
self
.
c_to_string
()
)
...
...
@@ -772,7 +770,7 @@ class CEnumType(EnumList):
case
%(i)
d:
%(name)
s =
%(constant_cname)
s; break;
"""
%
dict
(
i
=
i
,
name
=
name
,
constant_cname
=
swapped_dict
[
i
])
for
i
in
sorted
(
swapped_dict
.
keys
()
)
for
i
in
sorted
(
swapped_dict
)
),
fail
=
sub
[
"fail"
],
)
...
...
pytensor/link/numba/dispatch/scalar.py
浏览文件 @
8ae2a195
...
...
@@ -117,9 +117,7 @@ def {scalar_op_fn_name}({input_names}):
converted_call_args
=
", "
.
join
(
[
f
"direct_cast({i_name}, {i_tmp_dtype_name})"
for
i_name
,
i_tmp_dtype_name
in
zip
(
input_names
,
input_tmp_dtype_names
.
keys
()
)
for
i_name
,
i_tmp_dtype_name
in
zip
(
input_names
,
input_tmp_dtype_names
)
]
)
if
not
has_pyx_skip_dispatch
:
...
...
pytensor/link/numba/dispatch/scan.py
浏览文件 @
8ae2a195
...
...
@@ -70,7 +70,7 @@ def numba_funcify_Scan(op, node, **kwargs):
outer_in_names_to_vars
=
{
(
f
"outer_in_{i}"
if
i
>
0
else
"n_steps"
):
v
for
i
,
v
in
enumerate
(
node
.
inputs
)
}
outer_in_names
=
list
(
outer_in_names_to_vars
.
keys
()
)
outer_in_names
=
list
(
outer_in_names_to_vars
)
outer_in_seqs_names
=
op
.
outer_seqs
(
outer_in_names
)
outer_in_mit_mot_names
=
op
.
outer_mitmot
(
outer_in_names
)
outer_in_mit_sot_names
=
op
.
outer_mitsot
(
outer_in_names
)
...
...
pytensor/link/vm.py
浏览文件 @
8ae2a195
...
...
@@ -990,7 +990,7 @@ class VMLinker(LocalLinker):
for
pair
in
reallocated_info
.
values
():
storage_map
[
pair
[
1
]]
=
storage_map
[
pair
[
0
]]
return
tuple
(
reallocated_info
.
keys
()
)
return
tuple
(
reallocated_info
)
def
make_vm
(
self
,
...
...
pytensor/printing.py
浏览文件 @
8ae2a195
...
...
@@ -1106,9 +1106,7 @@ class PPrinter(Printer):
outputs
=
[
outputs
]
current
=
None
if
display_inputs
:
strings
=
[
(
0
,
"inputs: "
+
", "
.
join
(
map
(
str
,
list
(
inputs
)
+
updates
.
keys
())))
]
strings
=
[(
0
,
"inputs: "
+
", "
.
join
(
str
(
x
)
for
x
in
[
*
inputs
,
*
updates
]))]
else
:
strings
=
[]
pprinter
=
self
.
clone_assign
(
...
...
@@ -1116,9 +1114,7 @@ class PPrinter(Printer):
)
inv_updates
=
{
b
:
a
for
(
a
,
b
)
in
updates
.
items
()}
i
=
1
for
node
in
io_toposort
(
list
(
inputs
)
+
updates
.
keys
(),
list
(
outputs
)
+
updates
.
values
()
):
for
node
in
io_toposort
([
*
inputs
,
*
updates
],
[
*
outputs
,
*
updates
.
values
()]):
for
output
in
node
.
outputs
:
if
output
in
inv_updates
:
name
=
str
(
inv_updates
[
output
])
...
...
pytensor/scalar/basic.py
浏览文件 @
8ae2a195
...
...
@@ -4426,7 +4426,7 @@ class Compositef32:
else
:
ni
=
i
mapping
[
i
]
=
ni
if
isinstance
(
node
.
op
,
tuple
(
self
.
special
.
keys
()
)):
if
isinstance
(
node
.
op
,
tuple
(
self
.
special
)):
self
.
special
[
type
(
node
.
op
)](
node
,
mapping
)
continue
new_node
=
node
.
clone_with_new_inputs
(
...
...
pytensor/scan/op.py
浏览文件 @
8ae2a195
...
...
@@ -1284,14 +1284,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def
__str__
(
self
):
inplace
=
"none"
if
len
(
self
.
destroy_map
.
keys
())
>
0
:
if
self
.
destroy_map
:
# Check if all outputs are inplace
if
sorted
(
self
.
destroy_map
.
keys
()
)
==
sorted
(
if
sorted
(
self
.
destroy_map
)
==
sorted
(
range
(
self
.
info
.
n_mit_mot
+
self
.
info
.
n_mit_sot
+
self
.
info
.
n_sit_sot
)
):
inplace
=
"all"
else
:
inplace
=
str
(
list
(
self
.
destroy_map
.
keys
()
))
inplace
=
str
(
list
(
self
.
destroy_map
))
return
(
f
"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
)
...
...
pytensor/scan/utils.py
浏览文件 @
8ae2a195
...
...
@@ -269,7 +269,7 @@ class Validator:
# Mapping from invalid variables to equivalent valid ones.
self
.
valid_equivalent
=
valid_equivalent
.
copy
()
self
.
valid
.
update
(
list
(
valid_equivalent
.
values
()))
self
.
invalid
.
update
(
list
(
valid_equivalent
.
keys
()
))
self
.
invalid
.
update
(
list
(
valid_equivalent
))
def
check
(
self
,
out
):
"""
...
...
pytensor/sparse/basic.py
浏览文件 @
8ae2a195
...
...
@@ -524,7 +524,7 @@ csc_fmatrix = SparseTensorType(format="csc", dtype="float32")
csr_fmatrix
=
SparseTensorType
(
format
=
"csr"
,
dtype
=
"float32"
)
bsr_fmatrix
=
SparseTensorType
(
format
=
"bsr"
,
dtype
=
"float32"
)
all_dtypes
=
list
(
SparseTensorType
.
dtype_specs_map
.
keys
()
)
all_dtypes
=
list
(
SparseTensorType
.
dtype_specs_map
)
complex_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
7
]
==
"complex"
]
float_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
5
]
==
"float"
]
int_dtypes
=
[
t
for
t
in
all_dtypes
if
t
[:
3
]
==
"int"
]
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
8ae2a195
...
...
@@ -71,7 +71,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
ndim
=
prof
[
"ndim"
]
if
ndim
:
print
(
blanc
,
"ndim"
,
"nb"
,
file
=
stream
)
for
n
in
sorted
(
ndim
.
keys
()
):
for
n
in
sorted
(
ndim
):
print
(
blanc
,
n
,
ndim
[
n
],
file
=
stream
)
def
candidate_input_idxs
(
self
,
node
):
...
...
pytensor/tensor/utils.py
浏览文件 @
8ae2a195
...
...
@@ -88,7 +88,7 @@ def shape_of_variables(
compute_shapes
=
pytensor
.
function
(
input_dims
,
output_dims
)
if
any
(
i
not
in
fgraph
.
inputs
for
i
in
input_shapes
.
keys
()
):
if
any
(
i
not
in
fgraph
.
inputs
for
i
in
input_shapes
):
raise
ValueError
(
"input_shapes keys aren't in the fgraph.inputs. FunctionGraph()"
" interface changed. Now by default, it clones the graph it receives."
...
...
tests/compile/function/test_types.py
浏览文件 @
8ae2a195
...
...
@@ -889,9 +889,9 @@ class TestPicklefunction:
return
else
:
raise
# if they both return, assume
that they return equivalent things.
# print [(k,
id(k)) for k in f.finder.keys()
]
# print [(k,
id(k)) for k in g.finder.keys()
]
# if they both return, assume that they return equivalent things.
# print [(k,
id(k)) for k in f.finder
]
# print [(k,
id(k)) for k in g.finder
]
assert
g
.
container
[
0
]
.
storage
is
not
f
.
container
[
0
]
.
storage
assert
g
.
container
[
1
]
.
storage
is
not
f
.
container
[
1
]
.
storage
...
...
@@ -1012,9 +1012,9 @@ class TestPicklefunction:
return
else
:
raise
# if they both return, assume
that they return equivalent things.
# print [(k,
id(k)) for k in f.finder.keys()
]
# print [(k,
id(k)) for k in g.finder.keys()
]
# if they both return, assume that they return equivalent things.
# print [(k,
id(k)) for k in f.finder
]
# print [(k,
id(k)) for k in g.finder
]
assert
g
.
container
[
0
]
.
storage
is
not
f
.
container
[
0
]
.
storage
assert
g
.
container
[
1
]
.
storage
is
not
f
.
container
[
1
]
.
storage
...
...
tests/link/numba/test_basic.py
浏览文件 @
8ae2a195
...
...
@@ -829,7 +829,7 @@ def test_config_options_fastmath():
with
config
.
change_flags
(
numba__fastmath
=
True
):
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
print
(
list
(
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
.
keys
()
))
print
(
list
(
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
))
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
assert
numba_mul_fn
.
targetoptions
[
"fastmath"
]
is
True
...
...
tests/link/numba/test_scan.py
浏览文件 @
8ae2a195
...
...
@@ -479,14 +479,14 @@ def test_vector_taps_benchmark(benchmark):
sitsot_init
:
rng
.
normal
(),
}
numba_fn
=
pytensor
.
function
(
list
(
test
.
keys
()
),
outs
,
mode
=
get_mode
(
"NUMBA"
))
numba_fn
=
pytensor
.
function
(
list
(
test
),
outs
,
mode
=
get_mode
(
"NUMBA"
))
scan_nodes
=
[
node
for
node
in
numba_fn
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
]
assert
len
(
scan_nodes
)
==
1
numba_res
=
numba_fn
(
*
test
.
values
())
ref_fn
=
pytensor
.
function
(
list
(
test
.
keys
()
),
outs
,
mode
=
get_mode
(
"FAST_COMPILE"
))
ref_fn
=
pytensor
.
function
(
list
(
test
),
outs
,
mode
=
get_mode
(
"FAST_COMPILE"
))
ref_res
=
ref_fn
(
*
test
.
values
())
for
numba_r
,
ref_r
in
zip
(
numba_res
,
ref_res
):
np
.
testing
.
assert_array_almost_equal
(
numba_r
,
ref_r
)
...
...
tests/link/test_utils.py
浏览文件 @
8ae2a195
...
...
@@ -57,7 +57,7 @@ def test_fgraph_to_python_names():
"scalar_variable"
,
"tensor_variable_1"
,
r
.
name
,
)
==
tuple
(
sig
.
parameters
.
keys
()
)
)
==
tuple
(
sig
.
parameters
)
assert
(
1
,
2
,
3
,
4
,
5
)
==
out_jx
(
1
,
2
,
3
,
4
,
5
)
obj
=
object
()
...
...
tests/link/test_vm.py
浏览文件 @
8ae2a195
...
...
@@ -337,7 +337,7 @@ def test_reallocation():
def
check_storage
(
storage_map
):
for
i
in
storage_map
:
if
not
isinstance
(
i
,
TensorConstant
):
keys_copy
=
list
(
storage_map
.
keys
()
)[:]
keys_copy
=
list
(
storage_map
)[:]
keys_copy
.
remove
(
i
)
for
o
in
keys_copy
:
if
storage_map
[
i
][
0
]
and
storage_map
[
i
][
0
]
is
storage_map
[
o
][
0
]:
...
...
tests/scan/test_rewriting.py
浏览文件 @
8ae2a195
...
...
@@ -1097,8 +1097,8 @@ class TestScanInplaceOptimizer:
allow_input_downcast
=
True
,
)
scan_node
=
[
x
for
x
in
f9
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
assert
0
in
scan_node
[
0
]
.
op
.
destroy_map
.
keys
()
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
.
keys
()
assert
0
in
scan_node
[
0
]
.
op
.
destroy_map
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
# compute output in numpy
numpy_x0
=
np
.
zeros
((
3
,))
numpy_x1
=
np
.
zeros
((
3
,))
...
...
@@ -1163,8 +1163,8 @@ class TestScanInplaceOptimizer:
)
scan_node
=
[
x
for
x
in
f9
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
assert
0
in
scan_node
[
0
]
.
op
.
destroy_map
.
keys
()
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
.
keys
()
assert
0
in
scan_node
[
0
]
.
op
.
destroy_map
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
# compute output in numpy
numpy_x0
=
np
.
zeros
((
3
,))
numpy_x1
=
np
.
zeros
((
3
,))
...
...
@@ -1203,8 +1203,8 @@ class TestScanInplaceOptimizer:
f9
=
function
([],
outputs
,
updates
=
updates
,
mode
=
self
.
mode
)
scan_node
=
[
x
for
x
in
f9
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
assert
0
not
in
scan_node
[
0
]
.
op
.
destroy_map
.
keys
()
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
.
keys
()
assert
0
not
in
scan_node
[
0
]
.
op
.
destroy_map
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
class
TestSaveMem
:
...
...
tests/test_config.py
浏览文件 @
8ae2a195
...
...
@@ -222,7 +222,7 @@ def test_config_pickling():
buffer
.
seek
(
0
)
restored
=
pickle
.
load
(
buffer
)
# ...without a change in the config values
for
name
in
root
.
_config_var_dict
.
keys
()
:
for
name
in
root
.
_config_var_dict
:
v_original
=
getattr
(
root
,
name
)
v_restored
=
getattr
(
restored
,
name
)
assert
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论