提交 3b0e071e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove Scan.inputs and Scan.outputs

上级 b979dd6e
......@@ -34,7 +34,7 @@ def array0d_range(x):
@numba_funcify.register(Scan)
def numba_funcify_Scan(op, node, **kwargs):
inner_fg = FunctionGraph(op.inputs, op.outputs)
inner_fg = FunctionGraph(op.inner_inputs, op.inner_outputs)
numba_at_inner_func = numba_basic.numba_njit(numba_funcify(inner_fg, **kwargs))
n_seqs = op.info.n_seqs
......
......@@ -1537,7 +1537,7 @@ def pydotprint(
if hasattr(scan_op.op, "_fn"):
to_print = scan_op.op.fn
else:
to_print = scan_op.op.outputs
to_print = scan_op.op.inner_outputs
pydotprint(
to_print,
new_name,
......
差异被折叠。
......@@ -87,8 +87,8 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
st += op.n_sit_sot
st += op.n_shared_outs
op_ins = op.inputs
op_outs = op.outputs
op_ins = op.inner_inputs
op_outs = op.inner_outputs
# Corresponds to the initial states, which should stay untouched.
# We put those variables aside, and put them back at the end.
......@@ -202,7 +202,7 @@ def push_out_non_seq_scan(fgraph, node):
if not isinstance(node.op, Scan):
return False
node_inputs, node_outputs = node.op.inputs, node.op.outputs
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_outs_set = set(node_outputs)
......@@ -412,7 +412,7 @@ def push_out_seq_scan(fgraph, node):
if not isinstance(node.op, Scan):
return False
node_inputs, node_outputs = node.op.inputs, node.op.outputs
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_outs_set = set(node_outputs)
......@@ -723,8 +723,8 @@ def push_out_inner_vars(
new_scan_args = ScanArgs(
new_scan_node.inputs,
new_scan_node.outputs,
new_scan_node.op.inputs,
new_scan_node.op.outputs,
new_scan_node.op.inner_inputs,
new_scan_node.op.inner_outputs,
new_scan_node.op.info,
)
......@@ -821,7 +821,9 @@ def push_out_add_scan(fgraph, node):
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(node.inputs, node.outputs, op.inputs, op.outputs, op.info)
args = ScanArgs(
node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info
)
clients = {}
local_fgraph_topo = io_toposort(
......@@ -986,8 +988,8 @@ class ScanInplaceOptimizer(GlobalOptimizer):
typeConstructor = self.typeInfer(node)
new_op = Scan(
op.inputs,
op.outputs,
op.inner_inputs,
op.inner_outputs,
op.info,
mode=op.mode,
typeConstructor=typeConstructor,
......@@ -1656,7 +1658,7 @@ class ScanMerge(GlobalOptimizer):
if nodes[0].op.info.as_while:
as_while = True
condition = nodes[0].op.outputs[-1]
condition = nodes[0].op.inner_outputs[-1]
else:
as_while = False
......@@ -1676,15 +1678,15 @@ class ScanMerge(GlobalOptimizer):
return ls
for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inputs), idx))
inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inner_inputs), idx))
outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
mit_mot_out_slices = ()
mit_mot_in_slices = ()
for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.outputs))
inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inner_inputs), idx))
inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.inner_outputs))
mit_mot_in_slices += nd.op.mitmot_taps()
mit_mot_out_slices += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
......@@ -1692,40 +1694,40 @@ class ScanMerge(GlobalOptimizer):
mit_sot_in_slices = ()
for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.outputs))
inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inner_inputs), idx))
inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.inner_outputs))
mit_sot_in_slices += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
sit_sot_in_slices = ()
for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inputs), idx))
inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inner_inputs), idx))
sit_sot_in_slices += tuple((-1,) for x in range(nd.op.n_sit_sot))
inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.outputs))
inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.inner_outputs))
outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# Shared
inner_ins[idx].append(rename(nd.op.inner_shared(nd.op.inputs), idx))
inner_ins[idx].append(rename(nd.op.inner_shared(nd.op.inner_inputs), idx))
outer_ins += rename(nd.op.outer_shared(nd.inputs), idx)
for idx, nd in enumerate(nodes):
# NitSot
inner_outs[idx].append(nd.op.inner_nitsot_outs(nd.op.outputs))
inner_outs[idx].append(nd.op.inner_nitsot_outs(nd.op.inner_outputs))
outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx)
outer_outs += nd.op.outer_nitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# Shared
outer_outs += nd.op.outer_shared_outs(nd.outputs)
inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.outputs))
inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.inner_outputs))
n_non_seqs = 0
for idx, nd in enumerate(nodes):
# Non Seqs
node_inner_non_seqs = nd.op.inner_non_seqs(nd.op.inputs)
node_inner_non_seqs = nd.op.inner_non_seqs(nd.op.inner_inputs)
n_non_seqs += len(node_inner_non_seqs)
inner_ins[idx].append(rename(node_inner_non_seqs, idx))
outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)
......@@ -1863,9 +1865,11 @@ class ScanMerge(GlobalOptimizer):
if not node.op.info.as_while:
return True
cond = node.op.outputs[-1]
rep_cond = rep.op.outputs[-1]
return equal_computations([cond], [rep_cond], node.op.inputs, rep.op.inputs)
cond = node.op.inner_outputs[-1]
rep_cond = rep.op.inner_outputs[-1]
return equal_computations(
[cond], [rep_cond], node.op.inner_inputs, rep.op.inner_inputs
)
def apply(self, fgraph):
# Collect all scan nodes ordered according to toposort
......@@ -1943,8 +1947,8 @@ def scan_merge_inouts(fgraph, node):
a = ScanArgs(
node.inputs,
node.outputs,
node.op.inputs,
node.op.outputs,
node.op.inner_inputs,
node.op.inner_outputs,
node.op.info,
)
......@@ -2003,8 +2007,8 @@ def scan_merge_inouts(fgraph, node):
na = ScanArgs(
outer_inputs,
outputs,
new_op.inputs,
new_op.outputs,
new_op.inner_inputs,
new_op.inner_outputs,
new_op.info,
)
remove = [node]
......@@ -2146,10 +2150,10 @@ def push_out_dot1_scan(fgraph, node):
# Note that this works when only you need X[-1] in the end
# and assumes dimshuffle are applied to vectors before calling dot
op = node.op
sitsot_ins = op.inner_sitsot(op.inputs)
sitsot_outs = op.inner_sitsot_outs(op.outputs)
sitsot_ins = op.inner_sitsot(op.inner_inputs)
sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
outer_sitsot = op.outer_sitsot_outs(node.outputs)
seqs = op.inner_seqs(op.inputs)
seqs = op.inner_seqs(op.inner_inputs)
for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot):
if (
......@@ -2191,23 +2195,23 @@ def push_out_dot1_scan(fgraph, node):
# First let us split all arguments according to their
# corresponding categories
inner_seqs = op.inner_seqs(op.inputs)
inner_seqs = op.inner_seqs(op.inner_inputs)
outer_seqs = op.outer_seqs(node.inputs)
inner_mitmot = op.inner_mitmot(op.inputs)
inner_mitmot = op.inner_mitmot(op.inner_inputs)
outer_mitmot = op.outer_mitmot(node.inputs)
inner_mitmot_outs = op.inner_mitmot_outs(op.outputs)
inner_mitsot = op.inner_mitsot(op.inputs)
inner_mitmot_outs = op.inner_mitmot_outs(op.inner_outputs)
inner_mitsot = op.inner_mitsot(op.inner_inputs)
outer_mitsot = op.outer_mitsot(node.inputs)
inner_mitsot_outs = op.inner_mitsot_outs(op.outputs)
inner_sitsot = op.inner_sitsot(op.inputs)
inner_mitsot_outs = op.inner_mitsot_outs(op.inner_outputs)
inner_sitsot = op.inner_sitsot(op.inner_inputs)
outer_sitsot = op.outer_sitsot(node.inputs)
inner_sitsot_outs = op.inner_sitsot_outs(op.outputs)
inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
outer_nitsot = op.outer_nitsot(node.inputs)
inner_nitsot_outs = op.inner_nitsot_outs(op.outputs)
inner_shared = op.inner_shared(op.inputs)
inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs)
inner_shared = op.inner_shared(op.inner_inputs)
outer_shared = op.outer_shared(node.inputs)
inner_shared_outs = op.inner_shared_outs(op.outputs)
inner_non_seqs = op.inner_non_seqs(op.inputs)
inner_shared_outs = op.inner_shared_outs(op.inner_outputs)
inner_non_seqs = op.inner_non_seqs(op.inner_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs)
new_info = dataclasses.replace(
......
......@@ -357,7 +357,7 @@ def scan_can_remove_outs(op, out_idxs):
second with the outputs that can not be removed.
"""
non_removable = [o for i, o in enumerate(op.outputs) if i not in out_idxs]
non_removable = [o for i, o in enumerate(op.inner_outputs) if i not in out_idxs]
required_inputs = list(graph_inputs(non_removable))
out_ins = []
......@@ -366,10 +366,10 @@ def scan_can_remove_outs(op, out_idxs):
chain(op.mit_mot_in_slices, op.mit_sot_in_slices, op.sit_sot_in_slices)
):
n_ins = len(tap)
out_ins += [op.inputs[offset : offset + n_ins]]
out_ins += [op.inner_inputs[offset : offset + n_ins]]
offset += n_ins
out_ins += [[] for k in range(op.n_nit_sot)]
out_ins += [[op.inputs[offset + k]] for k in range(op.n_shared_outs)]
out_ins += [[op.inner_inputs[offset + k]] for k in range(op.n_shared_outs)]
added = True
out_idxs_mask = [1 for idx in out_idxs]
......@@ -379,7 +379,7 @@ def scan_can_remove_outs(op, out_idxs):
if out_idxs_mask[pos] and any(x in required_inputs for x in out_ins[idx]):
# This output is required ..
out_idxs_mask[pos] = 0
required_inputs += list(graph_inputs([op.outputs[idx]]))
required_inputs += list(graph_inputs([op.inner_outputs[idx]]))
added = True
required_outs = [x for i, x in enumerate(out_idxs) if out_idxs_mask[i] == 0]
......@@ -412,7 +412,7 @@ def compress_outs(op, not_required, inputs):
as_while=op.info.as_while,
)
op_inputs = op.inputs[: op.n_seqs]
op_inputs = op.inner_inputs[: op.n_seqs]
op_outputs = []
node_inputs = inputs[: op.n_seqs + 1]
map_old_new = OrderedDict()
......@@ -434,11 +434,11 @@ def compress_outs(op, not_required, inputs):
)
# input taps
for jdx in op.mit_mot_in_slices[idx]:
op_inputs += [op.inputs[i_offset]]
op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1
# output taps
for jdx in op.mit_mot_out_slices[idx]:
op_outputs += [op.outputs[o_offset]]
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
# node inputs
node_inputs += [inputs[ni_offset + idx]]
......@@ -459,10 +459,10 @@ def compress_outs(op, not_required, inputs):
)
# input taps
for jdx in op.mit_sot_in_slices[idx]:
op_inputs += [op.inputs[i_offset]]
op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1
# output taps
op_outputs += [op.outputs[o_offset]]
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
# node inputs
node_inputs += [inputs[ni_offset + idx]]
......@@ -481,10 +481,10 @@ def compress_outs(op, not_required, inputs):
sit_sot_in_slices=info.sit_sot_in_slices + (op.sit_sot_in_slices[idx],),
)
# input taps
op_inputs += [op.inputs[i_offset]]
op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1
# output taps
op_outputs += [op.outputs[o_offset]]
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
# node inputs
node_inputs += [inputs[ni_offset + idx]]
......@@ -500,7 +500,7 @@ def compress_outs(op, not_required, inputs):
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.outputs[o_offset]]
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]]
else:
......@@ -513,9 +513,9 @@ def compress_outs(op, not_required, inputs):
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(info, n_shared_outs=info.n_shared_outs + 1)
op_outputs += [op.outputs[o_offset]]
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
op_inputs += [op.inputs[i_offset]]
op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1
shared_ins += [inputs[ni_offset + idx]]
else:
......@@ -524,11 +524,11 @@ def compress_outs(op, not_required, inputs):
node_inputs += shared_ins
node_inputs += nit_sot_ins
# other stuff
op_inputs += op.inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inputs[i_offset:]))
op_inputs += op.inner_inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:]))
node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot :]
if op.info.as_while:
op_outputs += [op.outputs[o_offset]]
op_outputs += [op.inner_outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1
# map_old_new[len(op_outputs)-1] = o_offset
......@@ -706,8 +706,8 @@ class ScanArgs:
return ScanArgs(
node.inputs,
node.outputs,
node.op.inputs,
node.op.outputs,
node.op.inner_inputs,
node.op.inner_outputs,
node.op.info,
clone=clone,
)
......
......@@ -2594,7 +2594,7 @@ def test_inner_get_vector_length():
# Make sure the `size` in `scan_body` is a plain `Variable` instance
# carrying no information with which we can derive its length
size_clone = res.owner.op.inputs[1]
size_clone = res.owner.op.inner_inputs[1]
assert size_clone.owner is None
# Make sure the cloned `size` maps to the original `size_at`
......@@ -2696,9 +2696,6 @@ def test_profile_info():
assert fn.fn.call_counts == [0]
c = scalar("c", dtype="floatX")
class TestExamples:
"""Miscellaneous example-based tests with unnecessarily complicated setups and/or no background information.
......@@ -4021,9 +4018,6 @@ class TestExamples:
self._grad_mout_helper(1, None)
c = scalar("c", dtype="floatX")
@pytest.mark.parametrize(
"fn, sequences, outputs_info, non_sequences, n_steps, op_check",
[
......@@ -4050,7 +4044,7 @@ c = scalar("c", dtype="floatX")
lambda c: at.as_tensor(2.0) * c,
[],
[{}],
[c],
[scalar("c", dtype="floatX")],
3,
lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0,
),
......@@ -4105,7 +4099,7 @@ c = scalar("c", dtype="floatX")
# TODO: mit-mot (can't be created through the `scan` interface)
],
)
def test_n_non_seqs(fn, sequences, outputs_info, non_sequences, n_steps, op_check):
def test_ScanInfo_totals(fn, sequences, outputs_info, non_sequences, n_steps, op_check):
res, _ = scan(
fn,
sequences=sequences,
......@@ -4124,13 +4118,9 @@ def test_n_non_seqs(fn, sequences, outputs_info, non_sequences, n_steps, op_chec
scan_op = res.owner.op
assert isinstance(scan_op, Scan)
# from aesara.scan.utils import ScanArgs
# print(ScanArgs.from_node(res.owner))
# print(res.owner.op.info)
_ = op_check(scan_op)
assert scan_op.info.n_outer_inputs == len(res.owner.inputs)
assert scan_op.info.n_outer_outputs == len(res.owner.outputs)
assert scan_op.info.n_inner_inputs == len(res.owner.op.inputs)
assert scan_op.info.n_inner_outputs == len(res.owner.op.outputs)
assert scan_op.info.n_inner_inputs == len(res.owner.op.inner_inputs)
assert scan_op.info.n_inner_outputs == len(res.owner.op.inner_outputs)
......@@ -95,7 +95,7 @@ class TestRemoveConstantsAndUnusedInputsScan:
scan_node = scan_nodes[0]
assert len(scan_node.inputs[1:]) == len(set(scan_node.inputs[1:]))
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
inp = scan_node.op.inner_non_seqs(scan_node.op.inner_inputs)
assert len(inp) == 1
assert len(inp) == len(set(inp))
......@@ -170,11 +170,11 @@ class TestRemoveConstantsAndUnusedInputsScan:
scan_node = scan_nodes[0]
assert len(scan_node.inputs[1:]) == len(set(scan_node.inputs[1:]))
inp = scan_node.op.inner_seqs(scan_node.op.inputs)
inp = scan_node.op.inner_seqs(scan_node.op.inner_inputs)
assert len(inp) == 1
inp = scan_node.op.outer_seqs(scan_node.inputs)
assert len(inp) == 1
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
inp = scan_node.op.inner_non_seqs(scan_node.op.inner_inputs)
assert len(inp) == 1
inp = scan_node.op.outer_non_seqs(scan_node.inputs)
assert len(inp) == 1
......@@ -445,8 +445,8 @@ class TestPushOutNonSeqScan:
scan_node = [
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
][0]
assert len(scan_node.op.outputs) == 1
assert not isinstance(scan_node.op.outputs[0], Dot)
assert len(scan_node.op.inner_outputs) == 1
assert not isinstance(scan_node.op.inner_outputs[0], Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
......@@ -491,8 +491,8 @@ class TestPushOutNonSeqScan:
][0]
# NOTE: WHEN INFER_SHAPE IS RE-ENABLED, BELOW THE SCAN MUST
# HAVE ONLY 1 OUTPUT.
assert len(scan_node.op.outputs) == 2
assert not isinstance(scan_node.op.outputs[0], Dot)
assert len(scan_node.op.inner_outputs) == 2
assert not isinstance(scan_node.op.inner_outputs[0], Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
......@@ -537,8 +537,8 @@ class TestPushOutNonSeqScan:
scan_node = [
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
][0]
assert len(scan_node.op.outputs) == 2
assert not isinstance(scan_node.op.outputs[0], Dot)
assert len(scan_node.op.inner_outputs) == 2
assert not isinstance(scan_node.op.inner_outputs[0], Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
......@@ -725,7 +725,7 @@ class TestPushOutAddScan:
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
][1]
for output in scan_node_grad.op.outputs:
for output in scan_node_grad.op.inner_outputs:
assert not (
isinstance(output.owner.op, Elemwise)
and any(isinstance(i, Dot) for i in output.owner.inputs)
......@@ -1466,7 +1466,7 @@ def test_alloc_inputs3():
f = function([_h0, _W1, _W2], o, mode="FAST_RUN")
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
assert len(scan_node.op.inputs) == 1
assert len(scan_node.op.inner_inputs) == 1
def test_opt_order():
......
......@@ -101,18 +101,18 @@ def test_ScanArgs():
assert scan_args.n_mit_mot == scan_op.n_mit_mot
# The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inputs
assert scan_args.inputs == scan_op.inner_inputs
assert scan_args.info == scan_op.info
# Check that `ScanArgs.find_among_fields` works
test_v = scan_op.inner_seqs(scan_op.inputs)[1]
test_v = scan_op.inner_seqs(scan_op.inner_inputs)[1]
field_info = scan_args.find_among_fields(test_v)
assert field_info.name == "inner_in_seqs"
assert field_info.index == 1
assert field_info.inner_index is None
assert scan_args.inner_inputs[field_info.agg_index] == test_v
test_l = scan_op.inner_non_seqs(scan_op.inputs)
test_l = scan_op.inner_non_seqs(scan_op.inner_inputs)
# We didn't index this argument, so it's a `list` (i.e. bad input)
field_info = scan_args.find_among_fields(test_l)
assert field_info is None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论