提交 8ae2a195 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Remove dict.keys() when unnecessary

上级 63da6d16
...@@ -659,7 +659,7 @@ class Function: ...@@ -659,7 +659,7 @@ class Function:
exist_svs = [i.variable for i in maker.inputs] exist_svs = [i.variable for i in maker.inputs]
# Check if given ShareVariables exist # Check if given ShareVariables exist
for sv in swap.keys(): for sv in swap:
if sv not in exist_svs: if sv not in exist_svs:
raise ValueError(f"SharedVariable: {sv.name} not found") raise ValueError(f"SharedVariable: {sv.name} not found")
...@@ -711,9 +711,9 @@ class Function: ...@@ -711,9 +711,9 @@ class Function:
# it is well tested, we don't share the part of the storage_map. # it is well tested, we don't share the part of the storage_map.
if share_memory: if share_memory:
i_o_vars = maker.fgraph.inputs + maker.fgraph.outputs i_o_vars = maker.fgraph.inputs + maker.fgraph.outputs
for key in storage_map.keys(): for key, val in storage_map.items():
if key not in i_o_vars: if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = val
if not name and self.name: if not name and self.name:
name = self.name + " copy" name = self.name + " copy"
...@@ -1446,7 +1446,7 @@ class FunctionMaker: ...@@ -1446,7 +1446,7 @@ class FunctionMaker:
if not hasattr(mode.linker, "accept"): if not hasattr(mode.linker, "accept"):
raise ValueError( raise ValueError(
"'linker' parameter of FunctionMaker should be " "'linker' parameter of FunctionMaker should be "
f"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers.keys())}" f"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers)}"
) )
def __init__( def __init__(
......
...@@ -1446,7 +1446,7 @@ class ProfileStats: ...@@ -1446,7 +1446,7 @@ class ProfileStats:
file=file, file=file,
) )
if config.profiling__debugprint: if config.profiling__debugprint:
fcts = {fgraph for (fgraph, n) in self.apply_time.keys()} fcts = {fgraph for (fgraph, n) in self.apply_time}
pytensor.printing.debugprint(fcts, print_type=True) pytensor.printing.debugprint(fcts, print_type=True)
if self.variable_shape or self.variable_strides: if self.variable_shape or self.variable_strides:
self.summary_memory(file, n_apply_to_print) self.summary_memory(file, n_apply_to_print)
......
...@@ -1318,7 +1318,7 @@ def add_caching_dir_configvars(): ...@@ -1318,7 +1318,7 @@ def add_caching_dir_configvars():
_compiledir_format_dict["short_platform"] = short_platform() _compiledir_format_dict["short_platform"] = short_platform()
# Allow to have easily one compiledir per device. # Allow to have easily one compiledir per device.
_compiledir_format_dict["device"] = config.device _compiledir_format_dict["device"] = config.device
compiledir_format_keys = ", ".join(sorted(_compiledir_format_dict.keys())) compiledir_format_keys = ", ".join(sorted(_compiledir_format_dict))
_default_compiledir_format = ( _default_compiledir_format = (
"compiledir_%(short_platform)s-%(processor)s-" "compiledir_%(short_platform)s-%(processor)s-"
"%(python_version)s-%(python_bitwidth)s" "%(python_version)s-%(python_bitwidth)s"
......
...@@ -214,7 +214,7 @@ class PyTensorConfigParser: ...@@ -214,7 +214,7 @@ class PyTensorConfigParser:
return _ChangeFlagsDecorator(*args, _root=self, **kwargs) return _ChangeFlagsDecorator(*args, _root=self, **kwargs)
def warn_unused_flags(self): def warn_unused_flags(self):
for key in self._flags_dict.keys(): for key in self._flags_dict:
warnings.warn(f"PyTensor does not recognise this flag: {key}") warnings.warn(f"PyTensor does not recognise this flag: {key}")
......
...@@ -500,7 +500,7 @@ def grad( ...@@ -500,7 +500,7 @@ def grad(
if cost is not None: if cost is not None:
outputs.append(cost) outputs.append(cost)
if known_grads is not None: if known_grads is not None:
outputs.extend(list(known_grads.keys())) outputs.extend(list(known_grads))
var_to_app_to_idx = _populate_var_to_app_to_idx(outputs, _wrt, consider_constant) var_to_app_to_idx = _populate_var_to_app_to_idx(outputs, _wrt, consider_constant)
...@@ -966,7 +966,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant): ...@@ -966,7 +966,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
visit(elem) visit(elem)
# Remove variables that don't have wrt as a true ancestor # Remove variables that don't have wrt as a true ancestor
orig_vars = list(var_to_app_to_idx.keys()) orig_vars = list(var_to_app_to_idx)
for var in orig_vars: for var in orig_vars:
if var not in visited: if var not in visited:
del var_to_app_to_idx[var] del var_to_app_to_idx[var]
......
...@@ -631,7 +631,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]): ...@@ -631,7 +631,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
if not hasattr(self, "_fn_cache"): if not hasattr(self, "_fn_cache"):
self._fn_cache: dict = dict() self._fn_cache: dict = dict()
inputs = tuple(sorted(parsed_inputs_to_values.keys(), key=id)) inputs = tuple(sorted(parsed_inputs_to_values, key=id))
cache_key = (inputs, tuple(kwargs.items())) cache_key = (inputs, tuple(kwargs.items()))
try: try:
fn = self._fn_cache[cache_key] fn = self._fn_cache[cache_key]
......
...@@ -406,7 +406,7 @@ class DestroyHandler(Bookkeeper): ...@@ -406,7 +406,7 @@ class DestroyHandler(Bookkeeper):
# If True means that the apply node, destroys the protected_var. # If True means that the apply node, destroys the protected_var.
if idx in [dmap for sublist in destroy_maps for dmap in sublist]: if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
return True return True
for var_idx in app.op.view_map.keys(): for var_idx in app.op.view_map:
if idx in app.op.view_map[var_idx]: if idx in app.op.view_map[var_idx]:
# We need to recursively check the destroy_map of all the # We need to recursively check the destroy_map of all the
# outputs that we have a view_map on. # outputs that we have a view_map on.
......
...@@ -15,7 +15,7 @@ from collections.abc import Callable, Iterable, Sequence ...@@ -15,7 +15,7 @@ from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable as IterableType from collections.abc import Iterable as IterableType
from functools import _compose_mro, partial, reduce # type: ignore from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Literal, cast from typing import TYPE_CHECKING, Literal
import pytensor import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -1924,9 +1924,9 @@ class NodeProcessingGraphRewriter(GraphRewriter): ...@@ -1924,9 +1924,9 @@ class NodeProcessingGraphRewriter(GraphRewriter):
remove: list[Variable] = [] remove: list[Variable] = []
if isinstance(replacements, dict): if isinstance(replacements, dict):
if "remove" in replacements: if "remove" in replacements:
remove = list(cast(Sequence[Variable], replacements.pop("remove"))) remove = list(replacements.pop("remove"))
old_vars = list(cast(Sequence[Variable], replacements.keys())) old_vars = list(replacements)
replacements = list(cast(Sequence[Variable], replacements.values())) replacements = list(replacements.values())
elif not isinstance(replacements, tuple | list): elif not isinstance(replacements, tuple | list):
raise TypeError( raise TypeError(
f"Node rewriter {node_rewriter} gave wrong type of replacement. " f"Node rewriter {node_rewriter} gave wrong type of replacement. "
......
...@@ -168,7 +168,7 @@ class MissingInputError(Exception): ...@@ -168,7 +168,7 @@ class MissingInputError(Exception):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if kwargs: if kwargs:
# The call to list is needed for Python 3 # The call to list is needed for Python 3
assert list(kwargs.keys()) == ["variable"] assert list(kwargs) == ["variable"]
error_msg = get_variable_trace_string(kwargs["variable"]) error_msg = get_variable_trace_string(kwargs["variable"])
if error_msg: if error_msg:
args = (*args, error_msg) args = (*args, error_msg)
......
...@@ -264,10 +264,7 @@ class Params(dict): ...@@ -264,10 +264,7 @@ class Params(dict):
def __repr__(self): def __repr__(self):
return "Params({})".format( return "Params({})".format(
", ".join( ", ".join(
[ [(f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self)]
(f"{k}:{type(self[k]).__name__}:{self[k]}")
for k in sorted(self.keys())
]
) )
) )
...@@ -365,7 +362,7 @@ class ParamsType(CType): ...@@ -365,7 +362,7 @@ class ParamsType(CType):
) )
self.length = len(kwargs) self.length = len(kwargs)
self.fields = tuple(sorted(kwargs.keys())) self.fields = tuple(sorted(kwargs))
self.types = tuple(kwargs[field] for field in self.fields) self.types = tuple(kwargs[field] for field in self.fields)
self.name = self.generate_struct_name() self.name = self.generate_struct_name()
......
...@@ -472,7 +472,7 @@ class EnumType(CType, dict): ...@@ -472,7 +472,7 @@ class EnumType(CType, dict):
""" """
Return the sorted tuple of all aliases in this enumeration. Return the sorted tuple of all aliases in this enumeration.
""" """
return tuple(sorted(self.aliases.keys())) return tuple(sorted(self.aliases))
def __repr__(self): def __repr__(self):
names_to_aliases = {constant_name: "" for constant_name in self} names_to_aliases = {constant_name: "" for constant_name in self}
...@@ -481,9 +481,7 @@ class EnumType(CType, dict): ...@@ -481,9 +481,7 @@ class EnumType(CType, dict):
return "{}<{}>({})".format( return "{}<{}>({})".format(
type(self).__name__, type(self).__name__,
self.ctype, self.ctype,
", ".join( ", ".join(f"{k}{names_to_aliases[k]}:{self[k]}" for k in sorted(self)),
f"{k}{names_to_aliases[k]}:{self[k]}" for k in sorted(self.keys())
),
) )
def __getattr__(self, key): def __getattr__(self, key):
...@@ -612,7 +610,7 @@ class EnumType(CType, dict): ...@@ -612,7 +610,7 @@ class EnumType(CType, dict):
f""" f"""
#define {k} {self[k]!s} #define {k} {self[k]!s}
""" """
for k in sorted(self.keys()) for k in sorted(self)
) )
+ self.c_to_string() + self.c_to_string()
) )
...@@ -772,7 +770,7 @@ class CEnumType(EnumList): ...@@ -772,7 +770,7 @@ class CEnumType(EnumList):
case %(i)d: %(name)s = %(constant_cname)s; break; case %(i)d: %(name)s = %(constant_cname)s; break;
""" """
% dict(i=i, name=name, constant_cname=swapped_dict[i]) % dict(i=i, name=name, constant_cname=swapped_dict[i])
for i in sorted(swapped_dict.keys()) for i in sorted(swapped_dict)
), ),
fail=sub["fail"], fail=sub["fail"],
) )
......
...@@ -117,9 +117,7 @@ def {scalar_op_fn_name}({input_names}): ...@@ -117,9 +117,7 @@ def {scalar_op_fn_name}({input_names}):
converted_call_args = ", ".join( converted_call_args = ", ".join(
[ [
f"direct_cast({i_name}, {i_tmp_dtype_name})" f"direct_cast({i_name}, {i_tmp_dtype_name})"
for i_name, i_tmp_dtype_name in zip( for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names)
input_names, input_tmp_dtype_names.keys()
)
] ]
) )
if not has_pyx_skip_dispatch: if not has_pyx_skip_dispatch:
......
...@@ -70,7 +70,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -70,7 +70,7 @@ def numba_funcify_Scan(op, node, **kwargs):
outer_in_names_to_vars = { outer_in_names_to_vars = {
(f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs) (f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)
} }
outer_in_names = list(outer_in_names_to_vars.keys()) outer_in_names = list(outer_in_names_to_vars)
outer_in_seqs_names = op.outer_seqs(outer_in_names) outer_in_seqs_names = op.outer_seqs(outer_in_names)
outer_in_mit_mot_names = op.outer_mitmot(outer_in_names) outer_in_mit_mot_names = op.outer_mitmot(outer_in_names)
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names) outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
......
...@@ -990,7 +990,7 @@ class VMLinker(LocalLinker): ...@@ -990,7 +990,7 @@ class VMLinker(LocalLinker):
for pair in reallocated_info.values(): for pair in reallocated_info.values():
storage_map[pair[1]] = storage_map[pair[0]] storage_map[pair[1]] = storage_map[pair[0]]
return tuple(reallocated_info.keys()) return tuple(reallocated_info)
def make_vm( def make_vm(
self, self,
......
...@@ -1106,9 +1106,7 @@ class PPrinter(Printer): ...@@ -1106,9 +1106,7 @@ class PPrinter(Printer):
outputs = [outputs] outputs = [outputs]
current = None current = None
if display_inputs: if display_inputs:
strings = [ strings = [(0, "inputs: " + ", ".join(str(x) for x in [*inputs, *updates]))]
(0, "inputs: " + ", ".join(map(str, list(inputs) + updates.keys())))
]
else: else:
strings = [] strings = []
pprinter = self.clone_assign( pprinter = self.clone_assign(
...@@ -1116,9 +1114,7 @@ class PPrinter(Printer): ...@@ -1116,9 +1114,7 @@ class PPrinter(Printer):
) )
inv_updates = {b: a for (a, b) in updates.items()} inv_updates = {b: a for (a, b) in updates.items()}
i = 1 i = 1
for node in io_toposort( for node in io_toposort([*inputs, *updates], [*outputs, *updates.values()]):
list(inputs) + updates.keys(), list(outputs) + updates.values()
):
for output in node.outputs: for output in node.outputs:
if output in inv_updates: if output in inv_updates:
name = str(inv_updates[output]) name = str(inv_updates[output])
......
...@@ -4426,7 +4426,7 @@ class Compositef32: ...@@ -4426,7 +4426,7 @@ class Compositef32:
else: else:
ni = i ni = i
mapping[i] = ni mapping[i] = ni
if isinstance(node.op, tuple(self.special.keys())): if isinstance(node.op, tuple(self.special)):
self.special[type(node.op)](node, mapping) self.special[type(node.op)](node, mapping)
continue continue
new_node = node.clone_with_new_inputs( new_node = node.clone_with_new_inputs(
......
...@@ -1284,14 +1284,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1284,14 +1284,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def __str__(self): def __str__(self):
inplace = "none" inplace = "none"
if len(self.destroy_map.keys()) > 0: if self.destroy_map:
# Check if all outputs are inplace # Check if all outputs are inplace
if sorted(self.destroy_map.keys()) == sorted( if sorted(self.destroy_map) == sorted(
range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot) range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
): ):
inplace = "all" inplace = "all"
else: else:
inplace = str(list(self.destroy_map.keys())) inplace = str(list(self.destroy_map))
return ( return (
f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}" f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
) )
......
...@@ -269,7 +269,7 @@ class Validator: ...@@ -269,7 +269,7 @@ class Validator:
# Mapping from invalid variables to equivalent valid ones. # Mapping from invalid variables to equivalent valid ones.
self.valid_equivalent = valid_equivalent.copy() self.valid_equivalent = valid_equivalent.copy()
self.valid.update(list(valid_equivalent.values())) self.valid.update(list(valid_equivalent.values()))
self.invalid.update(list(valid_equivalent.keys())) self.invalid.update(list(valid_equivalent))
def check(self, out): def check(self, out):
""" """
......
...@@ -524,7 +524,7 @@ csc_fmatrix = SparseTensorType(format="csc", dtype="float32") ...@@ -524,7 +524,7 @@ csc_fmatrix = SparseTensorType(format="csc", dtype="float32")
csr_fmatrix = SparseTensorType(format="csr", dtype="float32") csr_fmatrix = SparseTensorType(format="csr", dtype="float32")
bsr_fmatrix = SparseTensorType(format="bsr", dtype="float32") bsr_fmatrix = SparseTensorType(format="bsr", dtype="float32")
all_dtypes = list(SparseTensorType.dtype_specs_map.keys()) all_dtypes = list(SparseTensorType.dtype_specs_map)
complex_dtypes = [t for t in all_dtypes if t[:7] == "complex"] complex_dtypes = [t for t in all_dtypes if t[:7] == "complex"]
float_dtypes = [t for t in all_dtypes if t[:5] == "float"] float_dtypes = [t for t in all_dtypes if t[:5] == "float"]
int_dtypes = [t for t in all_dtypes if t[:3] == "int"] int_dtypes = [t for t in all_dtypes if t[:3] == "int"]
......
...@@ -71,7 +71,7 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -71,7 +71,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
ndim = prof["ndim"] ndim = prof["ndim"]
if ndim: if ndim:
print(blanc, "ndim", "nb", file=stream) print(blanc, "ndim", "nb", file=stream)
for n in sorted(ndim.keys()): for n in sorted(ndim):
print(blanc, n, ndim[n], file=stream) print(blanc, n, ndim[n], file=stream)
def candidate_input_idxs(self, node): def candidate_input_idxs(self, node):
......
...@@ -88,7 +88,7 @@ def shape_of_variables( ...@@ -88,7 +88,7 @@ def shape_of_variables(
compute_shapes = pytensor.function(input_dims, output_dims) compute_shapes = pytensor.function(input_dims, output_dims)
if any(i not in fgraph.inputs for i in input_shapes.keys()): if any(i not in fgraph.inputs for i in input_shapes):
raise ValueError( raise ValueError(
"input_shapes keys aren't in the fgraph.inputs. FunctionGraph()" "input_shapes keys aren't in the fgraph.inputs. FunctionGraph()"
" interface changed. Now by default, it clones the graph it receives." " interface changed. Now by default, it clones the graph it receives."
......
...@@ -889,9 +889,9 @@ class TestPicklefunction: ...@@ -889,9 +889,9 @@ class TestPicklefunction:
return return
else: else:
raise raise
# if they both return, assume that they return equivalent things. # if they both return, assume that they return equivalent things.
# print [(k,id(k)) for k in f.finder.keys()] # print [(k, id(k)) for k in f.finder]
# print [(k,id(k)) for k in g.finder.keys()] # print [(k, id(k)) for k in g.finder]
assert g.container[0].storage is not f.container[0].storage assert g.container[0].storage is not f.container[0].storage
assert g.container[1].storage is not f.container[1].storage assert g.container[1].storage is not f.container[1].storage
...@@ -1012,9 +1012,9 @@ class TestPicklefunction: ...@@ -1012,9 +1012,9 @@ class TestPicklefunction:
return return
else: else:
raise raise
# if they both return, assume that they return equivalent things. # if they both return, assume that they return equivalent things.
# print [(k,id(k)) for k in f.finder.keys()] # print [(k, id(k)) for k in f.finder]
# print [(k,id(k)) for k in g.finder.keys()] # print [(k, id(k)) for k in g.finder]
assert g.container[0].storage is not f.container[0].storage assert g.container[0].storage is not f.container[0].storage
assert g.container[1].storage is not f.container[1].storage assert g.container[1].storage is not f.container[1].storage
......
...@@ -829,7 +829,7 @@ def test_config_options_fastmath(): ...@@ -829,7 +829,7 @@ def test_config_options_fastmath():
with config.change_flags(numba__fastmath=True): with config.change_flags(numba__fastmath=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__.keys())) print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__))
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_mul_fn.targetoptions["fastmath"] is True assert numba_mul_fn.targetoptions["fastmath"] is True
......
...@@ -479,14 +479,14 @@ def test_vector_taps_benchmark(benchmark): ...@@ -479,14 +479,14 @@ def test_vector_taps_benchmark(benchmark):
sitsot_init: rng.normal(), sitsot_init: rng.normal(),
} }
numba_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("NUMBA")) numba_fn = pytensor.function(list(test), outs, mode=get_mode("NUMBA"))
scan_nodes = [ scan_nodes = [
node for node in numba_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan) node for node in numba_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
] ]
assert len(scan_nodes) == 1 assert len(scan_nodes) == 1
numba_res = numba_fn(*test.values()) numba_res = numba_fn(*test.values())
ref_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("FAST_COMPILE")) ref_fn = pytensor.function(list(test), outs, mode=get_mode("FAST_COMPILE"))
ref_res = ref_fn(*test.values()) ref_res = ref_fn(*test.values())
for numba_r, ref_r in zip(numba_res, ref_res): for numba_r, ref_r in zip(numba_res, ref_res):
np.testing.assert_array_almost_equal(numba_r, ref_r) np.testing.assert_array_almost_equal(numba_r, ref_r)
......
...@@ -57,7 +57,7 @@ def test_fgraph_to_python_names(): ...@@ -57,7 +57,7 @@ def test_fgraph_to_python_names():
"scalar_variable", "scalar_variable",
"tensor_variable_1", "tensor_variable_1",
r.name, r.name,
) == tuple(sig.parameters.keys()) ) == tuple(sig.parameters)
assert (1, 2, 3, 4, 5) == out_jx(1, 2, 3, 4, 5) assert (1, 2, 3, 4, 5) == out_jx(1, 2, 3, 4, 5)
obj = object() obj = object()
......
...@@ -337,7 +337,7 @@ def test_reallocation(): ...@@ -337,7 +337,7 @@ def test_reallocation():
def check_storage(storage_map): def check_storage(storage_map):
for i in storage_map: for i in storage_map:
if not isinstance(i, TensorConstant): if not isinstance(i, TensorConstant):
keys_copy = list(storage_map.keys())[:] keys_copy = list(storage_map)[:]
keys_copy.remove(i) keys_copy.remove(i)
for o in keys_copy: for o in keys_copy:
if storage_map[i][0] and storage_map[i][0] is storage_map[o][0]: if storage_map[i][0] and storage_map[i][0] is storage_map[o][0]:
......
...@@ -1097,8 +1097,8 @@ class TestScanInplaceOptimizer: ...@@ -1097,8 +1097,8 @@ class TestScanInplaceOptimizer:
allow_input_downcast=True, allow_input_downcast=True,
) )
scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)] scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert 0 in scan_node[0].op.destroy_map.keys() assert 0 in scan_node[0].op.destroy_map
assert 1 in scan_node[0].op.destroy_map.keys() assert 1 in scan_node[0].op.destroy_map
# compute output in numpy # compute output in numpy
numpy_x0 = np.zeros((3,)) numpy_x0 = np.zeros((3,))
numpy_x1 = np.zeros((3,)) numpy_x1 = np.zeros((3,))
...@@ -1163,8 +1163,8 @@ class TestScanInplaceOptimizer: ...@@ -1163,8 +1163,8 @@ class TestScanInplaceOptimizer:
) )
scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)] scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert 0 in scan_node[0].op.destroy_map.keys() assert 0 in scan_node[0].op.destroy_map
assert 1 in scan_node[0].op.destroy_map.keys() assert 1 in scan_node[0].op.destroy_map
# compute output in numpy # compute output in numpy
numpy_x0 = np.zeros((3,)) numpy_x0 = np.zeros((3,))
numpy_x1 = np.zeros((3,)) numpy_x1 = np.zeros((3,))
...@@ -1203,8 +1203,8 @@ class TestScanInplaceOptimizer: ...@@ -1203,8 +1203,8 @@ class TestScanInplaceOptimizer:
f9 = function([], outputs, updates=updates, mode=self.mode) f9 = function([], outputs, updates=updates, mode=self.mode)
scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)] scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert 0 not in scan_node[0].op.destroy_map.keys() assert 0 not in scan_node[0].op.destroy_map
assert 1 in scan_node[0].op.destroy_map.keys() assert 1 in scan_node[0].op.destroy_map
class TestSaveMem: class TestSaveMem:
......
...@@ -222,7 +222,7 @@ def test_config_pickling(): ...@@ -222,7 +222,7 @@ def test_config_pickling():
buffer.seek(0) buffer.seek(0)
restored = pickle.load(buffer) restored = pickle.load(buffer)
# ...without a change in the config values # ...without a change in the config values
for name in root._config_var_dict.keys(): for name in root._config_var_dict:
v_original = getattr(root, name) v_original = getattr(root, name)
v_restored = getattr(restored, name) v_restored = getattr(restored, name)
assert ( assert (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论