提交 e85c7fd0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace Scan info dict with ScanInfo dataclass

上级 d83ff33c
...@@ -13,7 +13,7 @@ from aesara.graph.fg import MissingInputError ...@@ -13,7 +13,7 @@ from aesara.graph.fg import MissingInputError
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.graph.utils import TestValueError from aesara.graph.utils import TestValueError
from aesara.scan import utils from aesara.scan import utils
from aesara.scan.op import Scan from aesara.scan.op import Scan, ScanInfo
from aesara.scan.utils import safe_new, traverse from aesara.scan.utils import safe_new, traverse
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import minimum from aesara.tensor.math import minimum
...@@ -1022,31 +1022,32 @@ def scan( ...@@ -1022,31 +1022,32 @@ def scan(
# Step 7. Create the Scan Op # Step 7. Create the Scan Op
## ##
tap_array = mit_sot_tap_array + [[-1] for x in range(n_sit_sot)] tap_array = tuple(tuple(v) for v in mit_sot_tap_array) + tuple(
(-1,) for x in range(n_sit_sot)
)
if allow_gc is None: if allow_gc is None:
allow_gc = config.scan__allow_gc allow_gc = config.scan__allow_gc
info = OrderedDict()
info = ScanInfo(
info["tap_array"] = tap_array tap_array=tap_array,
info["n_seqs"] = n_seqs n_seqs=n_seqs,
info["n_mit_mot"] = n_mit_mot n_mit_mot=n_mit_mot,
info["n_mit_mot_outs"] = n_mit_mot_outs n_mit_mot_outs=n_mit_mot_outs,
info["mit_mot_out_slices"] = mit_mot_out_slices mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
info["n_mit_sot"] = n_mit_sot n_mit_sot=n_mit_sot,
info["n_sit_sot"] = n_sit_sot n_sit_sot=n_sit_sot,
info["n_shared_outs"] = n_shared_outs n_shared_outs=n_shared_outs,
info["n_nit_sot"] = n_nit_sot n_nit_sot=n_nit_sot,
info["truncate_gradient"] = truncate_gradient truncate_gradient=truncate_gradient,
info["name"] = name name=name,
info["mode"] = mode gpua=False,
info["destroy_map"] = OrderedDict() as_while=as_while,
info["gpua"] = False profile=profile,
info["as_while"] = as_while allow_gc=allow_gc,
info["profile"] = profile strict=strict,
info["allow_gc"] = allow_gc )
info["strict"] = strict
local_op = Scan(inner_inputs, new_outs, info, mode)
local_op = Scan(inner_inputs, new_outs, info)
## ##
# Step 8. Compute the outputs using the scan op # Step 8. Compute the outputs using the scan op
......
...@@ -44,11 +44,12 @@ relies on the following elements to work properly : ...@@ -44,11 +44,12 @@ relies on the following elements to work properly :
""" """
import copy import dataclasses
import itertools import itertools
import logging import logging
import time import time
from collections import OrderedDict from collections import OrderedDict
from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -57,7 +58,7 @@ from aesara import tensor as aet ...@@ -57,7 +58,7 @@ from aesara import tensor as aet
from aesara.compile.builders import infer_shape from aesara.compile.builders import infer_shape
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.io import In, Out from aesara.compile.io import In, Out
from aesara.compile.mode import AddFeatureOptimizer, get_mode from aesara.compile.mode import AddFeatureOptimizer, Mode, get_default_mode, get_mode
from aesara.compile.profiling import ScanProfileStats, register_profiler_printer from aesara.compile.profiling import ScanProfileStats, register_profiler_printer
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
...@@ -76,7 +77,7 @@ from aesara.graph.op import Op, ops_with_inner_function ...@@ -76,7 +77,7 @@ from aesara.graph.op import Op, ops_with_inner_function
from aesara.link.c.basic import CLinker from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op from aesara.link.utils import raise_with_op
from aesara.scan.utils import Validator, forced_replace, hash_listsDictsTuples, safe_new from aesara.scan.utils import Validator, forced_replace, safe_new
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.math import minimum from aesara.tensor.math import minimum
from aesara.tensor.shape import Shape_i from aesara.tensor.shape import Shape_i
...@@ -88,25 +89,58 @@ from aesara.tensor.var import TensorVariable ...@@ -88,25 +89,58 @@ from aesara.tensor.var import TensorVariable
_logger = logging.getLogger("aesara.scan.op") _logger = logging.getLogger("aesara.scan.op")
@dataclasses.dataclass(frozen=True)
class ScanInfo:
tap_array: tuple
n_seqs: int
n_mit_mot: int
n_mit_mot_outs: int
mit_mot_out_slices: tuple
n_mit_sot: int
n_sit_sot: int
n_shared_outs: int
n_nit_sot: int
truncate_gradient: bool = False
name: Optional[str] = None
gpua: bool = False
as_while: bool = False
profile: Optional[Union[str, bool]] = None
allow_gc: bool = True
strict: bool = True
TensorConstructorType = Callable[[List[bool], Union[str, np.generic]], TensorType]
class Scan(Op): class Scan(Op):
""" def __init__(
self,
inputs: List[Variable],
outputs: List[Variable],
info: ScanInfo,
mode: Optional[Mode] = None,
typeConstructor: Optional[TensorConstructorType] = None,
):
r"""
Parameters Parameters
---------- ----------
inputs inputs
Inputs of the inner function of scan. Inputs of the inner function of `Scan`.
outputs outputs
Outputs of the inner function of scan. Outputs of the inner function of `Scan`.
info info
Dictionary containing different properties of the scan op (like number Dictionary containing different properties of the `Scan` `Op` (like
of different types of arguments, name, mode, if it should run on GPU or number of different types of arguments, name, mode, if it should run on
not, etc.). GPU or not, etc.).
mode
The compilation mode for the inner graph.
typeConstructor typeConstructor
Function that constructs an equivalent to Aesara TensorType. Function that constructs an equivalent to Aesara `TensorType`.
Notes Notes
----- -----
``typeConstructor`` had been added to refactor how `typeConstructor` had been added to refactor how
Aesara deals with the GPU. If it runs on the GPU, scan needs Aesara deals with the GPU. If it runs on the GPU, scan needs
to construct certain outputs (those who reside in the GPU to construct certain outputs (those who reside in the GPU
memory) as the GPU-specific type. However we can not import memory) as the GPU-specific type. However we can not import
...@@ -114,31 +148,31 @@ class Scan(Op): ...@@ -114,31 +148,31 @@ class Scan(Op):
on each machine) so the workaround is that the GPU on each machine) so the workaround is that the GPU
optimization passes to the constructor of this class a optimization passes to the constructor of this class a
function that is able to construct a GPU type. This way the function that is able to construct a GPU type. This way the
class Scan does not need to be aware of the details for the class `Scan` does not need to be aware of the details for the
GPU, it just constructs any tensor using this function (which GPU, it just constructs any tensor using this function (which
by default constructs normal tensors). by default constructs normal tensors).
""" """
def __init__(
self,
inputs,
outputs,
info,
typeConstructor=None,
):
# adding properties into self # adding properties into self
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.__dict__.update(info)
# I keep a version of info in self, to use in __eq__ and __hash__,
# since info contains all tunable parameters of the op, so for two
# scan to be equal this tunable parameters should be the same
self.info = info self.info = info
self.__dict__.update(dataclasses.asdict(info))
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if self.name:
message = self.name + " sub profile"
else:
message = "Scan sub profile"
self.mode = get_default_mode() if mode is None else mode
self.mode_instance = get_mode(self.mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc), message=message
)
# build a list of output types for any Apply node using this op. # build a list of output types for any Apply node using this op.
self.output_types = [] self.output_types = []
idx = 0
jdx = 0
def tensorConstructor(broadcastable, dtype): def tensorConstructor(broadcastable, dtype):
return TensorType(broadcastable=broadcastable, dtype=dtype) return TensorType(broadcastable=broadcastable, dtype=dtype)
...@@ -146,6 +180,8 @@ class Scan(Op): ...@@ -146,6 +180,8 @@ class Scan(Op):
if typeConstructor is None: if typeConstructor is None:
typeConstructor = tensorConstructor typeConstructor = tensorConstructor
idx = 0
jdx = 0
while idx < self.n_mit_mot_outs: while idx < self.n_mit_mot_outs:
# Not that for mit_mot there are several output slices per # Not that for mit_mot there are several output slices per
# output sequence # output sequence
...@@ -176,24 +212,13 @@ class Scan(Op): ...@@ -176,24 +212,13 @@ class Scan(Op):
if self.as_while: if self.as_while:
self.output_types = self.output_types[:-1] self.output_types = self.output_types[:-1]
mode_instance = get_mode(self.mode)
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if self.name:
message = self.name + " sub profile"
else:
message = "Scan sub profile"
self.mode_instance = mode_instance.clone(
link_kwargs=dict(allow_gc=self.allow_gc), message=message
)
if not hasattr(self, "name") or self.name is None: if not hasattr(self, "name") or self.name is None:
self.name = "scan_fn" self.name = "scan_fn"
# to have a fair __eq__ comparison later on, we update the info with # to have a fair __eq__ comparison later on, we update the info with
# the actual mode used to compile the function and the name of the # the actual mode used to compile the function and the name of the
# function that we set in case none was given # function that we set in case none was given
self.info["name"] = self.name self.info = dataclasses.replace(self.info, name=self.name)
# Pre-computing some values to speed up perform # Pre-computing some values to speed up perform
self.mintaps = [np.min(x) for x in self.tap_array] self.mintaps = [np.min(x) for x in self.tap_array]
...@@ -205,8 +230,8 @@ class Scan(Op): ...@@ -205,8 +230,8 @@ class Scan(Op):
self.nit_sot_arg_offset = self.shared_arg_offset + self.n_shared_outs self.nit_sot_arg_offset = self.shared_arg_offset + self.n_shared_outs
self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
self.n_tap_outs = self.n_mit_mot + self.n_mit_sot self.n_tap_outs = self.n_mit_mot + self.n_mit_sot
if self.info["gpua"]: if self.info.gpua:
self._hash_inner_graph = self.info["gpu_hash"] self._hash_inner_graph = self.info.gpu_hash
else: else:
# Do the missing inputs check here to have the error early. # Do the missing inputs check here to have the error early.
for var in graph_inputs(self.outputs, self.inputs): for var in graph_inputs(self.outputs, self.inputs):
...@@ -256,7 +281,7 @@ class Scan(Op): ...@@ -256,7 +281,7 @@ class Scan(Op):
# output with type GpuArrayType # output with type GpuArrayType
from aesara.gpuarray import GpuArrayType from aesara.gpuarray import GpuArrayType
if not self.info.get("gpua", False): if not self.info.gpua:
for inp in self.inputs: for inp in self.inputs:
if isinstance(inp.type, GpuArrayType): if isinstance(inp.type, GpuArrayType):
raise TypeError( raise TypeError(
...@@ -279,9 +304,7 @@ class Scan(Op): ...@@ -279,9 +304,7 @@ class Scan(Op):
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
if "allow_gc" not in self.__dict__:
self.allow_gc = True
self.info["allow_gc"] = True
if not hasattr(self, "var_mappings"): if not hasattr(self, "var_mappings"):
# Generate the mappings between inner and outer inputs and outputs # Generate the mappings between inner and outer inputs and outputs
# if they haven't already been generated. # if they haven't already been generated.
...@@ -731,41 +754,20 @@ class Scan(Op): ...@@ -731,41 +754,20 @@ class Scan(Op):
return apply_node return apply_node
def __eq__(self, other): def __eq__(self, other):
# Check if we are dealing with same type of objects if type(self) != type(other):
if not type(self) == type(other):
return False return False
if "destroy_map" not in self.info:
self.info["destroy_map"] = OrderedDict() if self.info != other.info:
if "destroy_map" not in other.info:
other.info["destroy_map"] = OrderedDict()
keys_to_check = [
"truncate_gradient",
"profile",
"n_seqs",
"tap_array",
"as_while",
"n_mit_sot",
"destroy_map",
"n_nit_sot",
"n_shared_outs",
"n_sit_sot",
"gpua",
"n_mit_mot_outs",
"n_mit_mot",
"mit_mot_out_slices",
]
# This are some safety checks ( namely that the inner graph has the
# same number of inputs and same number of outputs )
if not len(self.inputs) == len(other.inputs):
return False return False
elif not len(self.outputs) == len(other.outputs):
# Compare inner graphs
# TODO: Use `self.inner_fgraph == other.inner_fgraph`
if len(self.inputs) != len(other.inputs):
return False return False
for key in keys_to_check:
if self.info[key] != other.info[key]: if len(self.outputs) != len(other.outputs):
return False return False
# If everything went OK up to here, there is still one thing to
# check. Namely, do the internal graph represent same
# computations
for self_in, other_in in zip(self.inputs, other.inputs): for self_in, other_in in zip(self.inputs, other.inputs):
if self_in.type != other_in.type: if self_in.type != other_in.type:
return False return False
...@@ -801,15 +803,7 @@ class Scan(Op): ...@@ -801,15 +803,7 @@ class Scan(Op):
return aux_txt return aux_txt
def __hash__(self): def __hash__(self):
return hash( return hash((type(self), self._hash_inner_graph, self.info))
(
type(self),
# and a hash representing the inner graph using the
# CLinker.cmodule_key_
self._hash_inner_graph,
hash_listsDictsTuples(self.info),
)
)
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
""" """
...@@ -864,7 +858,6 @@ class Scan(Op): ...@@ -864,7 +858,6 @@ class Scan(Op):
wrapped_inputs = [In(x, borrow=False) for x in self.inputs[: self.n_seqs]] wrapped_inputs = [In(x, borrow=False) for x in self.inputs[: self.n_seqs]]
new_outputs = [x for x in self.outputs] new_outputs = [x for x in self.outputs]
preallocated_mitmot_outs = [] preallocated_mitmot_outs = []
new_mit_mot_out_slices = copy.deepcopy(self.mit_mot_out_slices)
input_idx = self.n_seqs input_idx = self.n_seqs
for mitmot_idx in range(self.n_mit_mot): for mitmot_idx in range(self.n_mit_mot):
...@@ -894,7 +887,6 @@ class Scan(Op): ...@@ -894,7 +887,6 @@ class Scan(Op):
) )
wrapped_inputs.append(wrapped_inp) wrapped_inputs.append(wrapped_inp)
preallocated_mitmot_outs.append(output_idx) preallocated_mitmot_outs.append(output_idx)
new_mit_mot_out_slices[mitmot_idx].remove(inp_tap)
else: else:
# Wrap the corresponding input as usual. Leave the # Wrap the corresponding input as usual. Leave the
# output as-is. # output as-is.
...@@ -1963,7 +1955,7 @@ class Scan(Op): ...@@ -1963,7 +1955,7 @@ class Scan(Op):
outer_oidx = 0 outer_oidx = 0
# Handle sequences inputs # Handle sequences inputs
for i in range(self.info["n_seqs"]): for i in range(self.info.n_seqs):
outer_input_indices.append(outer_iidx) outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx]) inner_input_indices.append([inner_iidx])
inner_output_indices.append([]) inner_output_indices.append([])
...@@ -1975,8 +1967,8 @@ class Scan(Op): ...@@ -1975,8 +1967,8 @@ class Scan(Op):
outer_oidx += 0 outer_oidx += 0
# Handle mitmots, mitsots and sitsots variables # Handle mitmots, mitsots and sitsots variables
for i in range(len(self.info["tap_array"])): for i in range(len(self.info.tap_array)):
nb_input_taps = len(self.info["tap_array"][i]) nb_input_taps = len(self.info.tap_array[i])
if i < self.n_mit_mot: if i < self.n_mit_mot:
nb_output_taps = len(self.mit_mot_out_slices[i]) nb_output_taps = len(self.mit_mot_out_slices[i])
...@@ -1999,7 +1991,7 @@ class Scan(Op): ...@@ -1999,7 +1991,7 @@ class Scan(Op):
# This is needed because, for outer inputs (and for outer inputs only) # This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables. # nitsots come *after* shared variables.
outer_iidx += self.info["n_shared_outs"] outer_iidx += self.info.n_shared_outs
# Handle nitsots variables # Handle nitsots variables
for i in range(self.n_nit_sot): for i in range(self.n_nit_sot):
...@@ -2015,10 +2007,10 @@ class Scan(Op): ...@@ -2015,10 +2007,10 @@ class Scan(Op):
# This is needed because, for outer inputs (and for outer inputs only) # This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables. # nitsots come *after* shared variables.
outer_iidx -= self.info["n_shared_outs"] + self.n_nit_sot outer_iidx -= self.info.n_shared_outs + self.n_nit_sot
# Handle shared states # Handle shared states
for i in range(self.info["n_shared_outs"]): for i in range(self.info.n_shared_outs):
outer_input_indices.append(outer_iidx) outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx]) inner_input_indices.append([inner_iidx])
inner_output_indices.append([inner_oidx]) inner_output_indices.append([inner_oidx])
...@@ -2158,10 +2150,10 @@ class Scan(Op): ...@@ -2158,10 +2150,10 @@ class Scan(Op):
oidx += 1 oidx += 1
iidx -= len(taps) iidx -= len(taps)
if iidx < self.info["n_sit_sot"]: if iidx < self.info.n_sit_sot:
return oidx + iidx return oidx + iidx
else: else:
return oidx + iidx + self.info["n_nit_sot"] return oidx + iidx + self.info.n_nit_sot
def get_out_idx(iidx): def get_out_idx(iidx):
oidx = 0 oidx = 0
...@@ -2241,9 +2233,9 @@ class Scan(Op): ...@@ -2241,9 +2233,9 @@ class Scan(Op):
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because # "if Xt not in self.inner_nitsot_outs(self_outputs)" because
# the exact same variable can be used as multiple outputs. # the exact same variable can be used as multiple outputs.
idx_nitsot_start = ( idx_nitsot_start = (
self.info["n_mit_mot"] + self.info["n_mit_sot"] + self.info["n_sit_sot"] self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
) )
idx_nitsot_end = idx_nitsot_start + self.info["n_nit_sot"] idx_nitsot_end = idx_nitsot_start + self.info.n_nit_sot
if idx < idx_nitsot_start or idx >= idx_nitsot_end: if idx < idx_nitsot_start or idx >= idx_nitsot_end:
# What we do here is loop through dC_douts and collect all # What we do here is loop through dC_douts and collect all
# those that are connected to the specific one and do an # those that are connected to the specific one and do an
...@@ -2668,27 +2660,23 @@ class Scan(Op): ...@@ -2668,27 +2660,23 @@ class Scan(Op):
n_sitsot_outs = len(outer_inp_sitsot) n_sitsot_outs = len(outer_inp_sitsot)
new_tap_array = mitmot_inp_taps + [[-1] for k in range(n_sitsot_outs)] new_tap_array = mitmot_inp_taps + [[-1] for k in range(n_sitsot_outs)]
info = OrderedDict() info = ScanInfo(
info["n_seqs"] = len(outer_inp_seqs) n_seqs=len(outer_inp_seqs),
info["n_mit_sot"] = 0 n_mit_sot=0,
info["tap_array"] = new_tap_array tap_array=tuple(tuple(v) for v in new_tap_array),
info["gpua"] = False gpua=False,
info["n_mit_mot"] = len(outer_inp_mitmot) n_mit_mot=len(outer_inp_mitmot),
info["n_mit_mot_outs"] = n_mitmot_outs n_mit_mot_outs=n_mitmot_outs,
info["mit_mot_out_slices"] = mitmot_out_taps mit_mot_out_slices=tuple(tuple(v) for v in mitmot_out_taps),
info["truncate_gradient"] = self.truncate_gradient truncate_gradient=self.truncate_gradient,
info["n_sit_sot"] = n_sitsot_outs n_sit_sot=n_sitsot_outs,
info["n_shared_outs"] = 0 n_shared_outs=0,
info["n_nit_sot"] = n_nit_sot n_nit_sot=n_nit_sot,
info["as_while"] = False as_while=False,
info["profile"] = self.profile profile=self.profile,
info["destroy_map"] = OrderedDict() name=f"grad_of_{self.name}" if self.name else None,
if self.name: allow_gc=self.allow_gc,
info["name"] = "grad_of_" + self.name )
else:
info["name"] = None
info["mode"] = self.mode
info["allow_gc"] = self.allow_gc
outer_inputs = ( outer_inputs = (
[grad_steps] [grad_steps]
...@@ -2709,7 +2697,7 @@ class Scan(Op): ...@@ -2709,7 +2697,7 @@ class Scan(Op):
) )
inner_gfn_outs = inner_out_mitmot + inner_out_sitsot + inner_out_nitsot inner_gfn_outs = inner_out_mitmot + inner_out_sitsot + inner_out_nitsot
local_op = Scan(inner_gfn_ins, inner_gfn_outs, info) local_op = Scan(inner_gfn_ins, inner_gfn_outs, info, self.mode)
outputs = local_op(*outer_inputs) outputs = local_op(*outer_inputs)
if type(outputs) not in (list, tuple): if type(outputs) not in (list, tuple):
outputs = [outputs] outputs = [outputs]
...@@ -2852,8 +2840,8 @@ class Scan(Op): ...@@ -2852,8 +2840,8 @@ class Scan(Op):
rop_self_outputs = self_outputs[:-1] rop_self_outputs = self_outputs[:-1]
else: else:
rop_self_outputs = self_outputs rop_self_outputs = self_outputs
if self.info["n_shared_outs"] > 0: if self.info.n_shared_outs > 0:
rop_self_outputs = rop_self_outputs[: -self.info["n_shared_outs"]] rop_self_outputs = rop_self_outputs[: -self.info.n_shared_outs]
rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points) rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
if type(rop_outs) not in (list, tuple): if type(rop_outs) not in (list, tuple):
rop_outs = [rop_outs] rop_outs = [rop_outs]
...@@ -2867,25 +2855,7 @@ class Scan(Op): ...@@ -2867,25 +2855,7 @@ class Scan(Op):
# The only exception is the eval point for the number of sequences, and # The only exception is the eval point for the number of sequences, and
# evan point for the number of nit_sot which I think should just be # evan point for the number of nit_sot which I think should just be
# ignored (?) # ignored (?)
info = OrderedDict()
info["n_seqs"] = self.n_seqs * 2
info["n_mit_sot"] = self.n_mit_sot * 2
info["n_sit_sot"] = self.n_sit_sot * 2
info["n_mit_mot"] = self.n_mit_mot * 2
info["n_nit_sot"] = self.n_nit_sot * 2
info["n_shared_outs"] = self.n_shared_outs
info["gpua"] = False
info["as_while"] = self.as_while
info["profile"] = self.profile
info["truncate_gradient"] = self.truncate_gradient
if self.name:
info["name"] = "rop_of_" + self.name
else:
info["name"] = None
info["mode"] = self.mode
info["allow_gc"] = self.allow_gc
info["mit_mot_out_slices"] = self.mit_mot_out_slices * 2
info["destroy_map"] = OrderedDict()
new_tap_array = [] new_tap_array = []
b = 0 b = 0
e = self.n_mit_mot e = self.n_mit_mot
...@@ -2896,7 +2866,6 @@ class Scan(Op): ...@@ -2896,7 +2866,6 @@ class Scan(Op):
b = e b = e
e += self.n_sit_sot e += self.n_sit_sot
new_tap_array += self.tap_array[b:e] * 2 new_tap_array += self.tap_array[b:e] * 2
info["tap_array"] = new_tap_array
# Sequences ... # Sequences ...
b = 1 b = 1
...@@ -2993,7 +2962,7 @@ class Scan(Op): ...@@ -2993,7 +2962,7 @@ class Scan(Op):
# Outputs # Outputs
n_mit_mot_outs = int(np.sum([len(x) for x in self.mit_mot_out_slices])) n_mit_mot_outs = int(np.sum([len(x) for x in self.mit_mot_out_slices]))
info["n_mit_mot_outs"] = n_mit_mot_outs * 2
b = 0 b = 0
e = n_mit_mot_outs e = n_mit_mot_outs
inner_out_mit_mot = self_outputs[b:e] + rop_outs[b:e] inner_out_mit_mot = self_outputs[b:e] + rop_outs[b:e]
...@@ -3039,7 +3008,25 @@ class Scan(Op): ...@@ -3039,7 +3008,25 @@ class Scan(Op):
+ scan_other + scan_other
) )
local_op = Scan(inner_ins, inner_outs, info) info = ScanInfo(
n_seqs=self.n_seqs * 2,
n_mit_sot=self.n_mit_sot * 2,
n_sit_sot=self.n_sit_sot * 2,
n_mit_mot=self.n_mit_mot * 2,
n_nit_sot=self.n_nit_sot * 2,
n_shared_outs=self.n_shared_outs,
n_mit_mot_outs=n_mit_mot_outs * 2,
gpua=False,
as_while=self.as_while,
profile=self.profile,
truncate_gradient=self.truncate_gradient,
name=f"rop_of_{self.name}" if self.name else None,
allow_gc=self.allow_gc,
tap_array=tuple(tuple(v) for v in new_tap_array),
mit_mot_out_slices=tuple(tuple(v) for v in self.mit_mot_out_slices) * 2,
)
local_op = Scan(inner_ins, inner_outs, info, self.mode)
outputs = local_op(*scan_inputs) outputs = local_op(*scan_inputs)
if type(outputs) not in (list, tuple): if type(outputs) not in (list, tuple):
outputs = [outputs] outputs = [outputs]
......
...@@ -51,8 +51,8 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global). ...@@ -51,8 +51,8 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
""" """
import copy import copy
import dataclasses
import logging import logging
from collections import OrderedDict
from sys import maxsize from sys import maxsize
import numpy as np import numpy as np
...@@ -78,7 +78,7 @@ from aesara.graph.fg import InconsistencyError ...@@ -78,7 +78,7 @@ from aesara.graph.fg import InconsistencyError
from aesara.graph.op import compute_test_value from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.scan.op import Scan from aesara.scan.op import Scan, ScanInfo
from aesara.scan.utils import ( from aesara.scan.utils import (
ScanArgs, ScanArgs,
compress_outs, compress_outs,
...@@ -156,7 +156,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -156,7 +156,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
out_stuff_outer = node.inputs[1 + op.n_seqs : st] out_stuff_outer = node.inputs[1 + op.n_seqs : st]
# To replace constants in the outer graph by clones in the inner graph # To replace constants in the outer graph by clones in the inner graph
givens = OrderedDict() givens = {}
# All the inputs of the inner graph of the new scan # All the inputs of the inner graph of the new scan
nw_inner = [] nw_inner = []
# Same for the outer graph, initialized w/ number of steps # Same for the outer graph, initialized w/ number of steps
...@@ -217,12 +217,10 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -217,12 +217,10 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
if len(nw_inner) != len(op_ins): if len(nw_inner) != len(op_ins):
op_outs = clone_replace(op_outs, replace=givens) op_outs = clone_replace(op_outs, replace=givens)
nw_info = copy.deepcopy(op.info) nw_info = dataclasses.replace(op.info, n_seqs=nw_n_seqs)
nw_info["n_seqs"] = nw_n_seqs nwScan = Scan(nw_inner, op_outs, nw_info, op.mode)
# DEBUG CHECK
nwScan = Scan(nw_inner, op_outs, nw_info)
nw_outs = nwScan(*nw_outer, **dict(return_list=True)) nw_outs = nwScan(*nw_outer, **dict(return_list=True))
return OrderedDict([("remove", [node])] + list(zip(node.outputs, nw_outs))) return dict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
else: else:
return False return False
...@@ -263,7 +261,7 @@ class PushOutNonSeqScan(GlobalOptimizer): ...@@ -263,7 +261,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
to_replace_map = OrderedDict() to_replace_map = {}
def add_to_replace(y): def add_to_replace(y):
to_replace_set.add(y) to_replace_set.add(y)
...@@ -377,7 +375,7 @@ class PushOutNonSeqScan(GlobalOptimizer): ...@@ -377,7 +375,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
if len(clean_to_replace) > 0: if len(clean_to_replace) > 0:
# We can finally put an end to all this madness # We can finally put an end to all this madness
givens = OrderedDict() givens = {}
nw_outer = [] nw_outer = []
nw_inner = [] nw_inner = []
for to_repl, repl_in, repl_out in zip( for to_repl, repl_in, repl_out in zip(
...@@ -394,7 +392,7 @@ class PushOutNonSeqScan(GlobalOptimizer): ...@@ -394,7 +392,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
op_ins = clean_inputs + nw_inner op_ins = clean_inputs + nw_inner
# Reconstruct node # Reconstruct node
nwScan = Scan(op_ins, op_outs, op.info) nwScan = Scan(op_ins, op_outs, op.info, op.mode)
# Do not call make_node for test_value # Do not call make_node for test_value
nw_node = nwScan(*(node.inputs + nw_outer), **dict(return_list=True))[ nw_node = nwScan(*(node.inputs + nw_outer), **dict(return_list=True))[
...@@ -409,7 +407,7 @@ class PushOutNonSeqScan(GlobalOptimizer): ...@@ -409,7 +407,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
return True return True
elif not to_keep_set: elif not to_keep_set:
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = OrderedDict() replace_with = {}
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
if out in local_fgraph_outs_set: if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]] x = node.outputs[local_fgraph_outs_map[out]]
...@@ -481,7 +479,7 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -481,7 +479,7 @@ class PushOutSeqScan(GlobalOptimizer):
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
to_replace_map = OrderedDict() to_replace_map = {}
def add_to_replace(y): def add_to_replace(y):
to_replace_set.add(y) to_replace_set.add(y)
...@@ -638,7 +636,7 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -638,7 +636,7 @@ class PushOutSeqScan(GlobalOptimizer):
if len(clean_to_replace) > 0: if len(clean_to_replace) > 0:
# We can finally put an end to all this madness # We can finally put an end to all this madness
givens = OrderedDict() givens = {}
nw_outer = [] nw_outer = []
nw_inner = [] nw_inner = []
for to_repl, repl_in, repl_out in zip( for to_repl, repl_in, repl_out in zip(
...@@ -656,9 +654,10 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -656,9 +654,10 @@ class PushOutSeqScan(GlobalOptimizer):
op_ins = nw_inner + clean_inputs op_ins = nw_inner + clean_inputs
# Reconstruct node # Reconstruct node
nw_info = op.info.copy() nw_info = dataclasses.replace(
nw_info["n_seqs"] += len(nw_inner) op.info, n_seqs=op.info.n_seqs + len(nw_inner)
nwScan = Scan(op_ins, op_outs, nw_info) )
nwScan = Scan(op_ins, op_outs, nw_info, op.mode)
# Do not call make_node for test_value # Do not call make_node for test_value
nw_node = nwScan( nw_node = nwScan(
*(node.inputs[:1] + nw_outer + node.inputs[1:]), *(node.inputs[:1] + nw_outer + node.inputs[1:]),
...@@ -673,7 +672,7 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -673,7 +672,7 @@ class PushOutSeqScan(GlobalOptimizer):
return True return True
elif not to_keep_set and not op.as_while and not op.outer_mitmot(node): elif not to_keep_set and not op.as_while and not op.outer_mitmot(node):
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = OrderedDict() replace_with = {}
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
if out in local_fgraph_outs_set: if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]] x = node.outputs[local_fgraph_outs_map[out]]
...@@ -937,7 +936,10 @@ class PushOutScanOutput(GlobalOptimizer): ...@@ -937,7 +936,10 @@ class PushOutScanOutput(GlobalOptimizer):
# Create the `Scan` `Op` from the `ScanArgs` # Create the `Scan` `Op` from the `ScanArgs`
new_scan_op = Scan( new_scan_op = Scan(
new_scan_args.inner_inputs, new_scan_args.inner_outputs, new_scan_args.info new_scan_args.inner_inputs,
new_scan_args.inner_outputs,
new_scan_args.info,
old_scan_node.op.mode,
) )
# Create the Apply node for the scan op # Create the Apply node for the scan op
...@@ -1000,13 +1002,6 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1000,13 +1002,6 @@ class ScanInplaceOptimizer(GlobalOptimizer):
op = node.op op = node.op
info = copy.deepcopy(op.info)
if "destroy_map" not in info:
info["destroy_map"] = OrderedDict()
for out_idx in output_indices:
info["destroy_map"][out_idx] = [out_idx + 1 + op.info["n_seqs"]]
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[: 1 + op.n_seqs] ls_begin = node.inputs[: 1 + op.n_seqs]
ls = op.outer_mitmot(node.inputs) ls = op.outer_mitmot(node.inputs)
...@@ -1048,7 +1043,15 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1048,7 +1043,15 @@ class ScanInplaceOptimizer(GlobalOptimizer):
else: else:
typeConstructor = self.typeInfer(node) typeConstructor = self.typeInfer(node)
new_op = Scan(op.inputs, op.outputs, info, typeConstructor=typeConstructor) new_op = Scan(
op.inputs, op.outputs, op.info, op.mode, typeConstructor=typeConstructor
)
destroy_map = op.destroy_map.copy()
for out_idx in output_indices:
destroy_map[out_idx] = [out_idx + 1 + op.info.n_seqs]
new_op.destroy_map = destroy_map
# Do not call make_node for test_value # Do not call make_node for test_value
new_outs = new_op(*inputs, **dict(return_list=True)) new_outs = new_op(*inputs, **dict(return_list=True))
...@@ -1070,7 +1073,7 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1070,7 +1073,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
scan_nodes = [ scan_nodes = [
x x
for x in nodes for x in nodes
if (isinstance(x.op, Scan) and x.op.info["gpua"] == self.gpua_flag) if (isinstance(x.op, Scan) and x.op.info.gpua == self.gpua_flag)
] ]
for scan_idx in range(len(scan_nodes)): for scan_idx in range(len(scan_nodes)):
...@@ -1080,7 +1083,7 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1080,7 +1083,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# them. # them.
original_node = scan_nodes[scan_idx] original_node = scan_nodes[scan_idx]
op = original_node.op op = original_node.op
n_outs = op.info["n_mit_mot"] + op.info["n_mit_sot"] + op.info["n_sit_sot"] n_outs = op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot
# Generate a list of outputs on which the node could potentially # Generate a list of outputs on which the node could potentially
# operate inplace. # operate inplace.
...@@ -1171,7 +1174,7 @@ class ScanSaveMem(GlobalOptimizer): ...@@ -1171,7 +1174,7 @@ class ScanSaveMem(GlobalOptimizer):
# Each access to shape_of is in a try..except block in order to # Each access to shape_of is in a try..except block in order to
# use a default version when the variable is not in the shape_of # use a default version when the variable is not in the shape_of
# dictionary. # dictionary.
shape_of = OrderedDict() shape_of = {}
# 1. Initialization of variables # 1. Initialization of variables
# Note 1) We do not actually care about outputs representing shared # Note 1) We do not actually care about outputs representing shared
# variables (those have no intermediate values) so it is safer to # variables (those have no intermediate values) so it is safer to
...@@ -1548,7 +1551,7 @@ class ScanSaveMem(GlobalOptimizer): ...@@ -1548,7 +1551,7 @@ class ScanSaveMem(GlobalOptimizer):
(inps, outs, info, node_ins, compress_map) = compress_outs( (inps, outs, info, node_ins, compress_map) = compress_outs(
op, not_required, nw_inputs op, not_required, nw_inputs
) )
inv_compress_map = OrderedDict() inv_compress_map = {}
for k, v in compress_map.items(): for k, v in compress_map.items():
inv_compress_map[v] = k inv_compress_map[v] = k
...@@ -1559,7 +1562,9 @@ class ScanSaveMem(GlobalOptimizer): ...@@ -1559,7 +1562,9 @@ class ScanSaveMem(GlobalOptimizer):
return return
# Do not call make_node for test_value # Do not call make_node for test_value
new_outs = Scan(inps, outs, info)(*node_ins, **dict(return_list=True)) new_outs = Scan(inps, outs, info, op.mode)(
*node_ins, **dict(return_list=True)
)
old_new = [] old_new = []
# 3.7 Get replace pairs for those outputs that do not change # 3.7 Get replace pairs for those outputs that do not change
...@@ -1688,24 +1693,6 @@ class ScanMerge(GlobalOptimizer): ...@@ -1688,24 +1693,6 @@ class ScanMerge(GlobalOptimizer):
else: else:
as_while = False as_while = False
info = OrderedDict()
info["tap_array"] = []
info["n_seqs"] = sum([nd.op.n_seqs for nd in nodes])
info["n_mit_mot"] = sum([nd.op.n_mit_mot for nd in nodes])
info["n_mit_mot_outs"] = sum([nd.op.n_mit_mot_outs for nd in nodes])
info["mit_mot_out_slices"] = []
info["n_mit_sot"] = sum([nd.op.n_mit_sot for nd in nodes])
info["n_sit_sot"] = sum([nd.op.n_sit_sot for nd in nodes])
info["n_shared_outs"] = sum([nd.op.n_shared_outs for nd in nodes])
info["n_nit_sot"] = sum([nd.op.n_nit_sot for nd in nodes])
info["truncate_gradient"] = nodes[0].op.truncate_gradient
info["name"] = "&".join([nd.op.name for nd in nodes])
info["mode"] = nodes[0].op.mode
info["gpua"] = False
info["as_while"] = as_while
info["profile"] = nodes[0].op.profile
info["allow_gc"] = nodes[0].op.allow_gc
# We keep the inner_ins and inner_outs of each original node separated. # We keep the inner_ins and inner_outs of each original node separated.
# To be able to recombine them in the right order after the clone, # To be able to recombine them in the right order after the clone,
# we also need to split them by types (seq, mitmot, ...). # we also need to split them by types (seq, mitmot, ...).
...@@ -1725,12 +1712,15 @@ class ScanMerge(GlobalOptimizer): ...@@ -1725,12 +1712,15 @@ class ScanMerge(GlobalOptimizer):
inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inputs), idx)) inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inputs), idx))
outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx) outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
tap_array = ()
mit_mot_out_slices = ()
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitMot # MitMot
inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inputs), idx)) inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.outputs)) inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.outputs))
info["tap_array"] += nd.op.mitmot_taps() tap_array += nd.op.mitmot_taps()
info["mit_mot_out_slices"] += nd.op.mitmot_out_taps() mit_mot_out_slices += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx) outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
outer_outs += nd.op.outer_mitmot_outs(nd.outputs) outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
...@@ -1738,14 +1728,14 @@ class ScanMerge(GlobalOptimizer): ...@@ -1738,14 +1728,14 @@ class ScanMerge(GlobalOptimizer):
# MitSot # MitSot
inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inputs), idx)) inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.outputs)) inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.outputs))
info["tap_array"] += nd.op.mitsot_taps() tap_array += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx) outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd.outputs) outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# SitSot # SitSot
inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inputs), idx)) inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inputs), idx))
info["tap_array"] += [[-1] for x in range(nd.op.n_sit_sot)] tap_array += tuple((-1,) for x in range(nd.op.n_sit_sot))
inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.outputs)) inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.outputs))
outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx) outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd.outputs) outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
...@@ -1834,7 +1824,25 @@ class ScanMerge(GlobalOptimizer): ...@@ -1834,7 +1824,25 @@ class ScanMerge(GlobalOptimizer):
else: else:
new_inner_outs += inner_outs[idx][gr_idx] new_inner_outs += inner_outs[idx][gr_idx]
new_op = Scan(new_inner_ins, new_inner_outs, info) info = ScanInfo(
tap_array=tap_array,
n_seqs=sum([nd.op.n_seqs for nd in nodes]),
n_mit_mot=sum([nd.op.n_mit_mot for nd in nodes]),
n_mit_mot_outs=sum([nd.op.n_mit_mot_outs for nd in nodes]),
mit_mot_out_slices=mit_mot_out_slices,
n_mit_sot=sum([nd.op.n_mit_sot for nd in nodes]),
n_sit_sot=sum([nd.op.n_sit_sot for nd in nodes]),
n_shared_outs=sum([nd.op.n_shared_outs for nd in nodes]),
n_nit_sot=sum([nd.op.n_nit_sot for nd in nodes]),
truncate_gradient=nodes[0].op.truncate_gradient,
name="&".join([nd.op.name for nd in nodes]),
gpua=False,
as_while=as_while,
profile=nodes[0].op.profile,
allow_gc=nodes[0].op.allow_gc,
)
new_op = Scan(new_inner_ins, new_inner_outs, info, nodes[0].op.mode)
new_outs = new_op(*outer_ins) new_outs = new_op(*outer_ins)
if not isinstance(new_outs, (list, tuple)): if not isinstance(new_outs, (list, tuple)):
...@@ -1932,7 +1940,7 @@ def make_equiv(lo, li): ...@@ -1932,7 +1940,7 @@ def make_equiv(lo, li):
the equivalence of their corresponding outer inputs. the equivalence of their corresponding outer inputs.
""" """
seeno = OrderedDict() seeno = {}
left = [] left = []
right = [] right = []
for o, i in zip(lo, li): for o, i in zip(lo, li):
...@@ -1956,7 +1964,7 @@ def scan_merge_inouts(fgraph, node): ...@@ -1956,7 +1964,7 @@ def scan_merge_inouts(fgraph, node):
node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info
) )
inp_equiv = OrderedDict() inp_equiv = {}
if has_duplicates(a.outer_in_seqs): if has_duplicates(a.outer_in_seqs):
new_outer_seqs = [] new_outer_seqs = []
...@@ -1992,7 +2000,7 @@ def scan_merge_inouts(fgraph, node): ...@@ -1992,7 +2000,7 @@ def scan_merge_inouts(fgraph, node):
a_inner_outs = a.inner_outputs a_inner_outs = a.inner_outputs
inner_outputs = clone_replace(a_inner_outs, replace=inp_equiv) inner_outputs = clone_replace(a_inner_outs, replace=inp_equiv)
op = Scan(inner_inputs, inner_outputs, info) op = Scan(inner_inputs, inner_outputs, info, node.op.mode)
outputs = op(*outer_inputs) outputs = op(*outer_inputs)
if not isinstance(outputs, (list, tuple)): if not isinstance(outputs, (list, tuple)):
...@@ -2019,7 +2027,7 @@ def scan_merge_inouts(fgraph, node): ...@@ -2019,7 +2027,7 @@ def scan_merge_inouts(fgraph, node):
left += _left left += _left
right += _right right += _right
if has_duplicates(na.outer_in_mit_mot): if has_duplicates(na.outer_in_mit_mot):
seen = OrderedDict() seen = {}
for omm, imm, _sl in zip( for omm, imm, _sl in zip(
na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices
): ):
...@@ -2032,7 +2040,7 @@ def scan_merge_inouts(fgraph, node): ...@@ -2032,7 +2040,7 @@ def scan_merge_inouts(fgraph, node):
seen[(omm, sl)] = imm seen[(omm, sl)] = imm
if has_duplicates(na.outer_in_mit_sot): if has_duplicates(na.outer_in_mit_sot):
seen = OrderedDict() seen = {}
for oms, ims, _sl in zip( for oms, ims, _sl in zip(
na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices
): ):
...@@ -2117,9 +2125,7 @@ def scan_merge_inouts(fgraph, node): ...@@ -2117,9 +2125,7 @@ def scan_merge_inouts(fgraph, node):
new_outer_out_mit_mot.append(outer_omm) new_outer_out_mit_mot.append(outer_omm)
na.outer_out_mit_mot = new_outer_out_mit_mot na.outer_out_mit_mot = new_outer_out_mit_mot
if remove: if remove:
return OrderedDict( return dict([("remove", remove)] + list(zip(node.outputs, na.outer_outputs)))
[("remove", remove)] + list(zip(node.outputs, na.outer_outputs))
)
return na.outer_outputs return na.outer_outputs
...@@ -2214,15 +2220,17 @@ class PushOutDot1(GlobalOptimizer): ...@@ -2214,15 +2220,17 @@ class PushOutDot1(GlobalOptimizer):
inner_non_seqs = op.inner_non_seqs(op.inputs) inner_non_seqs = op.inner_non_seqs(op.inputs)
outer_non_seqs = op.outer_non_seqs(node) outer_non_seqs = op.outer_non_seqs(node)
new_info = op.info.copy()
st = len(op.mitmot_taps()) + len(op.mitsot_taps()) st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info["tap_array"] = ( new_info = dataclasses.replace(
new_info["tap_array"][: st + idx] op.info,
+ new_info["tap_array"][st + idx + 1 :] tap_array=(
op.info.tap_array[: st + idx]
+ op.info.tap_array[st + idx + 1 :]
),
n_sit_sot=op.info.n_sit_sot - 1,
n_nit_sot=op.info.n_nit_sot + 1,
) )
new_info["n_sit_sot"] -= 1
new_info["n_nit_sot"] += 1
inner_sitsot = inner_sitsot[:idx] + inner_sitsot[idx + 1 :] inner_sitsot = inner_sitsot[:idx] + inner_sitsot[idx + 1 :]
outer_sitsot = outer_sitsot[:idx] + outer_sitsot[idx + 1 :] outer_sitsot = outer_sitsot[:idx] + outer_sitsot[idx + 1 :]
inner_sitsot_outs = ( inner_sitsot_outs = (
...@@ -2249,7 +2257,7 @@ class PushOutDot1(GlobalOptimizer): ...@@ -2249,7 +2257,7 @@ class PushOutDot1(GlobalOptimizer):
new_inner_inps, new_inner_outs = reconstruct_graph( new_inner_inps, new_inner_outs = reconstruct_graph(
_new_inner_inps, _new_inner_outs _new_inner_inps, _new_inner_outs
) )
new_op = Scan(new_inner_inps, new_inner_outs, new_info) new_op = Scan(new_inner_inps, new_inner_outs, new_info, op.mode)
_scan_inputs = ( _scan_inputs = (
[node.inputs[0]] [node.inputs[0]]
+ outer_seqs + outer_seqs
......
"""This module provides utility functions for the `Scan` `Op`.""" """This module provides utility functions for the `Scan` `Op`."""
import copy import copy
import dataclasses
import logging import logging
import warnings import warnings
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
...@@ -157,21 +158,6 @@ def traverse(out, x, x_copy, d, visited=None): ...@@ -157,21 +158,6 @@ def traverse(out, x, x_copy, d, visited=None):
return d return d
# Hashing a dictionary/list/tuple by xoring the hash of each element
def hash_listsDictsTuples(x):
hash_value = 0
if isinstance(x, dict):
for k, v in x.items():
hash_value ^= hash_listsDictsTuples(k)
hash_value ^= hash_listsDictsTuples(v)
elif isinstance(x, (list, tuple)):
for v in x:
hash_value ^= hash_listsDictsTuples(v)
else:
hash_value ^= hash(x)
return hash_value
def map_variables(replacer, graphs, additional_inputs=None): def map_variables(replacer, graphs, additional_inputs=None):
"""Construct new graphs based on 'graphs' with some variables replaced """Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'. according to 'replacer'.
...@@ -264,6 +250,7 @@ def map_variables(replacer, graphs, additional_inputs=None): ...@@ -264,6 +250,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
new_inner_inputs, new_inner_inputs,
new_inner_outputs, new_inner_outputs,
node.op.info, node.op.info,
node.op.mode,
# FIXME: infer this someday? # FIXME: infer this someday?
typeConstructor=None, typeConstructor=None,
) )
...@@ -669,7 +656,7 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -669,7 +656,7 @@ def scan_can_remove_outs(op, out_idxs):
offset = op.n_seqs offset = op.n_seqs
lim = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot lim = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot
for idx in range(lim): for idx in range(lim):
n_ins = len(op.info["tap_array"][idx]) n_ins = len(op.info.tap_array[idx])
out_ins += [op.inputs[offset : offset + n_ins]] out_ins += [op.inputs[offset : offset + n_ins]]
offset += n_ins offset += n_ins
out_ins += [[] for k in range(op.n_nit_sot)] out_ins += [[] for k in range(op.n_nit_sot)]
...@@ -702,23 +689,25 @@ def compress_outs(op, not_required, inputs): ...@@ -702,23 +689,25 @@ def compress_outs(op, not_required, inputs):
node inputs, and changing the dictionary. node inputs, and changing the dictionary.
""" """
info = OrderedDict() from aesara.scan.op import ScanInfo
info["tap_array"] = []
info["n_seqs"] = op.info["n_seqs"] info = ScanInfo(
info["n_mit_mot"] = 0 tap_array=(),
info["n_mit_mot_outs"] = 0 n_seqs=op.info.n_seqs,
info["mit_mot_out_slices"] = [] n_mit_mot=0,
info["n_mit_sot"] = 0 n_mit_mot_outs=0,
info["n_sit_sot"] = 0 mit_mot_out_slices=(),
info["n_shared_outs"] = 0 n_mit_sot=0,
info["n_nit_sot"] = 0 n_sit_sot=0,
info["truncate_gradient"] = op.info["truncate_gradient"] n_shared_outs=0,
info["name"] = op.info["name"] n_nit_sot=0,
info["gpua"] = op.info["gpua"] truncate_gradient=op.info.truncate_gradient,
info["mode"] = op.info["mode"] name=op.info.name,
info["as_while"] = op.info["as_while"] gpua=op.info.gpua,
info["profile"] = op.info["profile"] as_while=op.info.as_while,
info["allow_gc"] = op.info["allow_gc"] profile=op.info.profile,
allow_gc=op.info.allow_gc,
)
op_inputs = op.inputs[: op.n_seqs] op_inputs = op.inputs[: op.n_seqs]
op_outputs = [] op_outputs = []
...@@ -730,13 +719,17 @@ def compress_outs(op, not_required, inputs): ...@@ -730,13 +719,17 @@ def compress_outs(op, not_required, inputs):
i_offset = op.n_seqs i_offset = op.n_seqs
o_offset = 0 o_offset = 0
curr_pos = 0 curr_pos = 0
for idx in range(op.info["n_mit_mot"]): for idx in range(op.info.n_mit_mot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info["n_mit_mot"] += 1 info = dataclasses.replace(
info["tap_array"] += [op.tap_array[offset + idx]] info,
info["mit_mot_out_slices"] += [op.mit_mot_out_slices[offset + idx]] n_mit_mot=info.n_mit_mot + 1,
tap_array=info.tap_array + (tuple(op.tap_array[offset + idx]),),
mit_mot_out_slices=info.mit_mot_out_slices
+ (tuple(op.mit_mot_out_slices[offset + idx]),),
)
# input taps # input taps
for jdx in op.tap_array[offset + idx]: for jdx in op.tap_array[offset + idx]:
op_inputs += [op.inputs[i_offset]] op_inputs += [op.inputs[i_offset]]
...@@ -750,16 +743,19 @@ def compress_outs(op, not_required, inputs): ...@@ -750,16 +743,19 @@ def compress_outs(op, not_required, inputs):
else: else:
o_offset += len(op.mit_mot_out_slices[offset + idx]) o_offset += len(op.mit_mot_out_slices[offset + idx])
i_offset += len(op.tap_array[offset + idx]) i_offset += len(op.tap_array[offset + idx])
info["n_mit_mot_outs"] = len(op_outputs) info = dataclasses.replace(info, n_mit_mot_outs=len(op_outputs))
offset += op.n_mit_mot offset += op.n_mit_mot
ni_offset += op.n_mit_mot ni_offset += op.n_mit_mot
for idx in range(op.info["n_mit_sot"]): for idx in range(op.info.n_mit_sot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info["n_mit_sot"] += 1 info = dataclasses.replace(
info["tap_array"] += [op.tap_array[offset + idx]] info,
n_mit_sot=info.n_mit_sot + 1,
tap_array=info.tap_array + (tuple(op.tap_array[offset + idx]),),
)
# input taps # input taps
for jdx in op.tap_array[offset + idx]: for jdx in op.tap_array[offset + idx]:
op_inputs += [op.inputs[i_offset]] op_inputs += [op.inputs[i_offset]]
...@@ -775,12 +771,15 @@ def compress_outs(op, not_required, inputs): ...@@ -775,12 +771,15 @@ def compress_outs(op, not_required, inputs):
offset += op.n_mit_sot offset += op.n_mit_sot
ni_offset += op.n_mit_sot ni_offset += op.n_mit_sot
for idx in range(op.info["n_sit_sot"]): for idx in range(op.info.n_sit_sot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info["n_sit_sot"] += 1 info = dataclasses.replace(
info["tap_array"] += [op.tap_array[offset + idx]] info,
n_sit_sot=info.n_sit_sot + 1,
tap_array=info.tap_array + (tuple(op.tap_array[offset + idx]),),
)
# input taps # input taps
op_inputs += [op.inputs[i_offset]] op_inputs += [op.inputs[i_offset]]
i_offset += 1 i_offset += 1
...@@ -796,11 +795,11 @@ def compress_outs(op, not_required, inputs): ...@@ -796,11 +795,11 @@ def compress_outs(op, not_required, inputs):
offset += op.n_sit_sot offset += op.n_sit_sot
ni_offset += op.n_sit_sot ni_offset += op.n_sit_sot
nit_sot_ins = [] nit_sot_ins = []
for idx in range(op.info["n_nit_sot"]): for idx in range(op.info.n_nit_sot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info["n_nit_sot"] += 1 info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.outputs[o_offset]] op_outputs += [op.outputs[o_offset]]
o_offset += 1 o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]] nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]]
...@@ -809,11 +808,11 @@ def compress_outs(op, not_required, inputs): ...@@ -809,11 +808,11 @@ def compress_outs(op, not_required, inputs):
offset += op.n_nit_sot offset += op.n_nit_sot
shared_ins = [] shared_ins = []
for idx in range(op.info["n_shared_outs"]): for idx in range(op.info.n_shared_outs):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info["n_shared_outs"] += 1 info = dataclasses.replace(info, n_shared_outs=info.n_shared_outs + 1)
op_outputs += [op.outputs[o_offset]] op_outputs += [op.outputs[o_offset]]
o_offset += 1 o_offset += 1
op_inputs += [op.inputs[i_offset]] op_inputs += [op.inputs[i_offset]]
...@@ -896,7 +895,7 @@ class ScanArgs: ...@@ -896,7 +895,7 @@ class ScanArgs:
else: else:
rval = (_inner_inputs, _inner_outputs) rval = (_inner_inputs, _inner_outputs)
if info["as_while"]: if info.as_while:
self.cond = [rval[1][-1]] self.cond = [rval[1][-1]]
inner_outputs = rval[1][:-1] inner_outputs = rval[1][:-1]
else: else:
...@@ -907,17 +906,17 @@ class ScanArgs: ...@@ -907,17 +906,17 @@ class ScanArgs:
p = 1 p = 1
q = 0 q = 0
n_seqs = info["n_seqs"] n_seqs = info.n_seqs
self.outer_in_seqs = outer_inputs[p : p + n_seqs] self.outer_in_seqs = outer_inputs[p : p + n_seqs]
self.inner_in_seqs = inner_inputs[q : q + n_seqs] self.inner_in_seqs = inner_inputs[q : q + n_seqs]
p += n_seqs p += n_seqs
q += n_seqs q += n_seqs
n_mit_mot = info["n_mit_mot"] n_mit_mot = info.n_mit_mot
n_mit_sot = info["n_mit_sot"] n_mit_sot = info.n_mit_sot
self.mit_mot_in_slices = info["tap_array"][:n_mit_mot] self.mit_mot_in_slices = info.tap_array[:n_mit_mot]
self.mit_sot_in_slices = info["tap_array"][n_mit_mot : n_mit_mot + n_mit_sot] self.mit_sot_in_slices = info.tap_array[n_mit_mot : n_mit_mot + n_mit_sot]
n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices) n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices) n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)
...@@ -943,19 +942,19 @@ class ScanArgs: ...@@ -943,19 +942,19 @@ class ScanArgs:
self.outer_in_mit_sot = outer_inputs[p : p + n_mit_sot] self.outer_in_mit_sot = outer_inputs[p : p + n_mit_sot]
p += n_mit_sot p += n_mit_sot
n_sit_sot = info["n_sit_sot"] n_sit_sot = info.n_sit_sot
self.outer_in_sit_sot = outer_inputs[p : p + n_sit_sot] self.outer_in_sit_sot = outer_inputs[p : p + n_sit_sot]
self.inner_in_sit_sot = inner_inputs[q : q + n_sit_sot] self.inner_in_sit_sot = inner_inputs[q : q + n_sit_sot]
p += n_sit_sot p += n_sit_sot
q += n_sit_sot q += n_sit_sot
n_shared_outs = info["n_shared_outs"] n_shared_outs = info.n_shared_outs
self.outer_in_shared = outer_inputs[p : p + n_shared_outs] self.outer_in_shared = outer_inputs[p : p + n_shared_outs]
self.inner_in_shared = inner_inputs[q : q + n_shared_outs] self.inner_in_shared = inner_inputs[q : q + n_shared_outs]
p += n_shared_outs p += n_shared_outs
q += n_shared_outs q += n_shared_outs
n_nit_sot = info["n_nit_sot"] n_nit_sot = info.n_nit_sot
self.outer_in_nit_sot = outer_inputs[p : p + n_nit_sot] self.outer_in_nit_sot = outer_inputs[p : p + n_nit_sot]
p += n_nit_sot p += n_nit_sot
...@@ -966,14 +965,14 @@ class ScanArgs: ...@@ -966,14 +965,14 @@ class ScanArgs:
p = 0 p = 0
q = 0 q = 0
self.mit_mot_out_slices = info["mit_mot_out_slices"] self.mit_mot_out_slices = info.mit_mot_out_slices
n_mit_mot_outs = info["n_mit_mot_outs"] n_mit_mot_outs = info.n_mit_mot_outs
self.outer_out_mit_mot = outer_outputs[p : p + n_mit_mot] self.outer_out_mit_mot = outer_outputs[p : p + n_mit_mot]
iomm = inner_outputs[q : q + n_mit_mot_outs] iomm = inner_outputs[q : q + n_mit_mot_outs]
self.inner_out_mit_mot = [] self.inner_out_mit_mot = ()
qq = 0 qq = 0
for sl in self.mit_mot_out_slices: for sl in self.mit_mot_out_slices:
self.inner_out_mit_mot.append(iomm[qq : qq + len(sl)]) self.inner_out_mit_mot += (iomm[qq : qq + len(sl)],)
qq += len(sl) qq += len(sl)
p += n_mit_mot p += n_mit_mot
q += n_mit_mot_outs q += n_mit_mot_outs
...@@ -1001,19 +1000,17 @@ class ScanArgs: ...@@ -1001,19 +1000,17 @@ class ScanArgs:
assert p == len(outer_outputs) assert p == len(outer_outputs)
assert q == len(inner_outputs) assert q == len(inner_outputs)
self.other_info = OrderedDict() self.other_info = {
k: getattr(info, k)
for k in ( for k in (
"truncate_gradient", "truncate_gradient",
"name", "name",
"mode",
"destroy_map",
"gpua", "gpua",
"as_while", "as_while",
"profile", "profile",
"allow_gc", "allow_gc",
): )
if k in info: }
self.other_info[k] = info[k]
@staticmethod @staticmethod
def from_node(node, clone=False): def from_node(node, clone=False):
...@@ -1032,26 +1029,24 @@ class ScanArgs: ...@@ -1032,26 +1029,24 @@ class ScanArgs:
@classmethod @classmethod
def create_empty(cls): def create_empty(cls):
info = OrderedDict( from aesara.scan.op import ScanInfo
[
("n_seqs", 0), info = ScanInfo(
("n_mit_mot", 0), n_seqs=0,
("n_mit_sot", 0), n_mit_mot=0,
("tap_array", []), n_mit_sot=0,
("n_sit_sot", 0), tap_array=(),
("n_nit_sot", 0), n_sit_sot=0,
("n_shared_outs", 0), n_nit_sot=0,
("n_mit_mot_outs", 0), n_shared_outs=0,
("mit_mot_out_slices", []), n_mit_mot_outs=0,
("truncate_gradient", -1), mit_mot_out_slices=(),
("name", None), truncate_gradient=-1,
("mode", None), name=None,
("destroy_map", OrderedDict()), gpua=False,
("gpua", False), as_while=False,
("as_while", False), profile=False,
("profile", False), allow_gc=False,
("allow_gc", False),
]
) )
res = cls([1], [], [], [], info) res = cls([1], [], [], [], info)
res.n_steps = None res.n_steps = None
...@@ -1060,7 +1055,7 @@ class ScanArgs: ...@@ -1060,7 +1055,7 @@ class ScanArgs:
@property @property
def n_nit_sot(self): def n_nit_sot(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings` # This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return self.info["n_nit_sot"] return self.info.n_nit_sot
@property @property
def inputs(self): def inputs(self):
...@@ -1070,7 +1065,7 @@ class ScanArgs: ...@@ -1070,7 +1065,7 @@ class ScanArgs:
@property @property
def n_mit_mot(self): def n_mit_mot(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings` # This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return self.info["n_mit_mot"] return self.info.n_mit_mot
@property @property
def var_mappings(self): def var_mappings(self):
...@@ -1141,20 +1136,22 @@ class ScanArgs: ...@@ -1141,20 +1136,22 @@ class ScanArgs:
@property @property
def info(self): def info(self):
return OrderedDict( from aesara.scan.op import ScanInfo
return ScanInfo(
n_seqs=len(self.outer_in_seqs), n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot), n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot), n_mit_sot=len(self.outer_in_mit_sot),
tap_array=( tap_array=(
self.mit_mot_in_slices tuple(tuple(v) for v in self.mit_mot_in_slices)
+ self.mit_sot_in_slices + tuple(tuple(v) for v in self.mit_sot_in_slices)
+ [[-1]] * len(self.inner_in_sit_sot) + ((-1,),) * len(self.inner_in_sit_sot)
), ),
n_sit_sot=len(self.outer_in_sit_sot), n_sit_sot=len(self.outer_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot), n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared), n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices), n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=self.mit_mot_out_slices, mit_mot_out_slices=tuple(self.mit_mot_out_slices),
**self.other_info, **self.other_info,
) )
......
-e ./ -e ./
dataclasses>=0.7; python_version < '3.7'
filelock filelock
flake8==3.8.4 flake8==3.8.4
pep8 pep8
......
#!/usr/bin/env python #!/usr/bin/env python
import sys
from setuptools import find_packages, setup from setuptools import find_packages, setup
import versioneer import versioneer
...@@ -43,6 +45,11 @@ Programming Language :: Python :: 3.9 ...@@ -43,6 +45,11 @@ Programming Language :: Python :: 3.9
""" """
CLASSIFIERS = [_f for _f in CLASSIFIERS.split("\n") if _f] CLASSIFIERS = [_f for _f in CLASSIFIERS.split("\n") if _f]
install_requires = ["numpy>=1.17.0", "scipy>=0.14", "filelock"]
if sys.version_info[0:2] < (3, 7):
install_requires += ["dataclasses"]
if __name__ == "__main__": if __name__ == "__main__":
setup( setup(
name=NAME, name=NAME,
...@@ -57,7 +64,7 @@ if __name__ == "__main__": ...@@ -57,7 +64,7 @@ if __name__ == "__main__":
license=LICENSE, license=LICENSE,
platforms=PLATFORMS, platforms=PLATFORMS,
packages=find_packages(exclude=["tests", "tests.*"]), packages=find_packages(exclude=["tests", "tests.*"]),
install_requires=["numpy>=1.17.0", "scipy>=0.14", "filelock"], install_requires=install_requires,
package_data={ package_data={
"": [ "": [
"*.txt", "*.txt",
......
...@@ -252,10 +252,7 @@ def test_ScanArgs(): ...@@ -252,10 +252,7 @@ def test_ScanArgs():
# The `scan_args` base class always clones the inner-graph; # The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same) # here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inputs assert scan_args.inputs == scan_op.inputs
scan_op_info = dict(scan_op.info) assert scan_args.info == scan_op.info
# The `ScanInfo` dictionary has the wrong order and an extra entry
del scan_op_info["strict"]
assert dict(scan_args.info) == scan_op_info
assert scan_args.var_mappings == scan_op.var_mappings assert scan_args.var_mappings == scan_op.var_mappings
# Check that `ScanArgs.find_among_fields` works # Check that `ScanArgs.find_among_fields` works
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论