提交 d11f3039 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Completely remove ScanInfo fields from Scan Op

上级 7cfe58c4
...@@ -418,7 +418,7 @@ def jax_funcify_Scan(op, **kwargs): ...@@ -418,7 +418,7 @@ def jax_funcify_Scan(op, **kwargs):
def scan(*outer_inputs): def scan(*outer_inputs):
scan_args = ScanArgs( scan_args = ScanArgs(
list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
) )
# `outer_inputs` is a list with the following composite form: # `outer_inputs` is a list with the following composite form:
......
...@@ -314,12 +314,6 @@ class ScanMethodsMixin: ...@@ -314,12 +314,6 @@ class ScanMethodsMixin:
def outer_mitmot_outs(self, list_outputs): def outer_mitmot_outs(self, list_outputs):
return list_outputs[: self.info.n_mit_mot] return list_outputs[: self.info.n_mit_mot]
def mitmot_taps(self):
return self.info.mit_mot_in_slices
def mitmot_out_taps(self):
return self.info.mit_mot_out_slices[: self.info.n_mit_mot]
def inner_mitsot(self, list_inputs): def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.info.mit_mot_in_slices) n_mitmot_taps = sum(len(x) for x in self.info.mit_mot_in_slices)
ntaps_upto_sit_sot = n_mitmot_taps + sum( ntaps_upto_sit_sot = n_mitmot_taps + sum(
...@@ -342,9 +336,6 @@ class ScanMethodsMixin: ...@@ -342,9 +336,6 @@ class ScanMethodsMixin:
self.info.n_mit_mot : self.info.n_mit_mot + self.info.n_mit_sot self.info.n_mit_mot : self.info.n_mit_mot + self.info.n_mit_sot
] ]
def mitsot_taps(self):
return self.info.mit_sot_in_slices
def inner_sitsot(self, list_inputs): def inner_sitsot(self, list_inputs):
n_taps_upto_sit_sot = sum( n_taps_upto_sit_sot = sum(
len(x) len(x)
...@@ -785,12 +776,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -785,12 +776,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.profile = profile self.profile = profile
self.allow_gc = allow_gc self.allow_gc = allow_gc
self.strict = strict self.strict = strict
self.__dict__.update(dataclasses.asdict(info))
self.n_mit_mot = self.info.n_mit_mot
self.n_mit_mot_outs = self.info.n_mit_mot_outs
self.n_mit_sot = self.info.n_mit_sot
self.n_sit_sot = self.info.n_sit_sot
# Clone mode_instance, altering "allow_gc" for the linker, # Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile # and adding a message if we profile
...@@ -971,8 +956,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -971,8 +956,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1 n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
n_inner_ins = ( n_inner_ins = (
len(self.inner_seqs(self.inner_inputs)) len(self.inner_seqs(self.inner_inputs))
+ len(self.mitmot_taps()) + len(self.info.mit_mot_in_slices)
+ len(self.mitsot_taps()) + len(self.info.mit_sot_in_slices)
+ len(self.inner_sitsot(self.inner_inputs)) + len(self.inner_sitsot(self.inner_inputs))
+ len(self.inner_shared(self.inner_inputs)) + len(self.inner_shared(self.inner_inputs))
+ len(self.inner_non_seqs(self.inner_inputs)) + len(self.inner_non_seqs(self.inner_inputs))
...@@ -1006,7 +991,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1006,7 +991,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_mitmot = self.inner_mitmot(self.inner_inputs) inner_mitmot = self.inner_mitmot(self.inner_inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.inner_outputs) inner_mitmot_outs = self.inner_mitmot_outs(self.inner_outputs)
for idx, (itaps, otaps, _outer_mitmot) in enumerate( for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(), self.mitmot_out_taps(), self.outer_mitmot(inputs)) zip(
self.info.mit_mot_in_slices,
self.info.mit_mot_out_slices[: self.info.n_mit_mot],
self.outer_mitmot(inputs),
)
): ):
outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos]) outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos])
new_inputs.append(outer_mitmot) new_inputs.append(outer_mitmot)
...@@ -1057,7 +1046,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1057,7 +1046,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_mitsots = self.inner_mitsot(self.inner_inputs) inner_mitsots = self.inner_mitsot(self.inner_inputs)
for idx, (itaps, _outer_mitsot, inner_mitsot_out) in enumerate( for idx, (itaps, _outer_mitsot, inner_mitsot_out) in enumerate(
zip( zip(
self.mitsot_taps(), self.info.mit_sot_in_slices,
self.outer_mitsot(inputs), self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.inner_outputs), self.inner_mitsot_outs(self.inner_outputs),
) )
...@@ -1383,9 +1372,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1383,9 +1372,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
output_idx = sum( output_idx = sum(
len(m) for m in info.mit_mot_out_slices[:mitmot_idx] len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
) )
output_idx += self.info.mit_mot_out_slices[mitmot_idx].index( output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
inp_tap
)
preallocated_mitmot_outs.append(output_idx) preallocated_mitmot_outs.append(output_idx)
...@@ -1979,7 +1966,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1979,7 +1966,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if self.mitmots_preallocated[mitmot_out_idx]: if self.mitmots_preallocated[mitmot_out_idx]:
mitmot_inp_idx = mitmot_inp_grp_offset + taps.index(out_slice) mitmot_inp_idx = mitmot_inp_grp_offset + taps.index(out_slice)
inner_inp_idx = self.n_seqs + mitmot_inp_idx inner_inp_idx = info.n_seqs + mitmot_inp_idx
# Verify whether the input points to the same data as # Verify whether the input points to the same data as
# it did before the execution of the inner function. # it did before the execution of the inner function.
...@@ -2455,13 +2442,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2455,13 +2442,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
return 1 + iidx return 1 + iidx
oidx = 1 + info.n_seqs oidx = 1 + info.n_seqs
iidx = iidx - info.n_seqs iidx = iidx - info.n_seqs
for taps in self.mitmot_taps(): for taps in info.mit_mot_in_slices:
if len(taps) > iidx: if len(taps) > iidx:
return oidx return oidx
else: else:
oidx += 1 oidx += 1
iidx -= len(taps) iidx -= len(taps)
for taps in self.mitsot_taps(): for taps in info.mit_sot_in_slices:
if len(taps) > iidx: if len(taps) > iidx:
return oidx return oidx
else: else:
...@@ -2475,7 +2462,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2475,7 +2462,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def get_out_idx(iidx): def get_out_idx(iidx):
oidx = 0 oidx = 0
for taps in self.mitmot_out_taps(): for taps in info.mit_mot_out_slices[: info.n_mit_mot]:
if len(taps) > iidx: if len(taps) > iidx:
return oidx return oidx
else: else:
...@@ -2666,7 +2653,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2666,7 +2653,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
): ):
mintap = min(taps) mintap = min(taps)
if idx < info.n_mit_mot: if idx < info.n_mit_mot:
outmaxtap = np.max(self.mitmot_out_taps()[idx]) outmaxtap = np.max(info.mit_mot_out_slices[: info.n_mit_mot][idx])
else: else:
outmaxtap = 0 outmaxtap = 0
seq = outs[idx] seq = outs[idx]
...@@ -2695,7 +2682,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2695,7 +2682,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n = n_steps.tag.test_value n = n_steps.tag.test_value
else: else:
n = inputs[0].tag.test_value n = inputs[0].tag.test_value
for taps, x in zip(self.mitsot_taps(), self.outer_mitsot_outs(outs)): for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs)):
mintap = np.min(taps) mintap = np.min(taps)
if hasattr(x[::-1][:mintap], "test_value"): if hasattr(x[::-1][:mintap], "test_value"):
assert x[::-1][:mintap].tag.test_value.shape[0] == n assert x[::-1][:mintap].tag.test_value.shape[0] == n
...@@ -2710,7 +2697,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2710,7 +2697,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
assert x[::-1].tag.test_value.shape[0] == n assert x[::-1].tag.test_value.shape[0] == n
outer_inp_seqs += [ outer_inp_seqs += [
x[::-1][: np.min(taps)] x[::-1][: np.min(taps)]
for taps, x in zip(self.mitsot_taps(), self.outer_mitsot_outs(outs)) for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs))
] ]
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)] outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)] outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
......
差异被折叠。
...@@ -361,15 +361,19 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -361,15 +361,19 @@ def scan_can_remove_outs(op, out_idxs):
required_inputs = list(graph_inputs(non_removable)) required_inputs = list(graph_inputs(non_removable))
out_ins = [] out_ins = []
offset = op.n_seqs offset = op.info.n_seqs
for idx, tap in enumerate( for idx, tap in enumerate(
chain(op.mit_mot_in_slices, op.mit_sot_in_slices, op.sit_sot_in_slices) chain(
op.info.mit_mot_in_slices,
op.info.mit_sot_in_slices,
op.info.sit_sot_in_slices,
)
): ):
n_ins = len(tap) n_ins = len(tap)
out_ins += [op.inner_inputs[offset : offset + n_ins]] out_ins += [op.inner_inputs[offset : offset + n_ins]]
offset += n_ins offset += n_ins
out_ins += [[] for k in range(op.n_nit_sot)] out_ins += [[] for k in range(op.info.n_nit_sot)]
out_ins += [[op.inner_inputs[offset + k]] for k in range(op.n_shared_outs)] out_ins += [[op.inner_inputs[offset + k]] for k in range(op.info.n_shared_outs)]
added = True added = True
out_idxs_mask = [1 for idx in out_idxs] out_idxs_mask = [1 for idx in out_idxs]
...@@ -400,8 +404,9 @@ def compress_outs(op, not_required, inputs): ...@@ -400,8 +404,9 @@ def compress_outs(op, not_required, inputs):
""" """
from aesara.scan.op import ScanInfo from aesara.scan.op import ScanInfo
op_info = op.info
info = ScanInfo( info = ScanInfo(
n_seqs=op.info.n_seqs, n_seqs=op_info.n_seqs,
mit_mot_in_slices=(), mit_mot_in_slices=(),
mit_mot_out_slices=(), mit_mot_out_slices=(),
mit_sot_in_slices=(), mit_sot_in_slices=(),
...@@ -409,56 +414,58 @@ def compress_outs(op, not_required, inputs): ...@@ -409,56 +414,58 @@ def compress_outs(op, not_required, inputs):
n_nit_sot=0, n_nit_sot=0,
n_shared_outs=0, n_shared_outs=0,
n_non_seqs=0, n_non_seqs=0,
as_while=op.info.as_while, as_while=op_info.as_while,
) )
op_inputs = op.inner_inputs[: op.n_seqs] op_inputs = op.inner_inputs[: op_info.n_seqs]
op_outputs = [] op_outputs = []
node_inputs = inputs[: op.n_seqs + 1] node_inputs = inputs[: op_info.n_seqs + 1]
map_old_new = OrderedDict() map_old_new = OrderedDict()
offset = 0 offset = 0
ni_offset = op.n_seqs + 1 ni_offset = op_info.n_seqs + 1
i_offset = op.n_seqs i_offset = op_info.n_seqs
o_offset = 0 o_offset = 0
curr_pos = 0 curr_pos = 0
for idx in range(op.info.n_mit_mot): for idx in range(op_info.n_mit_mot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info = dataclasses.replace( info = dataclasses.replace(
info, info,
mit_mot_in_slices=info.mit_mot_in_slices + (op.mit_mot_in_slices[idx],), mit_mot_in_slices=info.mit_mot_in_slices
+ (op_info.mit_mot_in_slices[idx],),
mit_mot_out_slices=info.mit_mot_out_slices mit_mot_out_slices=info.mit_mot_out_slices
+ (op.mit_mot_out_slices[idx],), + (op_info.mit_mot_out_slices[idx],),
) )
# input taps # input taps
for jdx in op.mit_mot_in_slices[idx]: for jdx in op_info.mit_mot_in_slices[idx]:
op_inputs += [op.inner_inputs[i_offset]] op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1 i_offset += 1
# output taps # output taps
for jdx in op.mit_mot_out_slices[idx]: for jdx in op_info.mit_mot_out_slices[idx]:
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1 o_offset += 1
# node inputs # node inputs
node_inputs += [inputs[ni_offset + idx]] node_inputs += [inputs[ni_offset + idx]]
else: else:
o_offset += len(op.mit_mot_out_slices[idx]) o_offset += len(op_info.mit_mot_out_slices[idx])
i_offset += len(op.mit_mot_in_slices[idx]) i_offset += len(op_info.mit_mot_in_slices[idx])
offset += op.n_mit_mot offset += op_info.n_mit_mot
ni_offset += op.n_mit_mot ni_offset += op_info.n_mit_mot
for idx in range(op.info.n_mit_sot): for idx in range(op_info.n_mit_sot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info = dataclasses.replace( info = dataclasses.replace(
info, info,
mit_sot_in_slices=info.mit_sot_in_slices + (op.mit_sot_in_slices[idx],), mit_sot_in_slices=info.mit_sot_in_slices
+ (op_info.mit_sot_in_slices[idx],),
) )
# input taps # input taps
for jdx in op.mit_sot_in_slices[idx]: for jdx in op_info.mit_sot_in_slices[idx]:
op_inputs += [op.inner_inputs[i_offset]] op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1 i_offset += 1
# output taps # output taps
...@@ -468,17 +475,18 @@ def compress_outs(op, not_required, inputs): ...@@ -468,17 +475,18 @@ def compress_outs(op, not_required, inputs):
node_inputs += [inputs[ni_offset + idx]] node_inputs += [inputs[ni_offset + idx]]
else: else:
o_offset += 1 o_offset += 1
i_offset += len(op.mit_sot_in_slices[idx]) i_offset += len(op_info.mit_sot_in_slices[idx])
offset += op.n_mit_sot offset += op_info.n_mit_sot
ni_offset += op.n_mit_sot ni_offset += op_info.n_mit_sot
for idx in range(op.info.n_sit_sot): for idx in range(op_info.n_sit_sot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info = dataclasses.replace( info = dataclasses.replace(
info, info,
sit_sot_in_slices=info.sit_sot_in_slices + (op.sit_sot_in_slices[idx],), sit_sot_in_slices=info.sit_sot_in_slices
+ (op_info.sit_sot_in_slices[idx],),
) )
# input taps # input taps
op_inputs += [op.inner_inputs[i_offset]] op_inputs += [op.inner_inputs[i_offset]]
...@@ -492,23 +500,23 @@ def compress_outs(op, not_required, inputs): ...@@ -492,23 +500,23 @@ def compress_outs(op, not_required, inputs):
o_offset += 1 o_offset += 1
i_offset += 1 i_offset += 1
offset += op.n_sit_sot offset += op_info.n_sit_sot
ni_offset += op.n_sit_sot ni_offset += op_info.n_sit_sot
nit_sot_ins = [] nit_sot_ins = []
for idx in range(op.info.n_nit_sot): for idx in range(op_info.n_nit_sot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1) info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1 o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]] nit_sot_ins += [inputs[ni_offset + idx + op_info.n_shared_outs]]
else: else:
o_offset += 1 o_offset += 1
offset += op.n_nit_sot offset += op_info.n_nit_sot
shared_ins = [] shared_ins = []
for idx in range(op.info.n_shared_outs): for idx in range(op_info.n_shared_outs):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
...@@ -526,8 +534,8 @@ def compress_outs(op, not_required, inputs): ...@@ -526,8 +534,8 @@ def compress_outs(op, not_required, inputs):
# other stuff # other stuff
op_inputs += op.inner_inputs[i_offset:] op_inputs += op.inner_inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:])) info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:]))
node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot :] node_inputs += inputs[ni_offset + op_info.n_shared_outs + op_info.n_nit_sot :]
if op.info.as_while: if op_info.as_while:
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1 map_old_new[o_offset] = len(op_outputs) - 1
# map_old_new[len(op_outputs)-1] = o_offset # map_old_new[len(op_outputs)-1] = o_offset
......
...@@ -97,8 +97,8 @@ def test_ScanArgs(): ...@@ -97,8 +97,8 @@ def test_ScanArgs():
# Check the properties that allow us to use # Check the properties that allow us to use
# `Scan.get_oinp_iinp_iout_oout_mappings` as-is to implement # `Scan.get_oinp_iinp_iout_oout_mappings` as-is to implement
# `ScanArgs.var_mappings` # `ScanArgs.var_mappings`
assert scan_args.n_nit_sot == scan_op.n_nit_sot assert scan_args.n_nit_sot == scan_op.info.n_nit_sot
assert scan_args.n_mit_mot == scan_op.n_mit_mot assert scan_args.n_mit_mot == scan_op.info.n_mit_mot
# The `scan_args` base class always clones the inner-graph; # The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same) # here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inner_inputs assert scan_args.inputs == scan_op.inner_inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论