提交 d11f3039 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Completely remove ScanInfo fields from Scan Op

上级 7cfe58c4
...@@ -418,7 +418,7 @@ def jax_funcify_Scan(op, **kwargs): ...@@ -418,7 +418,7 @@ def jax_funcify_Scan(op, **kwargs):
def scan(*outer_inputs): def scan(*outer_inputs):
scan_args = ScanArgs( scan_args = ScanArgs(
list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
) )
# `outer_inputs` is a list with the following composite form: # `outer_inputs` is a list with the following composite form:
......
...@@ -314,12 +314,6 @@ class ScanMethodsMixin: ...@@ -314,12 +314,6 @@ class ScanMethodsMixin:
def outer_mitmot_outs(self, list_outputs): def outer_mitmot_outs(self, list_outputs):
return list_outputs[: self.info.n_mit_mot] return list_outputs[: self.info.n_mit_mot]
def mitmot_taps(self):
return self.info.mit_mot_in_slices
def mitmot_out_taps(self):
return self.info.mit_mot_out_slices[: self.info.n_mit_mot]
def inner_mitsot(self, list_inputs): def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.info.mit_mot_in_slices) n_mitmot_taps = sum(len(x) for x in self.info.mit_mot_in_slices)
ntaps_upto_sit_sot = n_mitmot_taps + sum( ntaps_upto_sit_sot = n_mitmot_taps + sum(
...@@ -342,9 +336,6 @@ class ScanMethodsMixin: ...@@ -342,9 +336,6 @@ class ScanMethodsMixin:
self.info.n_mit_mot : self.info.n_mit_mot + self.info.n_mit_sot self.info.n_mit_mot : self.info.n_mit_mot + self.info.n_mit_sot
] ]
def mitsot_taps(self):
return self.info.mit_sot_in_slices
def inner_sitsot(self, list_inputs): def inner_sitsot(self, list_inputs):
n_taps_upto_sit_sot = sum( n_taps_upto_sit_sot = sum(
len(x) len(x)
...@@ -785,12 +776,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -785,12 +776,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.profile = profile self.profile = profile
self.allow_gc = allow_gc self.allow_gc = allow_gc
self.strict = strict self.strict = strict
self.__dict__.update(dataclasses.asdict(info))
self.n_mit_mot = self.info.n_mit_mot
self.n_mit_mot_outs = self.info.n_mit_mot_outs
self.n_mit_sot = self.info.n_mit_sot
self.n_sit_sot = self.info.n_sit_sot
# Clone mode_instance, altering "allow_gc" for the linker, # Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile # and adding a message if we profile
...@@ -971,8 +956,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -971,8 +956,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1 n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
n_inner_ins = ( n_inner_ins = (
len(self.inner_seqs(self.inner_inputs)) len(self.inner_seqs(self.inner_inputs))
+ len(self.mitmot_taps()) + len(self.info.mit_mot_in_slices)
+ len(self.mitsot_taps()) + len(self.info.mit_sot_in_slices)
+ len(self.inner_sitsot(self.inner_inputs)) + len(self.inner_sitsot(self.inner_inputs))
+ len(self.inner_shared(self.inner_inputs)) + len(self.inner_shared(self.inner_inputs))
+ len(self.inner_non_seqs(self.inner_inputs)) + len(self.inner_non_seqs(self.inner_inputs))
...@@ -1006,7 +991,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1006,7 +991,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_mitmot = self.inner_mitmot(self.inner_inputs) inner_mitmot = self.inner_mitmot(self.inner_inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.inner_outputs) inner_mitmot_outs = self.inner_mitmot_outs(self.inner_outputs)
for idx, (itaps, otaps, _outer_mitmot) in enumerate( for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(), self.mitmot_out_taps(), self.outer_mitmot(inputs)) zip(
self.info.mit_mot_in_slices,
self.info.mit_mot_out_slices[: self.info.n_mit_mot],
self.outer_mitmot(inputs),
)
): ):
outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos]) outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos])
new_inputs.append(outer_mitmot) new_inputs.append(outer_mitmot)
...@@ -1057,7 +1046,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1057,7 +1046,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_mitsots = self.inner_mitsot(self.inner_inputs) inner_mitsots = self.inner_mitsot(self.inner_inputs)
for idx, (itaps, _outer_mitsot, inner_mitsot_out) in enumerate( for idx, (itaps, _outer_mitsot, inner_mitsot_out) in enumerate(
zip( zip(
self.mitsot_taps(), self.info.mit_sot_in_slices,
self.outer_mitsot(inputs), self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.inner_outputs), self.inner_mitsot_outs(self.inner_outputs),
) )
...@@ -1383,9 +1372,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1383,9 +1372,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
output_idx = sum( output_idx = sum(
len(m) for m in info.mit_mot_out_slices[:mitmot_idx] len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
) )
output_idx += self.info.mit_mot_out_slices[mitmot_idx].index( output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
inp_tap
)
preallocated_mitmot_outs.append(output_idx) preallocated_mitmot_outs.append(output_idx)
...@@ -1979,7 +1966,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1979,7 +1966,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if self.mitmots_preallocated[mitmot_out_idx]: if self.mitmots_preallocated[mitmot_out_idx]:
mitmot_inp_idx = mitmot_inp_grp_offset + taps.index(out_slice) mitmot_inp_idx = mitmot_inp_grp_offset + taps.index(out_slice)
inner_inp_idx = self.n_seqs + mitmot_inp_idx inner_inp_idx = info.n_seqs + mitmot_inp_idx
# Verify whether the input points to the same data as # Verify whether the input points to the same data as
# it did before the execution of the inner function. # it did before the execution of the inner function.
...@@ -2455,13 +2442,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2455,13 +2442,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
return 1 + iidx return 1 + iidx
oidx = 1 + info.n_seqs oidx = 1 + info.n_seqs
iidx = iidx - info.n_seqs iidx = iidx - info.n_seqs
for taps in self.mitmot_taps(): for taps in info.mit_mot_in_slices:
if len(taps) > iidx: if len(taps) > iidx:
return oidx return oidx
else: else:
oidx += 1 oidx += 1
iidx -= len(taps) iidx -= len(taps)
for taps in self.mitsot_taps(): for taps in info.mit_sot_in_slices:
if len(taps) > iidx: if len(taps) > iidx:
return oidx return oidx
else: else:
...@@ -2475,7 +2462,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2475,7 +2462,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def get_out_idx(iidx): def get_out_idx(iidx):
oidx = 0 oidx = 0
for taps in self.mitmot_out_taps(): for taps in info.mit_mot_out_slices[: info.n_mit_mot]:
if len(taps) > iidx: if len(taps) > iidx:
return oidx return oidx
else: else:
...@@ -2666,7 +2653,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2666,7 +2653,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
): ):
mintap = min(taps) mintap = min(taps)
if idx < info.n_mit_mot: if idx < info.n_mit_mot:
outmaxtap = np.max(self.mitmot_out_taps()[idx]) outmaxtap = np.max(info.mit_mot_out_slices[: info.n_mit_mot][idx])
else: else:
outmaxtap = 0 outmaxtap = 0
seq = outs[idx] seq = outs[idx]
...@@ -2695,7 +2682,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2695,7 +2682,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n = n_steps.tag.test_value n = n_steps.tag.test_value
else: else:
n = inputs[0].tag.test_value n = inputs[0].tag.test_value
for taps, x in zip(self.mitsot_taps(), self.outer_mitsot_outs(outs)): for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs)):
mintap = np.min(taps) mintap = np.min(taps)
if hasattr(x[::-1][:mintap], "test_value"): if hasattr(x[::-1][:mintap], "test_value"):
assert x[::-1][:mintap].tag.test_value.shape[0] == n assert x[::-1][:mintap].tag.test_value.shape[0] == n
...@@ -2710,7 +2697,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2710,7 +2697,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
assert x[::-1].tag.test_value.shape[0] == n assert x[::-1].tag.test_value.shape[0] == n
outer_inp_seqs += [ outer_inp_seqs += [
x[::-1][: np.min(taps)] x[::-1][: np.min(taps)]
for taps, x in zip(self.mitsot_taps(), self.outer_mitsot_outs(outs)) for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs))
] ]
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)] outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)] outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
......
...@@ -81,31 +81,34 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -81,31 +81,34 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
if not isinstance(node.op, Scan): if not isinstance(node.op, Scan):
return False return False
op = node.op op = node.op
op_info = op.info
# We only need to take care of sequences and other arguments # We only need to take care of sequences and other arguments
st = op.n_seqs st = op_info.n_seqs
st += int(sum(len(x) for x in chain(op.mit_mot_in_slices, op.mit_sot_in_slices))) st += int(
st += op.n_sit_sot sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices))
st += op.n_shared_outs )
st += op_info.n_sit_sot
st += op_info.n_shared_outs
op_ins = op.inner_inputs op_ins = op.inner_inputs
op_outs = op.inner_outputs op_outs = op.inner_outputs
# Corresponds to the initial states, which should stay untouched. # Corresponds to the initial states, which should stay untouched.
# We put those variables aside, and put them back at the end. # We put those variables aside, and put them back at the end.
out_stuff_inner = op_ins[op.n_seqs : st] out_stuff_inner = op_ins[op_info.n_seqs : st]
non_seqs = op_ins[st:] non_seqs = op_ins[st:]
st = ( st = (
op.n_seqs op_info.n_seqs
+ op.n_mit_mot + op_info.n_mit_mot
+ op.n_mit_sot + op_info.n_mit_sot
+ op.n_sit_sot + op_info.n_sit_sot
+ op.n_nit_sot + op_info.n_nit_sot
+ op.n_shared_outs + op_info.n_shared_outs
+ 1 + 1
) )
outer_non_seqs = node.inputs[st:] outer_non_seqs = node.inputs[st:]
out_stuff_outer = node.inputs[1 + op.n_seqs : st] out_stuff_outer = node.inputs[1 + op_info.n_seqs : st]
# To replace constants in the outer graph by clones in the inner graph # To replace constants in the outer graph by clones in the inner graph
givens = {} givens = {}
...@@ -115,7 +118,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -115,7 +118,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
nw_outer = [node.inputs[0]] nw_outer = [node.inputs[0]]
all_ins = list(graph_inputs(op_outs)) all_ins = list(graph_inputs(op_outs))
for idx in range(op.n_seqs): for idx in range(op_info.n_seqs):
node_inp = node.inputs[idx + 1] node_inp = node.inputs[idx + 1]
if ( if (
isinstance(node_inp, TensorConstant) isinstance(node_inp, TensorConstant)
...@@ -170,7 +173,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -170,7 +173,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
if len(nw_inner) != len(op_ins): if len(nw_inner) != len(op_ins):
op_outs = clone_replace(op_outs, replace=givens) op_outs = clone_replace(op_outs, replace=givens)
nw_info = dataclasses.replace( nw_info = dataclasses.replace(
op.info, n_seqs=nw_n_seqs, n_non_seqs=len(nw_inner_nonseq) op_info, n_seqs=nw_n_seqs, n_non_seqs=len(nw_inner_nonseq)
) )
nwScan = Scan( nwScan = Scan(
nw_inner, nw_inner,
...@@ -615,7 +618,7 @@ def push_out_seq_scan(fgraph, node): ...@@ -615,7 +618,7 @@ def push_out_seq_scan(fgraph, node):
if out in op.inner_mitsot_outs(ls): if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out) odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node.inputs)[odx] inp = op.outer_mitsot(node.inputs)[odx]
st = abs(np.min(op.mitsot_taps())) st = abs(np.min(op.info.mit_sot_in_slices))
y = set_subtensor(inp[st:], _y) y = set_subtensor(inp[st:], _y)
elif out in op.inner_sitsot_outs(ls): elif out in op.inner_sitsot_outs(ls):
odx = op.inner_sitsot_outs(ls).index(out) odx = op.inner_sitsot_outs(ls).index(out)
...@@ -953,7 +956,7 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -953,7 +956,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
op = node.op op = node.op
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[: 1 + op.n_seqs] ls_begin = node.inputs[: 1 + op.info.n_seqs]
ls = op.outer_mitmot(node.inputs) ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node.inputs) ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node.inputs) ls += op.outer_sitsot(node.inputs)
...@@ -1044,7 +1047,7 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1044,7 +1047,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# operate inplace. # operate inplace.
out_indices = [] out_indices = []
for out_idx in range(n_outs): for out_idx in range(n_outs):
inp_idx = 1 + op.n_seqs + out_idx inp_idx = 1 + op.info.n_seqs + out_idx
inp = original_node.inputs[inp_idx] inp = original_node.inputs[inp_idx]
# If the input is from an eligible allocation node, attempt to # If the input is from an eligible allocation node, attempt to
...@@ -1153,11 +1156,16 @@ def save_mem_new_scan(fgraph, node): ...@@ -1153,11 +1156,16 @@ def save_mem_new_scan(fgraph, node):
# defining ``init_l`` for mit_mot sequences is a bit trickier but # defining ``init_l`` for mit_mot sequences is a bit trickier but
# it is safe to set it to 0 # it is safe to set it to 0
op = node.op op = node.op
c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot op_info = op.info
c_outs = (
op_info.n_mit_mot + op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
)
init_l = [0 for x in range(op.n_mit_mot)] init_l = [0 for x in range(op_info.n_mit_mot)]
init_l += [abs(min(v)) for v in chain(op.mit_sot_in_slices, op.sit_sot_in_slices)] init_l += [
init_l += [0 for x in range(op.n_nit_sot)] abs(min(v)) for v in chain(op_info.mit_sot_in_slices, op_info.sit_sot_in_slices)
]
init_l += [0 for x in range(op_info.n_nit_sot)]
# 2. Check the clients of each output and see for how many steps # 2. Check the clients of each output and see for how many steps
# does scan need to run # does scan need to run
...@@ -1198,8 +1206,8 @@ def save_mem_new_scan(fgraph, node): ...@@ -1198,8 +1206,8 @@ def save_mem_new_scan(fgraph, node):
# Note that for mit_mot outputs and shared outputs we can not change # Note that for mit_mot outputs and shared outputs we can not change
# the number of intermediate steps stored without affecting the # the number of intermediate steps stored without affecting the
# result of the op # result of the op
store_steps = [0 for o in range(op.n_mit_mot)] store_steps = [0 for o in range(op_info.n_mit_mot)]
store_steps += [-1 for o in node.outputs[op.n_mit_mot : c_outs]] store_steps += [-1 for o in node.outputs[op_info.n_mit_mot : c_outs]]
# Flag that says if an input has changed and we need to do something # Flag that says if an input has changed and we need to do something
# or not # or not
flag_store = False flag_store = False
...@@ -1237,7 +1245,7 @@ def save_mem_new_scan(fgraph, node): ...@@ -1237,7 +1245,7 @@ def save_mem_new_scan(fgraph, node):
break break
# 2.3.2 extract the begin/end of the first dimension # 2.3.2 extract the begin/end of the first dimension
if i >= op.n_mit_mot: if i >= op_info.n_mit_mot:
try: try:
length = shape_of[out][0] length = shape_of[out][0]
except KeyError: except KeyError:
...@@ -1339,7 +1347,7 @@ def save_mem_new_scan(fgraph, node): ...@@ -1339,7 +1347,7 @@ def save_mem_new_scan(fgraph, node):
store_steps[i] = 0 store_steps[i] = 0
break break
if i > op.n_mit_mot: if i > op_info.n_mit_mot:
length = node.inputs[0] + init_l[i] length = node.inputs[0] + init_l[i]
else: else:
try: try:
...@@ -1367,9 +1375,9 @@ def save_mem_new_scan(fgraph, node): ...@@ -1367,9 +1375,9 @@ def save_mem_new_scan(fgraph, node):
# the pre-allocation mechanism is activated. # the pre-allocation mechanism is activated.
prealloc_outs = config.scan__allow_output_prealloc prealloc_outs = config.scan__allow_output_prealloc
first_mitsot_idx = node.op.n_mit_mot first_mitsot_idx = op_info.n_mit_mot
last_sitsot_idx = ( last_sitsot_idx = (
node.op.n_mit_mot + node.op.n_mit_sot + node.op.n_sit_sot - 1 op_info.n_mit_mot + op_info.n_mit_sot + op_info.n_sit_sot - 1
) )
preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx
...@@ -1401,18 +1409,18 @@ def save_mem_new_scan(fgraph, node): ...@@ -1401,18 +1409,18 @@ def save_mem_new_scan(fgraph, node):
# to store everything in memory ( or ar orphane and required # to store everything in memory ( or ar orphane and required
# by the inner function .. ) # by the inner function .. )
replaced_outs = [] replaced_outs = []
offset = 1 + op.n_seqs + op.n_mit_mot offset = 1 + op_info.n_seqs + op_info.n_mit_mot
for idx, _val in enumerate(store_steps[op.n_mit_mot :]): for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]):
i = idx + op.n_mit_mot i = idx + op_info.n_mit_mot
if not (isinstance(_val, int) and _val <= 0 and i not in required): if not (isinstance(_val, int) and _val <= 0 and i not in required):
if idx + op.n_mit_mot in required: if idx + op_info.n_mit_mot in required:
val = 1 val = 1
else: else:
val = _val val = _val
# If the memory for this output has been pre-allocated # If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node) # before going into the scan op (by an alloc node)
if idx < op.n_mit_sot + op.n_sit_sot: if idx < op_info.n_mit_sot + op_info.n_sit_sot:
# In case the input is still an alloc node, we # In case the input is still an alloc node, we
# actually have two options: # actually have two options:
# a) the input is a set_subtensor, in that case we # a) the input is a set_subtensor, in that case we
...@@ -1442,8 +1450,8 @@ def save_mem_new_scan(fgraph, node): ...@@ -1442,8 +1450,8 @@ def save_mem_new_scan(fgraph, node):
nw_input = nw_inputs[offset + idx][:tmp] nw_input = nw_inputs[offset + idx][:tmp]
nw_inputs[offset + idx] = nw_input nw_inputs[offset + idx] = nw_input
replaced_outs.append(op.n_mit_mot + idx) replaced_outs.append(op_info.n_mit_mot + idx)
odx = op.n_mit_mot + idx odx = op_info.n_mit_mot + idx
old_outputs += [ old_outputs += [
( (
odx, odx,
...@@ -1454,12 +1462,18 @@ def save_mem_new_scan(fgraph, node): ...@@ -1454,12 +1462,18 @@ def save_mem_new_scan(fgraph, node):
) )
] ]
# If there is no memory pre-allocated for this output # If there is no memory pre-allocated for this output
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot: elif idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot:
pos = op.n_mit_mot + idx + op.n_seqs + 1 + op.n_shared_outs pos = (
op_info.n_mit_mot
+ idx
+ op_info.n_seqs
+ 1
+ op_info.n_shared_outs
)
if nw_inputs[pos] == node.inputs[0]: if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = val nw_inputs[pos] = val
odx = op.n_mit_mot + idx odx = op_info.n_mit_mot + idx
replaced_outs.append(odx) replaced_outs.append(odx)
old_outputs += [ old_outputs += [
( (
...@@ -1473,14 +1487,14 @@ def save_mem_new_scan(fgraph, node): ...@@ -1473,14 +1487,14 @@ def save_mem_new_scan(fgraph, node):
# 3.4. Recompute inputs for everything else based on the new # 3.4. Recompute inputs for everything else based on the new
# number of steps # number of steps
if global_nsteps is not None: if global_nsteps is not None:
for idx, val in enumerate(store_steps[op.n_mit_mot :]): for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
if val == 0: if val == 0:
# val == 0 means that we want to keep all intermediate # val == 0 means that we want to keep all intermediate
# results for that state, including the initial values. # results for that state, including the initial values.
if idx < op.n_mit_sot + op.n_sit_sot: if idx < op_info.n_mit_sot + op_info.n_sit_sot:
in_idx = offset + idx in_idx = offset + idx
# Number of steps in the initial state # Number of steps in the initial state
initl = init_l[op.n_mit_mot + idx] initl = init_l[op_info.n_mit_mot + idx]
# If the initial buffer has the form # If the initial buffer has the form
# inc_subtensor(zeros(...)[...], _nw_input) # inc_subtensor(zeros(...)[...], _nw_input)
...@@ -1501,8 +1515,10 @@ def save_mem_new_scan(fgraph, node): ...@@ -1501,8 +1515,10 @@ def save_mem_new_scan(fgraph, node):
else: else:
nw_input = nw_inputs[in_idx][: (initl + nw_steps)] nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot: elif (
in_idx = offset + idx + op.n_shared_outs idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
):
in_idx = offset + idx + op_info.n_shared_outs
if nw_inputs[in_idx] == node.inputs[0]: if nw_inputs[in_idx] == node.inputs[0]:
nw_inputs[in_idx] = nw_steps nw_inputs[in_idx] = nw_steps
...@@ -1693,8 +1709,8 @@ class ScanMerge(GlobalOptimizer): ...@@ -1693,8 +1709,8 @@ class ScanMerge(GlobalOptimizer):
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inner_inputs), idx)) inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inner_inputs), idx))
inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.inner_outputs)) inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.inner_outputs))
mit_mot_in_slices += nd.op.mitmot_taps() mit_mot_in_slices += nd.op.info.mit_mot_in_slices
mit_mot_out_slices += nd.op.mitmot_out_taps() mit_mot_out_slices += nd.op.info.mit_mot_out_slices[: nd.op.info.n_mit_mot]
outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx) outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
outer_outs += nd.op.outer_mitmot_outs(nd.outputs) outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
...@@ -1702,14 +1718,14 @@ class ScanMerge(GlobalOptimizer): ...@@ -1702,14 +1718,14 @@ class ScanMerge(GlobalOptimizer):
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inner_inputs), idx)) inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inner_inputs), idx))
inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.inner_outputs)) inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.inner_outputs))
mit_sot_in_slices += nd.op.mitsot_taps() mit_sot_in_slices += nd.op.info.mit_sot_in_slices
outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx) outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd.outputs) outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
sit_sot_in_slices = () sit_sot_in_slices = ()
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inner_inputs), idx)) inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inner_inputs), idx))
sit_sot_in_slices += tuple((-1,) for x in range(nd.op.n_sit_sot)) sit_sot_in_slices += tuple((-1,) for x in range(nd.op.info.n_sit_sot))
inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.inner_outputs)) inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.inner_outputs))
outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx) outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd.outputs) outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
...@@ -1802,13 +1818,13 @@ class ScanMerge(GlobalOptimizer): ...@@ -1802,13 +1818,13 @@ class ScanMerge(GlobalOptimizer):
new_inner_outs += inner_outs[idx][gr_idx] new_inner_outs += inner_outs[idx][gr_idx]
info = ScanInfo( info = ScanInfo(
n_seqs=sum(nd.op.n_seqs for nd in nodes), n_seqs=sum(nd.op.info.n_seqs for nd in nodes),
mit_mot_in_slices=mit_mot_in_slices, mit_mot_in_slices=mit_mot_in_slices,
mit_mot_out_slices=mit_mot_out_slices, mit_mot_out_slices=mit_mot_out_slices,
mit_sot_in_slices=mit_sot_in_slices, mit_sot_in_slices=mit_sot_in_slices,
sit_sot_in_slices=sit_sot_in_slices, sit_sot_in_slices=sit_sot_in_slices,
n_nit_sot=sum(nd.op.n_nit_sot for nd in nodes), n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes),
n_shared_outs=sum(nd.op.n_shared_outs for nd in nodes), n_shared_outs=sum(nd.op.info.n_shared_outs for nd in nodes),
n_non_seqs=n_non_seqs, n_non_seqs=n_non_seqs,
as_while=as_while, as_while=as_while,
) )
......
...@@ -361,15 +361,19 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -361,15 +361,19 @@ def scan_can_remove_outs(op, out_idxs):
required_inputs = list(graph_inputs(non_removable)) required_inputs = list(graph_inputs(non_removable))
out_ins = [] out_ins = []
offset = op.n_seqs offset = op.info.n_seqs
for idx, tap in enumerate( for idx, tap in enumerate(
chain(op.mit_mot_in_slices, op.mit_sot_in_slices, op.sit_sot_in_slices) chain(
op.info.mit_mot_in_slices,
op.info.mit_sot_in_slices,
op.info.sit_sot_in_slices,
)
): ):
n_ins = len(tap) n_ins = len(tap)
out_ins += [op.inner_inputs[offset : offset + n_ins]] out_ins += [op.inner_inputs[offset : offset + n_ins]]
offset += n_ins offset += n_ins
out_ins += [[] for k in range(op.n_nit_sot)] out_ins += [[] for k in range(op.info.n_nit_sot)]
out_ins += [[op.inner_inputs[offset + k]] for k in range(op.n_shared_outs)] out_ins += [[op.inner_inputs[offset + k]] for k in range(op.info.n_shared_outs)]
added = True added = True
out_idxs_mask = [1 for idx in out_idxs] out_idxs_mask = [1 for idx in out_idxs]
...@@ -400,8 +404,9 @@ def compress_outs(op, not_required, inputs): ...@@ -400,8 +404,9 @@ def compress_outs(op, not_required, inputs):
""" """
from aesara.scan.op import ScanInfo from aesara.scan.op import ScanInfo
op_info = op.info
info = ScanInfo( info = ScanInfo(
n_seqs=op.info.n_seqs, n_seqs=op_info.n_seqs,
mit_mot_in_slices=(), mit_mot_in_slices=(),
mit_mot_out_slices=(), mit_mot_out_slices=(),
mit_sot_in_slices=(), mit_sot_in_slices=(),
...@@ -409,56 +414,58 @@ def compress_outs(op, not_required, inputs): ...@@ -409,56 +414,58 @@ def compress_outs(op, not_required, inputs):
n_nit_sot=0, n_nit_sot=0,
n_shared_outs=0, n_shared_outs=0,
n_non_seqs=0, n_non_seqs=0,
as_while=op.info.as_while, as_while=op_info.as_while,
) )
op_inputs = op.inner_inputs[: op.n_seqs] op_inputs = op.inner_inputs[: op_info.n_seqs]
op_outputs = [] op_outputs = []
node_inputs = inputs[: op.n_seqs + 1] node_inputs = inputs[: op_info.n_seqs + 1]
map_old_new = OrderedDict() map_old_new = OrderedDict()
offset = 0 offset = 0
ni_offset = op.n_seqs + 1 ni_offset = op_info.n_seqs + 1
i_offset = op.n_seqs i_offset = op_info.n_seqs
o_offset = 0 o_offset = 0
curr_pos = 0 curr_pos = 0
for idx in range(op.info.n_mit_mot): for idx in range(op_info.n_mit_mot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info = dataclasses.replace( info = dataclasses.replace(
info, info,
mit_mot_in_slices=info.mit_mot_in_slices + (op.mit_mot_in_slices[idx],), mit_mot_in_slices=info.mit_mot_in_slices
+ (op_info.mit_mot_in_slices[idx],),
mit_mot_out_slices=info.mit_mot_out_slices mit_mot_out_slices=info.mit_mot_out_slices
+ (op.mit_mot_out_slices[idx],), + (op_info.mit_mot_out_slices[idx],),
) )
# input taps # input taps
for jdx in op.mit_mot_in_slices[idx]: for jdx in op_info.mit_mot_in_slices[idx]:
op_inputs += [op.inner_inputs[i_offset]] op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1 i_offset += 1
# output taps # output taps
for jdx in op.mit_mot_out_slices[idx]: for jdx in op_info.mit_mot_out_slices[idx]:
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1 o_offset += 1
# node inputs # node inputs
node_inputs += [inputs[ni_offset + idx]] node_inputs += [inputs[ni_offset + idx]]
else: else:
o_offset += len(op.mit_mot_out_slices[idx]) o_offset += len(op_info.mit_mot_out_slices[idx])
i_offset += len(op.mit_mot_in_slices[idx]) i_offset += len(op_info.mit_mot_in_slices[idx])
offset += op.n_mit_mot offset += op_info.n_mit_mot
ni_offset += op.n_mit_mot ni_offset += op_info.n_mit_mot
for idx in range(op.info.n_mit_sot): for idx in range(op_info.n_mit_sot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info = dataclasses.replace( info = dataclasses.replace(
info, info,
mit_sot_in_slices=info.mit_sot_in_slices + (op.mit_sot_in_slices[idx],), mit_sot_in_slices=info.mit_sot_in_slices
+ (op_info.mit_sot_in_slices[idx],),
) )
# input taps # input taps
for jdx in op.mit_sot_in_slices[idx]: for jdx in op_info.mit_sot_in_slices[idx]:
op_inputs += [op.inner_inputs[i_offset]] op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1 i_offset += 1
# output taps # output taps
...@@ -468,17 +475,18 @@ def compress_outs(op, not_required, inputs): ...@@ -468,17 +475,18 @@ def compress_outs(op, not_required, inputs):
node_inputs += [inputs[ni_offset + idx]] node_inputs += [inputs[ni_offset + idx]]
else: else:
o_offset += 1 o_offset += 1
i_offset += len(op.mit_sot_in_slices[idx]) i_offset += len(op_info.mit_sot_in_slices[idx])
offset += op.n_mit_sot offset += op_info.n_mit_sot
ni_offset += op.n_mit_sot ni_offset += op_info.n_mit_sot
for idx in range(op.info.n_sit_sot): for idx in range(op_info.n_sit_sot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info = dataclasses.replace( info = dataclasses.replace(
info, info,
sit_sot_in_slices=info.sit_sot_in_slices + (op.sit_sot_in_slices[idx],), sit_sot_in_slices=info.sit_sot_in_slices
+ (op_info.sit_sot_in_slices[idx],),
) )
# input taps # input taps
op_inputs += [op.inner_inputs[i_offset]] op_inputs += [op.inner_inputs[i_offset]]
...@@ -492,23 +500,23 @@ def compress_outs(op, not_required, inputs): ...@@ -492,23 +500,23 @@ def compress_outs(op, not_required, inputs):
o_offset += 1 o_offset += 1
i_offset += 1 i_offset += 1
offset += op.n_sit_sot offset += op_info.n_sit_sot
ni_offset += op.n_sit_sot ni_offset += op_info.n_sit_sot
nit_sot_ins = [] nit_sot_ins = []
for idx in range(op.info.n_nit_sot): for idx in range(op_info.n_nit_sot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1) info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1 o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]] nit_sot_ins += [inputs[ni_offset + idx + op_info.n_shared_outs]]
else: else:
o_offset += 1 o_offset += 1
offset += op.n_nit_sot offset += op_info.n_nit_sot
shared_ins = [] shared_ins = []
for idx in range(op.info.n_shared_outs): for idx in range(op_info.n_shared_outs):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
...@@ -526,8 +534,8 @@ def compress_outs(op, not_required, inputs): ...@@ -526,8 +534,8 @@ def compress_outs(op, not_required, inputs):
# other stuff # other stuff
op_inputs += op.inner_inputs[i_offset:] op_inputs += op.inner_inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:])) info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:]))
node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot :] node_inputs += inputs[ni_offset + op_info.n_shared_outs + op_info.n_nit_sot :]
if op.info.as_while: if op_info.as_while:
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1 map_old_new[o_offset] = len(op_outputs) - 1
# map_old_new[len(op_outputs)-1] = o_offset # map_old_new[len(op_outputs)-1] = o_offset
......
...@@ -97,8 +97,8 @@ def test_ScanArgs(): ...@@ -97,8 +97,8 @@ def test_ScanArgs():
# Check the properties that allow us to use # Check the properties that allow us to use
# `Scan.get_oinp_iinp_iout_oout_mappings` as-is to implement # `Scan.get_oinp_iinp_iout_oout_mappings` as-is to implement
# `ScanArgs.var_mappings` # `ScanArgs.var_mappings`
assert scan_args.n_nit_sot == scan_op.n_nit_sot assert scan_args.n_nit_sot == scan_op.info.n_nit_sot
assert scan_args.n_mit_mot == scan_op.n_mit_mot assert scan_args.n_mit_mot == scan_op.info.n_mit_mot
# The `scan_args` base class always clones the inner-graph; # The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same) # here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inner_inputs assert scan_args.inputs == scan_op.inner_inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论