提交 3b0e071e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove Scan.inputs and Scan.outputs

上级 b979dd6e
......@@ -34,7 +34,7 @@ def array0d_range(x):
@numba_funcify.register(Scan)
def numba_funcify_Scan(op, node, **kwargs):
inner_fg = FunctionGraph(op.inputs, op.outputs)
inner_fg = FunctionGraph(op.inner_inputs, op.inner_outputs)
numba_at_inner_func = numba_basic.numba_njit(numba_funcify(inner_fg, **kwargs))
n_seqs = op.info.n_seqs
......
......@@ -1537,7 +1537,7 @@ def pydotprint(
if hasattr(scan_op.op, "_fn"):
to_print = scan_op.op.fn
else:
to_print = scan_op.op.outputs
to_print = scan_op.op.inner_outputs
pydotprint(
to_print,
new_name,
......
......@@ -555,7 +555,7 @@ class ScanMethodsMixin:
# Note : the number of non-sequence inputs is not stored in self.info
# so it has to be inferred from the number of inner inputs that remain
# to be handled
for i in range(len(self.inputs) - inner_iidx):
for i in range(len(self.inner_inputs) - inner_iidx):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([])
......@@ -633,8 +633,8 @@ class ScanMethodsMixin:
for (inner_iidx, inner_oidx) in product(inner_iidxs, inner_oidxs):
type_input = self.inputs[inner_iidx].type
type_output = self.outputs[inner_oidx].type
type_input = self.inner_inputs[inner_iidx].type
type_output = self.inner_outputs[inner_oidx].type
if (
type_input.dtype != type_output.dtype
or type_input.broadcastable != type_output.broadcastable
......@@ -780,8 +780,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.fgraph = FunctionGraph(inputs, outputs, clone=False)
self.inputs = self.fgraph.inputs
self.outputs = self.fgraph.outputs
self.info = info
self.truncate_gradient = truncate_gradient
self.name = name
......@@ -866,11 +864,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.n_tap_outs = info.n_mit_mot + info.n_mit_sot
# Do the missing inputs check here to have the error early.
for var in graph_inputs(self.outputs, self.inputs):
if var not in self.inputs and not isinstance(var, Constant):
for var in graph_inputs(self.inner_outputs, self.inner_inputs):
if var not in self.inner_inputs and not isinstance(var, Constant):
raise MissingInputError(f"ScanOp is missing an input: {repr(var)}")
self._cmodule_key = CLinker().cmodule_key_variables(
self.inputs, self.outputs, []
self.inner_inputs, self.inner_outputs, []
)
self._hash_inner_graph = hash(self._cmodule_key)
......@@ -963,12 +961,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# the number of inputs of the inner function of scan
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
n_inner_ins = (
len(self.inner_seqs(self.inputs))
len(self.inner_seqs(self.inner_inputs))
+ len(self.mitmot_taps())
+ len(self.mitsot_taps())
+ len(self.inner_sitsot(self.inputs))
+ len(self.inner_shared(self.inputs))
+ len(self.inner_non_seqs(self.inputs))
+ len(self.inner_sitsot(self.inner_inputs))
+ len(self.inner_shared(self.inner_inputs))
+ len(self.inner_non_seqs(self.inner_inputs))
)
if n_outer_ins != n_inner_ins:
......@@ -984,7 +982,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# them have the same dtype
argoffset = 0
for inner_seq, outer_seq in zip(
self.inner_seqs(self.inputs), self.outer_seqs(inputs)
self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs)
):
check_broadcast(outer_seq, inner_seq)
new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq))
......@@ -996,8 +994,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# - variable representing an output slice of the output
ipos = 0
opos = 0
inner_mitmot = self.inner_mitmot(self.inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.outputs)
inner_mitmot = self.inner_mitmot(self.inner_inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.inner_outputs)
for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(), self.mitmot_out_taps(), self.outer_mitmot(inputs))
):
......@@ -1047,12 +1045,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
argoffset += len(self.outer_mitmot(inputs))
# Same checks as above but for outputs of type mit_sot
ipos = 0
inner_mitsots = self.inner_mitsot(self.inputs)
inner_mitsots = self.inner_mitsot(self.inner_inputs)
for idx, (itaps, _outer_mitsot, inner_mitsot_out) in enumerate(
zip(
self.mitsot_taps(),
self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.outputs),
self.inner_mitsot_outs(self.inner_outputs),
)
):
outer_mitsot = copy_var_format(_outer_mitsot, as_var=inner_mitsots[ipos])
......@@ -1102,9 +1100,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Same checks as above but for outputs of type sit_sot
for idx, (inner_sitsot, _outer_sitsot, inner_sitsot_out) in enumerate(
zip(
self.inner_sitsot(self.inputs),
self.inner_sitsot(self.inner_inputs),
self.outer_sitsot(inputs),
self.inner_sitsot_outs(self.outputs),
self.inner_sitsot_outs(self.inner_outputs),
)
):
outer_sitsot = copy_var_format(_outer_sitsot, as_var=inner_sitsot)
......@@ -1149,8 +1147,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# dtype. Maybe even same type ?!
for idx, (inner_shared, inner_shared_out, _outer_shared) in enumerate(
zip(
self.inner_shared(self.inputs),
self.inner_shared_outs(self.outputs),
self.inner_shared(self.inner_inputs),
self.inner_shared_outs(self.inner_outputs),
self.outer_shared(inputs),
)
):
......@@ -1210,7 +1208,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# type of tensor as the output, it is always a scalar int.
new_inputs += [as_tensor_variable(ons) for ons in self.outer_nitsot(inputs)]
for inner_nonseq, _outer_nonseq in zip(
self.inner_non_seqs(self.inputs), self.outer_non_seqs(inputs)
self.inner_non_seqs(self.inner_inputs), self.outer_non_seqs(inputs)
):
outer_nonseq = copy_var_format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq)
......@@ -1251,7 +1249,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
]
self.vector_outs += [
isinstance(t.type, TensorType) and t.ndim == 0
for t in self.outer_nitsot_outs(self.outputs)
for t in self.outer_nitsot_outs(self.inner_outputs)
]
outputs = [t() for t in self.output_types]
......@@ -1289,18 +1287,21 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Compare inner graphs
# TODO: Use `self.inner_fgraph == other.inner_fgraph`
if len(self.inputs) != len(other.inputs):
if len(self.inner_inputs) != len(other.inner_inputs):
return False
if len(self.outputs) != len(other.outputs):
if len(self.inner_outputs) != len(other.inner_outputs):
return False
for self_in, other_in in zip(self.inputs, other.inputs):
for self_in, other_in in zip(self.inner_inputs, other.inner_inputs):
if self_in.type != other_in.type:
return False
return equal_computations(
self.outputs, other.outputs, self.inputs, other.inputs
self.inner_outputs,
other.inner_outputs,
self.inner_inputs,
other.inner_inputs,
)
def __str__(self):
......@@ -1362,15 +1363,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# output variable becomes an update to be performed on it, possibly
# inplace at the end of the functions's execution.
wrapped_inputs = [
In(x, borrow=False) for x in self.inputs[: self.info.n_seqs]
In(x, borrow=False) for x in self.inner_inputs[: self.n_seqs]
]
new_outputs = [x for x in self.outputs]
new_outputs = [x for x in self.inner_outputs]
input_idx = self.info.n_seqs
for mitmot_idx in range(self.info.n_mit_mot):
for inp_tap in self.info.mit_mot_in_slices[mitmot_idx]:
if inp_tap in self.info.mit_mot_out_slices[mitmot_idx]:
inp = self.inputs[input_idx]
inp = self.inner_inputs[input_idx]
# Figure out the index of the corresponding output
output_idx = sum(
......@@ -1392,18 +1393,22 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
wrapped_inp = In(
variable=inp,
value=default_val,
update=self.outputs[output_idx],
update=self.inner_outputs[output_idx],
)
wrapped_inputs.append(wrapped_inp)
else:
# Wrap the corresponding input as usual. Leave the
# output as-is.
wrapped_inputs.append(In(self.inputs[input_idx], borrow=False))
wrapped_inputs.append(
In(self.inner_inputs[input_idx], borrow=False)
)
input_idx += 1
# Wrap the inputs not associated to mitmots and wrap the remaining
# outputs
wrapped_inputs += [In(x, borrow=False) for x in self.inputs[input_idx:]]
wrapped_inputs += [
In(x, borrow=False) for x in self.inner_inputs[input_idx:]
]
wrapped_outputs = [Out(x, borrow=True) for x in new_outputs[:slices]]
wrapped_outputs += new_outputs[slices:]
......@@ -1433,11 +1438,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
wrapped_inputs = [In(x, borrow=True) for x in self.inputs]
wrapped_outputs = [Out(x, borrow=False) for x in self.outputs[:slices]]
wrapped_outputs += self.outputs[slices:]
compilation_mode = self.mode_instance
wrapped_inputs = [In(x, borrow=True) for x in self.inner_inputs]
wrapped_outputs = [
Out(x, borrow=False) for x in self.inner_outputs[:slices]
]
wrapped_outputs += self.inner_outputs[slices:]
profile = None
if config.profile or (
......@@ -1463,11 +1469,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
@property
def inner_inputs(self):
return self.inputs
return self.fgraph.inputs
@property
def inner_outputs(self):
return self.outputs
return self.fgraph.outputs
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
"""
......@@ -2201,7 +2207,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Here we build 2 variables;
# - A list `inner_ins_shapes`, such that inner_ins_shapes[i] is the
# shape of self.inputs[i]
# shape of self.inner_inputs[i]
# - A dictionary `out_equivalent` containing, for every inner input,
# an equivalent variable computed from the outer inputs.
# NOTE : For non-sequences, this equivalence is trivial. For
......@@ -2225,7 +2231,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
seqs_shape = [x[1:] for x in input_shapes[1 : 1 + info.n_seqs]]
# We disable extra infer_shape for now. See gh-3765.
# if extra_infer_shape:
# inner_seqs = self.inputs[: info.n_seqs]
# inner_seqs = self.inner_inputs[: info.n_seqs]
# outer_seqs = node.inputs[1 : 1 + info.n_seqs]
# for in_s, out_s in zip(inner_seqs, outer_seqs):
# out_equivalent[in_s] = out_s[0]
......@@ -2249,7 +2255,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# if extra_infer_shape:
# mintap = abs(min(taps))
# corresponding_tap = node.inputs[outer_inp_idx][mintap + k]
# out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap
# out_equivalent[self.inner_inputs[inner_inp_idx]] = corresponding_tap
outer_inp_idx += 1
# shared_outs
......@@ -2260,26 +2266,28 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# non_sequences
offset += info.n_nit_sot + info.n_shared_outs
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(self.inputs)
assert len(inner_ins_shapes) == len(self.inner_inputs)
# Non-sequences have a direct equivalent from self.inputs in
# Non-sequences have a direct equivalent from self.inner_inputs in
# node.inputs
inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape) :]
inner_non_sequences = self.inner_inputs[len(seqs_shape) + len(outs_shape) :]
for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]):
out_equivalent[in_ns] = out_ns
if info.as_while:
self_outs = self.outputs[:-1]
self_outs = self.inner_outputs[:-1]
else:
self_outs = self.outputs
self_outs = self.inner_outputs
outs_shape = infer_shape(
outs=self_outs, inputs=self.inputs, input_shapes=inner_ins_shapes
outs=self_outs, inputs=self.inner_inputs, input_shapes=inner_ins_shapes
)
# Will be used to check if outs_shape can be expressed without using
# variables in self.inputs.
# variables in self.inner_inputs.
# The shapes of node.inputs are valid.
validator = Validator(
valid=input_shapes, invalid=self.inputs, valid_equivalent=out_equivalent
valid=input_shapes,
invalid=self.inner_inputs,
valid_equivalent=out_equivalent,
)
offset = 1 + info.n_seqs
......@@ -2337,7 +2345,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
return node.tag.connection_pattern
# Obtain the connection pattern of the inner function.
inner_connect_pattern = io_connection_pattern(self.inputs, self.outputs)
inner_connect_pattern = io_connection_pattern(
self.inner_inputs, self.inner_outputs
)
# Initially assume no outer input is connected to any outer output
connection_pattern = [[False for output in node.outputs] for x in node.inputs]
......@@ -2415,8 +2425,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if self.truncate_gradient != -1:
grad_steps = minimum(grad_steps, self.truncate_gradient)
self_inputs = self.inputs
self_outputs = self.outputs
self_inputs = self.inner_inputs
self_outputs = self.inner_outputs
# differentiable inputs
diff_inputs = (
self.inner_seqs(self_inputs)
......@@ -3144,12 +3154,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def R_op(self, inputs, eval_points):
# Step 0. Prepare some shortcut variable
info = self.info
self_inputs = self.inputs
self_inputs = self.inner_inputs
rop_of_inputs = (
self_inputs[: info.n_seqs + self.n_outs]
+ self_inputs[info.n_seqs + self.n_outs + info.n_shared_outs :]
)
self_outputs = self.outputs
self_outputs = self.inner_outputs
# Step 1. Compute the R_op of the inner function
inner_eval_points = [safe_new(x, "_evalpoint") for x in rop_of_inputs]
......
......@@ -87,8 +87,8 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
st += op.n_sit_sot
st += op.n_shared_outs
op_ins = op.inputs
op_outs = op.outputs
op_ins = op.inner_inputs
op_outs = op.inner_outputs
# Corresponds to the initial states, which should stay untouched.
# We put those variables aside, and put them back at the end.
......@@ -202,7 +202,7 @@ def push_out_non_seq_scan(fgraph, node):
if not isinstance(node.op, Scan):
return False
node_inputs, node_outputs = node.op.inputs, node.op.outputs
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_outs_set = set(node_outputs)
......@@ -412,7 +412,7 @@ def push_out_seq_scan(fgraph, node):
if not isinstance(node.op, Scan):
return False
node_inputs, node_outputs = node.op.inputs, node.op.outputs
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_outs_set = set(node_outputs)
......@@ -723,8 +723,8 @@ def push_out_inner_vars(
new_scan_args = ScanArgs(
new_scan_node.inputs,
new_scan_node.outputs,
new_scan_node.op.inputs,
new_scan_node.op.outputs,
new_scan_node.op.inner_inputs,
new_scan_node.op.inner_outputs,
new_scan_node.op.info,
)
......@@ -821,7 +821,9 @@ def push_out_add_scan(fgraph, node):
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(node.inputs, node.outputs, op.inputs, op.outputs, op.info)
args = ScanArgs(
node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info
)
clients = {}
local_fgraph_topo = io_toposort(
......@@ -986,8 +988,8 @@ class ScanInplaceOptimizer(GlobalOptimizer):
typeConstructor = self.typeInfer(node)
new_op = Scan(
op.inputs,
op.outputs,
op.inner_inputs,
op.inner_outputs,
op.info,
mode=op.mode,
typeConstructor=typeConstructor,
......@@ -1656,7 +1658,7 @@ class ScanMerge(GlobalOptimizer):
if nodes[0].op.info.as_while:
as_while = True
condition = nodes[0].op.outputs[-1]
condition = nodes[0].op.inner_outputs[-1]
else:
as_while = False
......@@ -1676,15 +1678,15 @@ class ScanMerge(GlobalOptimizer):
return ls
for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inputs), idx))
inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inner_inputs), idx))
outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
mit_mot_out_slices = ()
mit_mot_in_slices = ()
for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.outputs))
inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inner_inputs), idx))
inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.inner_outputs))
mit_mot_in_slices += nd.op.mitmot_taps()
mit_mot_out_slices += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
......@@ -1692,40 +1694,40 @@ class ScanMerge(GlobalOptimizer):
mit_sot_in_slices = ()
for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.outputs))
inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inner_inputs), idx))
inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.inner_outputs))
mit_sot_in_slices += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
sit_sot_in_slices = ()
for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inputs), idx))
inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inner_inputs), idx))
sit_sot_in_slices += tuple((-1,) for x in range(nd.op.n_sit_sot))
inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.outputs))
inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.inner_outputs))
outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# Shared
inner_ins[idx].append(rename(nd.op.inner_shared(nd.op.inputs), idx))
inner_ins[idx].append(rename(nd.op.inner_shared(nd.op.inner_inputs), idx))
outer_ins += rename(nd.op.outer_shared(nd.inputs), idx)
for idx, nd in enumerate(nodes):
# NitSot
inner_outs[idx].append(nd.op.inner_nitsot_outs(nd.op.outputs))
inner_outs[idx].append(nd.op.inner_nitsot_outs(nd.op.inner_outputs))
outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx)
outer_outs += nd.op.outer_nitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# Shared
outer_outs += nd.op.outer_shared_outs(nd.outputs)
inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.outputs))
inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.inner_outputs))
n_non_seqs = 0
for idx, nd in enumerate(nodes):
# Non Seqs
node_inner_non_seqs = nd.op.inner_non_seqs(nd.op.inputs)
node_inner_non_seqs = nd.op.inner_non_seqs(nd.op.inner_inputs)
n_non_seqs += len(node_inner_non_seqs)
inner_ins[idx].append(rename(node_inner_non_seqs, idx))
outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)
......@@ -1863,9 +1865,11 @@ class ScanMerge(GlobalOptimizer):
if not node.op.info.as_while:
return True
cond = node.op.outputs[-1]
rep_cond = rep.op.outputs[-1]
return equal_computations([cond], [rep_cond], node.op.inputs, rep.op.inputs)
cond = node.op.inner_outputs[-1]
rep_cond = rep.op.inner_outputs[-1]
return equal_computations(
[cond], [rep_cond], node.op.inner_inputs, rep.op.inner_inputs
)
def apply(self, fgraph):
# Collect all scan nodes ordered according to toposort
......@@ -1943,8 +1947,8 @@ def scan_merge_inouts(fgraph, node):
a = ScanArgs(
node.inputs,
node.outputs,
node.op.inputs,
node.op.outputs,
node.op.inner_inputs,
node.op.inner_outputs,
node.op.info,
)
......@@ -2003,8 +2007,8 @@ def scan_merge_inouts(fgraph, node):
na = ScanArgs(
outer_inputs,
outputs,
new_op.inputs,
new_op.outputs,
new_op.inner_inputs,
new_op.inner_outputs,
new_op.info,
)
remove = [node]
......@@ -2146,10 +2150,10 @@ def push_out_dot1_scan(fgraph, node):
# Note that this works when only you need X[-1] in the end
# and assumes dimshuffle are applied to vectors before calling dot
op = node.op
sitsot_ins = op.inner_sitsot(op.inputs)
sitsot_outs = op.inner_sitsot_outs(op.outputs)
sitsot_ins = op.inner_sitsot(op.inner_inputs)
sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
outer_sitsot = op.outer_sitsot_outs(node.outputs)
seqs = op.inner_seqs(op.inputs)
seqs = op.inner_seqs(op.inner_inputs)
for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot):
if (
......@@ -2191,23 +2195,23 @@ def push_out_dot1_scan(fgraph, node):
# First let us split all arguments according to their
# corresponding categories
inner_seqs = op.inner_seqs(op.inputs)
inner_seqs = op.inner_seqs(op.inner_inputs)
outer_seqs = op.outer_seqs(node.inputs)
inner_mitmot = op.inner_mitmot(op.inputs)
inner_mitmot = op.inner_mitmot(op.inner_inputs)
outer_mitmot = op.outer_mitmot(node.inputs)
inner_mitmot_outs = op.inner_mitmot_outs(op.outputs)
inner_mitsot = op.inner_mitsot(op.inputs)
inner_mitmot_outs = op.inner_mitmot_outs(op.inner_outputs)
inner_mitsot = op.inner_mitsot(op.inner_inputs)
outer_mitsot = op.outer_mitsot(node.inputs)
inner_mitsot_outs = op.inner_mitsot_outs(op.outputs)
inner_sitsot = op.inner_sitsot(op.inputs)
inner_mitsot_outs = op.inner_mitsot_outs(op.inner_outputs)
inner_sitsot = op.inner_sitsot(op.inner_inputs)
outer_sitsot = op.outer_sitsot(node.inputs)
inner_sitsot_outs = op.inner_sitsot_outs(op.outputs)
inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
outer_nitsot = op.outer_nitsot(node.inputs)
inner_nitsot_outs = op.inner_nitsot_outs(op.outputs)
inner_shared = op.inner_shared(op.inputs)
inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs)
inner_shared = op.inner_shared(op.inner_inputs)
outer_shared = op.outer_shared(node.inputs)
inner_shared_outs = op.inner_shared_outs(op.outputs)
inner_non_seqs = op.inner_non_seqs(op.inputs)
inner_shared_outs = op.inner_shared_outs(op.inner_outputs)
inner_non_seqs = op.inner_non_seqs(op.inner_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs)
new_info = dataclasses.replace(
......
......@@ -357,7 +357,7 @@ def scan_can_remove_outs(op, out_idxs):
second with the outputs that can not be removed.
"""
non_removable = [o for i, o in enumerate(op.outputs) if i not in out_idxs]
non_removable = [o for i, o in enumerate(op.inner_outputs) if i not in out_idxs]
required_inputs = list(graph_inputs(non_removable))
out_ins = []
......@@ -366,10 +366,10 @@ def scan_can_remove_outs(op, out_idxs):
chain(op.mit_mot_in_slices, op.mit_sot_in_slices, op.sit_sot_in_slices)
):
n_ins = len(tap)
out_ins += [op.inputs[offset : offset + n_ins]]
out_ins += [op.inner_inputs[offset : offset + n_ins]]
offset += n_ins
out_ins += [[] for k in range(op.n_nit_sot)]
out_ins += [[op.inputs[offset + k]] for k in range(op.n_shared_outs)]
out_ins += [[op.inner_inputs[offset + k]] for k in range(op.n_shared_outs)]
added = True
out_idxs_mask = [1 for idx in out_idxs]
......@@ -379,7 +379,7 @@ def scan_can_remove_outs(op, out_idxs):
if out_idxs_mask[pos] and any(x in required_inputs for x in out_ins[idx]):
# This output is required ..
out_idxs_mask[pos] = 0
required_inputs += list(graph_inputs([op.outputs[idx]]))
required_inputs += list(graph_inputs([op.inner_outputs[idx]]))
added = True
required_outs = [x for i, x in enumerate(out_idxs) if out_idxs_mask[i] == 0]
......@@ -412,7 +412,7 @@ def compress_outs(op, not_required, inputs):
as_while=op.info.as_while,
)
op_inputs = op.inputs[: op.n_seqs]
op_inputs = op.inner_inputs[: op.n_seqs]
op_outputs = []
node_inputs = inputs[: op.n_seqs + 1]
map_old_new = OrderedDict()
......@@ -434,11 +434,11 @@ def compress_outs(op, not_required, inputs):
)
# input taps
for jdx in op.mit_mot_in_slices[idx]:
op_inputs += [op.inputs[i_offset]]
op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1
# output taps
for jdx in op.mit_mot_out_slices[idx]:
op_outputs += [op.outputs[o_offset]]
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
# node inputs
node_inputs += [inputs[ni_offset + idx]]
......@@ -459,10 +459,10 @@ def compress_outs(op, not_required, inputs):
)
# input taps
for jdx in op.mit_sot_in_slices[idx]:
op_inputs += [op.inputs[i_offset]]
op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1
# output taps
op_outputs += [op.outputs[o_offset]]
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
# node inputs
node_inputs += [inputs[ni_offset + idx]]
......@@ -481,10 +481,10 @@ def compress_outs(op, not_required, inputs):
sit_sot_in_slices=info.sit_sot_in_slices + (op.sit_sot_in_slices[idx],),
)
# input taps
op_inputs += [op.inputs[i_offset]]
op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1
# output taps
op_outputs += [op.outputs[o_offset]]
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
# node inputs
node_inputs += [inputs[ni_offset + idx]]
......@@ -500,7 +500,7 @@ def compress_outs(op, not_required, inputs):
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.outputs[o_offset]]
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]]
else:
......@@ -513,9 +513,9 @@ def compress_outs(op, not_required, inputs):
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(info, n_shared_outs=info.n_shared_outs + 1)
op_outputs += [op.outputs[o_offset]]
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
op_inputs += [op.inputs[i_offset]]
op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1
shared_ins += [inputs[ni_offset + idx]]
else:
......@@ -524,11 +524,11 @@ def compress_outs(op, not_required, inputs):
node_inputs += shared_ins
node_inputs += nit_sot_ins
# other stuff
op_inputs += op.inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inputs[i_offset:]))
op_inputs += op.inner_inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:]))
node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot :]
if op.info.as_while:
op_outputs += [op.outputs[o_offset]]
op_outputs += [op.inner_outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1
# map_old_new[len(op_outputs)-1] = o_offset
......@@ -706,8 +706,8 @@ class ScanArgs:
return ScanArgs(
node.inputs,
node.outputs,
node.op.inputs,
node.op.outputs,
node.op.inner_inputs,
node.op.inner_outputs,
node.op.info,
clone=clone,
)
......
......@@ -2594,7 +2594,7 @@ def test_inner_get_vector_length():
# Make sure the `size` in `scan_body` is a plain `Variable` instance
# carrying no information with which we can derive its length
size_clone = res.owner.op.inputs[1]
size_clone = res.owner.op.inner_inputs[1]
assert size_clone.owner is None
# Make sure the cloned `size` maps to the original `size_at`
......@@ -2696,9 +2696,6 @@ def test_profile_info():
assert fn.fn.call_counts == [0]
c = scalar("c", dtype="floatX")
class TestExamples:
"""Miscellaneous example-based tests with unnecessarily complicated setups and/or no background information.
......@@ -4021,9 +4018,6 @@ class TestExamples:
self._grad_mout_helper(1, None)
c = scalar("c", dtype="floatX")
@pytest.mark.parametrize(
"fn, sequences, outputs_info, non_sequences, n_steps, op_check",
[
......@@ -4050,7 +4044,7 @@ c = scalar("c", dtype="floatX")
lambda c: at.as_tensor(2.0) * c,
[],
[{}],
[c],
[scalar("c", dtype="floatX")],
3,
lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0,
),
......@@ -4105,7 +4099,7 @@ c = scalar("c", dtype="floatX")
# TODO: mit-mot (can't be created through the `scan` interface)
],
)
def test_n_non_seqs(fn, sequences, outputs_info, non_sequences, n_steps, op_check):
def test_ScanInfo_totals(fn, sequences, outputs_info, non_sequences, n_steps, op_check):
res, _ = scan(
fn,
sequences=sequences,
......@@ -4124,13 +4118,9 @@ def test_n_non_seqs(fn, sequences, outputs_info, non_sequences, n_steps, op_chec
scan_op = res.owner.op
assert isinstance(scan_op, Scan)
# from aesara.scan.utils import ScanArgs
# print(ScanArgs.from_node(res.owner))
# print(res.owner.op.info)
_ = op_check(scan_op)
assert scan_op.info.n_outer_inputs == len(res.owner.inputs)
assert scan_op.info.n_outer_outputs == len(res.owner.outputs)
assert scan_op.info.n_inner_inputs == len(res.owner.op.inputs)
assert scan_op.info.n_inner_outputs == len(res.owner.op.outputs)
assert scan_op.info.n_inner_inputs == len(res.owner.op.inner_inputs)
assert scan_op.info.n_inner_outputs == len(res.owner.op.inner_outputs)
......@@ -95,7 +95,7 @@ class TestRemoveConstantsAndUnusedInputsScan:
scan_node = scan_nodes[0]
assert len(scan_node.inputs[1:]) == len(set(scan_node.inputs[1:]))
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
inp = scan_node.op.inner_non_seqs(scan_node.op.inner_inputs)
assert len(inp) == 1
assert len(inp) == len(set(inp))
......@@ -170,11 +170,11 @@ class TestRemoveConstantsAndUnusedInputsScan:
scan_node = scan_nodes[0]
assert len(scan_node.inputs[1:]) == len(set(scan_node.inputs[1:]))
inp = scan_node.op.inner_seqs(scan_node.op.inputs)
inp = scan_node.op.inner_seqs(scan_node.op.inner_inputs)
assert len(inp) == 1
inp = scan_node.op.outer_seqs(scan_node.inputs)
assert len(inp) == 1
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
inp = scan_node.op.inner_non_seqs(scan_node.op.inner_inputs)
assert len(inp) == 1
inp = scan_node.op.outer_non_seqs(scan_node.inputs)
assert len(inp) == 1
......@@ -445,8 +445,8 @@ class TestPushOutNonSeqScan:
scan_node = [
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
][0]
assert len(scan_node.op.outputs) == 1
assert not isinstance(scan_node.op.outputs[0], Dot)
assert len(scan_node.op.inner_outputs) == 1
assert not isinstance(scan_node.op.inner_outputs[0], Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
......@@ -491,8 +491,8 @@ class TestPushOutNonSeqScan:
][0]
# NOTE: WHEN INFER_SHAPE IS RE-ENABLED, BELOW THE SCAN MUST
# HAVE ONLY 1 OUTPUT.
assert len(scan_node.op.outputs) == 2
assert not isinstance(scan_node.op.outputs[0], Dot)
assert len(scan_node.op.inner_outputs) == 2
assert not isinstance(scan_node.op.inner_outputs[0], Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
......@@ -537,8 +537,8 @@ class TestPushOutNonSeqScan:
scan_node = [
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
][0]
assert len(scan_node.op.outputs) == 2
assert not isinstance(scan_node.op.outputs[0], Dot)
assert len(scan_node.op.inner_outputs) == 2
assert not isinstance(scan_node.op.inner_outputs[0], Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
......@@ -725,7 +725,7 @@ class TestPushOutAddScan:
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
][1]
for output in scan_node_grad.op.outputs:
for output in scan_node_grad.op.inner_outputs:
assert not (
isinstance(output.owner.op, Elemwise)
and any(isinstance(i, Dot) for i in output.owner.inputs)
......@@ -1466,7 +1466,7 @@ def test_alloc_inputs3():
f = function([_h0, _W1, _W2], o, mode="FAST_RUN")
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
assert len(scan_node.op.inputs) == 1
assert len(scan_node.op.inner_inputs) == 1
def test_opt_order():
......
......@@ -101,18 +101,18 @@ def test_ScanArgs():
assert scan_args.n_mit_mot == scan_op.n_mit_mot
# The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inputs
assert scan_args.inputs == scan_op.inner_inputs
assert scan_args.info == scan_op.info
# Check that `ScanArgs.find_among_fields` works
test_v = scan_op.inner_seqs(scan_op.inputs)[1]
test_v = scan_op.inner_seqs(scan_op.inner_inputs)[1]
field_info = scan_args.find_among_fields(test_v)
assert field_info.name == "inner_in_seqs"
assert field_info.index == 1
assert field_info.inner_index is None
assert scan_args.inner_inputs[field_info.agg_index] == test_v
test_l = scan_op.inner_non_seqs(scan_op.inputs)
test_l = scan_op.inner_non_seqs(scan_op.inner_inputs)
# We didn't index this argument, so it's a `list` (i.e. bad input)
field_info = scan_args.find_among_fields(test_l)
assert field_info is None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论