提交 d11f3039 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Completely remove ScanInfo fields from Scan Op

上级 7cfe58c4
......@@ -418,7 +418,7 @@ def jax_funcify_Scan(op, **kwargs):
def scan(*outer_inputs):
scan_args = ScanArgs(
list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
)
# `outer_inputs` is a list with the following composite form:
......
......@@ -314,12 +314,6 @@ class ScanMethodsMixin:
def outer_mitmot_outs(self, list_outputs):
return list_outputs[: self.info.n_mit_mot]
def mitmot_taps(self):
return self.info.mit_mot_in_slices
def mitmot_out_taps(self):
return self.info.mit_mot_out_slices[: self.info.n_mit_mot]
def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.info.mit_mot_in_slices)
ntaps_upto_sit_sot = n_mitmot_taps + sum(
......@@ -342,9 +336,6 @@ class ScanMethodsMixin:
self.info.n_mit_mot : self.info.n_mit_mot + self.info.n_mit_sot
]
def mitsot_taps(self):
return self.info.mit_sot_in_slices
def inner_sitsot(self, list_inputs):
n_taps_upto_sit_sot = sum(
len(x)
......@@ -785,12 +776,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.profile = profile
self.allow_gc = allow_gc
self.strict = strict
self.__dict__.update(dataclasses.asdict(info))
self.n_mit_mot = self.info.n_mit_mot
self.n_mit_mot_outs = self.info.n_mit_mot_outs
self.n_mit_sot = self.info.n_mit_sot
self.n_sit_sot = self.info.n_sit_sot
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
......@@ -971,8 +956,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
n_inner_ins = (
len(self.inner_seqs(self.inner_inputs))
+ len(self.mitmot_taps())
+ len(self.mitsot_taps())
+ len(self.info.mit_mot_in_slices)
+ len(self.info.mit_sot_in_slices)
+ len(self.inner_sitsot(self.inner_inputs))
+ len(self.inner_shared(self.inner_inputs))
+ len(self.inner_non_seqs(self.inner_inputs))
......@@ -1006,7 +991,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_mitmot = self.inner_mitmot(self.inner_inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.inner_outputs)
for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(), self.mitmot_out_taps(), self.outer_mitmot(inputs))
zip(
self.info.mit_mot_in_slices,
self.info.mit_mot_out_slices[: self.info.n_mit_mot],
self.outer_mitmot(inputs),
)
):
outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos])
new_inputs.append(outer_mitmot)
......@@ -1057,7 +1046,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_mitsots = self.inner_mitsot(self.inner_inputs)
for idx, (itaps, _outer_mitsot, inner_mitsot_out) in enumerate(
zip(
self.mitsot_taps(),
self.info.mit_sot_in_slices,
self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.inner_outputs),
)
......@@ -1383,9 +1372,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
output_idx = sum(
len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
)
output_idx += self.info.mit_mot_out_slices[mitmot_idx].index(
inp_tap
)
output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
preallocated_mitmot_outs.append(output_idx)
......@@ -1979,7 +1966,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if self.mitmots_preallocated[mitmot_out_idx]:
mitmot_inp_idx = mitmot_inp_grp_offset + taps.index(out_slice)
inner_inp_idx = self.n_seqs + mitmot_inp_idx
inner_inp_idx = info.n_seqs + mitmot_inp_idx
# Verify whether the input points to the same data as
# it did before the execution of the inner function.
......@@ -2455,13 +2442,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
return 1 + iidx
oidx = 1 + info.n_seqs
iidx = iidx - info.n_seqs
for taps in self.mitmot_taps():
for taps in info.mit_mot_in_slices:
if len(taps) > iidx:
return oidx
else:
oidx += 1
iidx -= len(taps)
for taps in self.mitsot_taps():
for taps in info.mit_sot_in_slices:
if len(taps) > iidx:
return oidx
else:
......@@ -2475,7 +2462,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def get_out_idx(iidx):
oidx = 0
for taps in self.mitmot_out_taps():
for taps in info.mit_mot_out_slices[: info.n_mit_mot]:
if len(taps) > iidx:
return oidx
else:
......@@ -2666,7 +2653,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
):
mintap = min(taps)
if idx < info.n_mit_mot:
outmaxtap = np.max(self.mitmot_out_taps()[idx])
outmaxtap = np.max(info.mit_mot_out_slices[: info.n_mit_mot][idx])
else:
outmaxtap = 0
seq = outs[idx]
......@@ -2695,7 +2682,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n = n_steps.tag.test_value
else:
n = inputs[0].tag.test_value
for taps, x in zip(self.mitsot_taps(), self.outer_mitsot_outs(outs)):
for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs)):
mintap = np.min(taps)
if hasattr(x[::-1][:mintap], "test_value"):
assert x[::-1][:mintap].tag.test_value.shape[0] == n
......@@ -2710,7 +2697,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
assert x[::-1].tag.test_value.shape[0] == n
outer_inp_seqs += [
x[::-1][: np.min(taps)]
for taps, x in zip(self.mitsot_taps(), self.outer_mitsot_outs(outs))
for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs))
]
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
......
差异被折叠。
......@@ -361,15 +361,19 @@ def scan_can_remove_outs(op, out_idxs):
required_inputs = list(graph_inputs(non_removable))
out_ins = []
offset = op.n_seqs
offset = op.info.n_seqs
for idx, tap in enumerate(
chain(op.mit_mot_in_slices, op.mit_sot_in_slices, op.sit_sot_in_slices)
chain(
op.info.mit_mot_in_slices,
op.info.mit_sot_in_slices,
op.info.sit_sot_in_slices,
)
):
n_ins = len(tap)
out_ins += [op.inner_inputs[offset : offset + n_ins]]
offset += n_ins
out_ins += [[] for k in range(op.n_nit_sot)]
out_ins += [[op.inner_inputs[offset + k]] for k in range(op.n_shared_outs)]
out_ins += [[] for k in range(op.info.n_nit_sot)]
out_ins += [[op.inner_inputs[offset + k]] for k in range(op.info.n_shared_outs)]
added = True
out_idxs_mask = [1 for idx in out_idxs]
......@@ -400,8 +404,9 @@ def compress_outs(op, not_required, inputs):
"""
from aesara.scan.op import ScanInfo
op_info = op.info
info = ScanInfo(
n_seqs=op.info.n_seqs,
n_seqs=op_info.n_seqs,
mit_mot_in_slices=(),
mit_mot_out_slices=(),
mit_sot_in_slices=(),
......@@ -409,56 +414,58 @@ def compress_outs(op, not_required, inputs):
n_nit_sot=0,
n_shared_outs=0,
n_non_seqs=0,
as_while=op.info.as_while,
as_while=op_info.as_while,
)
op_inputs = op.inner_inputs[: op.n_seqs]
op_inputs = op.inner_inputs[: op_info.n_seqs]
op_outputs = []
node_inputs = inputs[: op.n_seqs + 1]
node_inputs = inputs[: op_info.n_seqs + 1]
map_old_new = OrderedDict()
offset = 0
ni_offset = op.n_seqs + 1
i_offset = op.n_seqs
ni_offset = op_info.n_seqs + 1
i_offset = op_info.n_seqs
o_offset = 0
curr_pos = 0
for idx in range(op.info.n_mit_mot):
for idx in range(op_info.n_mit_mot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(
info,
mit_mot_in_slices=info.mit_mot_in_slices + (op.mit_mot_in_slices[idx],),
mit_mot_in_slices=info.mit_mot_in_slices
+ (op_info.mit_mot_in_slices[idx],),
mit_mot_out_slices=info.mit_mot_out_slices
+ (op.mit_mot_out_slices[idx],),
+ (op_info.mit_mot_out_slices[idx],),
)
# input taps
for jdx in op.mit_mot_in_slices[idx]:
for jdx in op_info.mit_mot_in_slices[idx]:
op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1
# output taps
for jdx in op.mit_mot_out_slices[idx]:
for jdx in op_info.mit_mot_out_slices[idx]:
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
# node inputs
node_inputs += [inputs[ni_offset + idx]]
else:
o_offset += len(op.mit_mot_out_slices[idx])
i_offset += len(op.mit_mot_in_slices[idx])
o_offset += len(op_info.mit_mot_out_slices[idx])
i_offset += len(op_info.mit_mot_in_slices[idx])
offset += op.n_mit_mot
ni_offset += op.n_mit_mot
offset += op_info.n_mit_mot
ni_offset += op_info.n_mit_mot
for idx in range(op.info.n_mit_sot):
for idx in range(op_info.n_mit_sot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(
info,
mit_sot_in_slices=info.mit_sot_in_slices + (op.mit_sot_in_slices[idx],),
mit_sot_in_slices=info.mit_sot_in_slices
+ (op_info.mit_sot_in_slices[idx],),
)
# input taps
for jdx in op.mit_sot_in_slices[idx]:
for jdx in op_info.mit_sot_in_slices[idx]:
op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1
# output taps
......@@ -468,17 +475,18 @@ def compress_outs(op, not_required, inputs):
node_inputs += [inputs[ni_offset + idx]]
else:
o_offset += 1
i_offset += len(op.mit_sot_in_slices[idx])
i_offset += len(op_info.mit_sot_in_slices[idx])
offset += op.n_mit_sot
ni_offset += op.n_mit_sot
for idx in range(op.info.n_sit_sot):
offset += op_info.n_mit_sot
ni_offset += op_info.n_mit_sot
for idx in range(op_info.n_sit_sot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(
info,
sit_sot_in_slices=info.sit_sot_in_slices + (op.sit_sot_in_slices[idx],),
sit_sot_in_slices=info.sit_sot_in_slices
+ (op_info.sit_sot_in_slices[idx],),
)
# input taps
op_inputs += [op.inner_inputs[i_offset]]
......@@ -492,23 +500,23 @@ def compress_outs(op, not_required, inputs):
o_offset += 1
i_offset += 1
offset += op.n_sit_sot
ni_offset += op.n_sit_sot
offset += op_info.n_sit_sot
ni_offset += op_info.n_sit_sot
nit_sot_ins = []
for idx in range(op.info.n_nit_sot):
for idx in range(op_info.n_nit_sot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]]
nit_sot_ins += [inputs[ni_offset + idx + op_info.n_shared_outs]]
else:
o_offset += 1
offset += op.n_nit_sot
offset += op_info.n_nit_sot
shared_ins = []
for idx in range(op.info.n_shared_outs):
for idx in range(op_info.n_shared_outs):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
......@@ -526,8 +534,8 @@ def compress_outs(op, not_required, inputs):
# other stuff
op_inputs += op.inner_inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:]))
node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot :]
if op.info.as_while:
node_inputs += inputs[ni_offset + op_info.n_shared_outs + op_info.n_nit_sot :]
if op_info.as_while:
op_outputs += [op.inner_outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1
# map_old_new[len(op_outputs)-1] = o_offset
......
......@@ -97,8 +97,8 @@ def test_ScanArgs():
# Check the properties that allow us to use
# `Scan.get_oinp_iinp_iout_oout_mappings` as-is to implement
# `ScanArgs.var_mappings`
assert scan_args.n_nit_sot == scan_op.n_nit_sot
assert scan_args.n_mit_mot == scan_op.n_mit_mot
assert scan_args.n_nit_sot == scan_op.info.n_nit_sot
assert scan_args.n_mit_mot == scan_op.info.n_mit_mot
# The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inner_inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论