提交 0b97f658 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Stop inner-graph output reordering in Scan

In order to use shared updates to pre-allocate the storage for mit-mot input and output loops, `Scan` would need to remove the corresponding mit-mot outputs from its inner-`FunctionGraph` before compilation and it would expect the `Function` compilation pipeline to add them back at the end of the remaining outputs. Now, `Scan`'s inner-`FunctionGraph`s maintain the same form at every point, and no special logic is needed to compensate for post-compilation changes in the order/location of inputs and outputs.
上级 27d2bfe3
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -861,30 +861,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
self.n_tap_outs = info.n_mit_mot + info.n_mit_sot
# TODO: These can be moved to thunk/function compilation
(
self.preallocated_mitmot_outs,
_,
self.mitmots_preallocated,
) = self._mitmot_preallocations()
self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs
features = []
self.fgraph = FunctionGraph(inputs, outputs, clone=False)
if config.scan__allow_output_prealloc:
# This feature will prevent mitsot, sitsot and nitsot outputs from
# being computed inplace (to allow their preallocation).
mitsot_start = info.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = mitsot_start + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
features.append(NoOutputFromInplace(range(mitsot_start, nitsot_end)))
self.fgraph = FunctionGraph(
inputs,
outputs,
clone=False,
features=features,
)
_ = self.prepare_fgraph(self.fgraph)
if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
raise InconsistencyError(
......@@ -1360,23 +1348,21 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
)
@property
def fn(self):
"""Lazily compile the inner function graph."""
if getattr(self, "_fn", None) is not None:
return self._fn
def prepare_fgraph(self, fgraph):
"""Update and wrap `fgraph`'s inputs and outputs in preparation for compilation."""
# If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of
# scan is done
slices = (
self.info.n_mit_mot_outs
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
)
info = self.info
slices = info.n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
# Setting `fgraph.update_mapping` will indicate to the `Function`
# construction pipeline that it needn't append the updates to the
# `FunctionGraph` outputs itself, because they're already in the given
# `FunctionGraph`'s outputs. This also prevents us from needing to
# remove those outputs here just to compensate for an overly rigid
# `Function` pipeline.
update_mapping = {}
fgraph = self.fgraph.clone()
preallocated_mitmot_outs = []
if config.scan__allow_output_prealloc:
......@@ -1384,32 +1370,31 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# input and an output, wrap the input such that the corresponding
# output variable becomes an update to be performed on it, possibly
# inplace at the end of the functions's execution.
wrapped_inputs = [
In(x, borrow=False) for x in fgraph.inputs[: self.info.n_seqs]
]
new_outputs = [x for x in fgraph.outputs]
wrapped_inputs = [In(x, borrow=False) for x in fgraph.inputs[: info.n_seqs]]
input_idx = self.info.n_seqs
for mitmot_idx in range(self.info.n_mit_mot):
for inp_tap in self.info.mit_mot_in_slices[mitmot_idx]:
if inp_tap in self.info.mit_mot_out_slices[mitmot_idx]:
input_idx = info.n_seqs
for mitmot_idx in range(info.n_mit_mot):
for inp_tap in info.mit_mot_in_slices[mitmot_idx]:
if inp_tap in info.mit_mot_out_slices[mitmot_idx]:
inp = fgraph.inputs[input_idx]
# Figure out the index of the corresponding output
output_idx = sum(
len(m) for m in self.info.mit_mot_out_slices[:mitmot_idx]
len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
)
output_idx += self.info.mit_mot_out_slices[mitmot_idx].index(
inp_tap
)
# Make it so the input is automatically updated to the
# output value, possibly inplace, at the end of the
# function execution. Also, since an update is
# defined, a default value must also be (this is
# verified by DebugMode). Use an array of size 0 but
# the right ndim and dtype (use a shape of 1 on
# broadcastable dimensions, 0 on the others).
preallocated_mitmot_outs.append(output_idx)
# Make it so that the input is automatically updated to
# the output value, possibly inplace, at the end of the
# function execution. Also, since an update is defined,
# a default value must also be (this is verified by
# DebugMode). Use an array of size 0 with the correct
# ndim and dtype (use a shape of 1 on broadcastable
# dimensions, and 0 on the others).
default_shape = [1 if _b else 0 for _b in inp.broadcastable]
default_val = inp.type.value_zeros(default_shape)
wrapped_inp = In(
......@@ -1417,10 +1402,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
value=default_val,
update=fgraph.outputs[output_idx],
)
update_mapping[output_idx] = input_idx
wrapped_inputs.append(wrapped_inp)
else:
# Wrap the corresponding input as usual. Leave the
# output as-is.
wrapped_inputs.append(
In(fgraph.inputs[input_idx], borrow=False)
)
......@@ -1429,22 +1413,57 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Wrap the inputs not associated to mitmots and wrap the remaining
# outputs
wrapped_inputs += [In(x, borrow=False) for x in fgraph.inputs[input_idx:]]
wrapped_outputs = [Out(x, borrow=True) for x in new_outputs[:slices]]
wrapped_outputs += new_outputs[slices:]
wrapped_outputs = [Out(x, borrow=True) for x in fgraph.outputs[:slices]]
wrapped_outputs += fgraph.outputs[slices:]
# Remove now useless outputs from the output list and start from
# the end to avoid altering the indices of the other outputs to be
# deleted.
for p in self.preallocated_mitmot_outs[::-1]:
fgraph.remove_output(p, reason="scan_prealloc")
del wrapped_outputs[p]
protected_outs = tuple(
i
for i in range(
info.n_mit_mot_outs
+ info.n_mit_sot
+ info.n_sit_sot
+ info.n_nit_sot
)
if i not in preallocated_mitmot_outs
)
fgraph.attach_feature(NoOutputFromInplace(protected_outs))
else:
wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs]
wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs[:slices]]
wrapped_outputs += fgraph.outputs[slices:]
fgraph.update_mapping = update_mapping
from aesara.compile.function.types import Supervisor
from aesara.graph.destroyhandler import DestroyHandler
for node in fgraph.apply_nodes:
if node.op.destroy_map:
fgraph.attach_feature(DestroyHandler())
break
fgraph.attach_feature(
Supervisor(
inp
for spec, inp in zip(wrapped_inputs, fgraph.inputs)
if not (
getattr(spec, "mutable", None)
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([inp]))
)
)
)
return wrapped_inputs, wrapped_outputs
@property
def fn(self):
"""Lazily compile the inner function graph."""
if getattr(self, "_fn", None) is not None:
return self._fn
wrapped_inputs, wrapped_outputs = self.prepare_fgraph(self.fgraph)
profile = None
if config.profile or (
isinstance(self.profile, (str, bool, (int,))) and self.profile
......@@ -1463,7 +1482,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
accept_inplace=False,
profile=profile,
on_unused_input="ignore",
fgraph=fgraph,
fgraph=self.fgraph,
)
return self._fn
......@@ -1559,10 +1578,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_input_storage = [s.storage for s in self.fn.input_storage]
inner_output_storage = [s.storage for s in self.fn.output_storage]
inner_input_needs_update = tuple(
inp.update is not None for inp in self.fn.maker.expanded_inputs
)
outer_output_dtypes = tuple(
getattr(out, "dtype", None) for out in node.outputs
)
......@@ -1607,8 +1622,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_outs_is_tensor,
inner_input_storage,
inner_output_storage,
getattr(self.fn.fn, "need_update_inputs", True),
inner_input_needs_update,
cython_destroy_map,
inputs,
outputs,
......@@ -1715,10 +1728,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
seqs.append(seq)
# 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containing the length of each output
# pos -- map containing the current position of each
# output
# The length of each output
store_steps = [
arg.shape[0]
for arg in inputs[self.seqs_arg_offset : self.shared_arg_offset]
......@@ -1763,6 +1774,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
output_storage[idx][0] = None
return
# The current position of each output
pos = [
(-self.mintaps[idx]) % store_steps[idx]
for idx in range(self.n_outs + info.n_nit_sot)
......@@ -1851,7 +1863,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for idx in range(info.n_mit_mot_outs):
if not self.mitmots_preallocated[idx]:
inner_output_storage[offset].storage[0] = None
offset += 1
offset += 1
# 4.2. Collect slices for mitsots, sitsots and nitsots
if i != 0:
......@@ -1950,37 +1962,25 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
pdx = offset + info.n_shared_outs
cond = inner_output_storage[pdx].storage[0] == 0
# 5.2. By calling fn() directly instead of calling the aesara
# function, it is possible that the updates have not been
# performed. Perform the updates if needed.
offset_out = len(inner_output_storage) - 1
if getattr(fn, "need_update_inputs", True):
# Update the inputs that have an update function
for inp, storage in zip(
self.fn.maker.expanded_inputs[::-1], self.fn.input_storage[::-1]
):
if inp.update is not None:
storage.data = inner_output_storage[offset_out].data
offset_out -= 1
t_fn += dt_fn
offset_out = 0
# 5.3 Copy over the values for mit_mot outputs
mitmot_inp_offset = 0
mitmot_inp_grp_offset = 0
mitmot_out_idx = 0
for j, taps in enumerate(info.mit_mot_in_slices):
for k in info.mit_mot_out_slices[j]:
for mitmot_grp_idx, taps in enumerate(info.mit_mot_in_slices):
for out_slice in info.mit_mot_out_slices[mitmot_grp_idx]:
if self.mitmots_preallocated[mitmot_out_idx]:
# This output tap has been preallocated.
inp_idx = mitmot_inp_offset + taps.index(k)
mitmot_inp_idx = mitmot_inp_grp_offset + taps.index(out_slice)
inner_inp_idx = self.n_seqs + mitmot_inp_idx
# Verify whether the input points to the same data as
# it did before the execution of the inner function.
old_var = old_mitmot_input_storage[inp_idx]
new_var = inner_input_storage[info.n_seqs + inp_idx].storage[0]
old_var = old_mitmot_input_storage[mitmot_inp_idx]
new_var = inner_input_storage[inner_inp_idx].storage[0]
if old_var is new_var:
old_data = old_mitmot_input_data[inp_idx]
old_data = old_mitmot_input_data[mitmot_inp_idx]
same_data = new_var.data == old_data
else:
same_data = False
......@@ -1990,21 +1990,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# nothing needs to be done. Otherwise, recover the
# and store it in `outs` as usual
if not same_data:
output_storage[j][0][k + pos[j]] = inner_input_storage[
info.n_seqs + inp_idx
].storage[0]
output_storage[mitmot_grp_idx][0][
out_slice + pos[mitmot_grp_idx]
] = inner_input_storage[inner_inp_idx].storage[0]
else:
# This output tap has not been preallocated, recover
# its value as usual
output_storage[j][0][k + pos[j]] = inner_output_storage[
offset_out
].storage[0]
offset_out += 1
output_storage[mitmot_grp_idx][0][
out_slice + pos[mitmot_grp_idx]
] = inner_output_storage[offset_out].storage[0]
offset_out += 1
mitmot_out_idx += 1
mitmot_inp_offset += len(taps)
mitmot_inp_grp_offset += len(taps)
# 5.4 Copy over the values for mit_sot/sit_sot outputs
begin = info.n_mit_mot
......
......@@ -59,7 +59,7 @@ from aesara.scan.utils import InnerFunctionError
def get_version():
return 0.313
return 0.314
@cython.boundscheck(False)
def perform(
......@@ -81,8 +81,6 @@ def perform(
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
list inner_input_storage,
list inner_output_storage,
bint need_update_inputs,
tuple inner_input_needs_update,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs,
list outer_outputs,
......@@ -138,10 +136,6 @@ def perform(
The storage locations for the inner-function's inputs.
inner_output_storage
The storage locations for the inner-function's outputs.
need_update_inputs
A boolean indicating whether or not inner inputs need to be updated.
inner_input_needs_update
A tuple of booleans indicating which inner inputs need to be updated.
fnct: Function
The compiled Aesara inner-function object.
destroy_map
......@@ -173,8 +167,13 @@ def perform(
n_shared_outs)
cdef unsigned int offset_out
cdef unsigned int lenpos = n_outs + n_nit_sot
# TODO: See how this is being converted and whether or not we can remove
# fixed allocations caused by this.
cdef int pos[500] # put a maximum of 500 outputs
cdef unsigned int len_store_steps = n_mit_mot + n_mit_sot + n_sit_sot + n_nit_sot
# The length of each output
# TODO: See how this is being converted and whether or not we can remove
# fixed allocations caused by this.
cdef int store_steps[500]
cdef unsigned int l
cdef unsigned int offset
......@@ -214,9 +213,8 @@ def perform(
outer_inputs[1+idx].shape,
n_steps,
))
# 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containing the length of each output
# pos -- map containing the current position of each output
for idx in range(n_mit_mot + n_mit_sot + n_sit_sot):
store_steps[<unsigned int>idx] = outer_inputs[<unsigned int>(idx+n_seqs+1)].shape[0]
......@@ -329,7 +327,7 @@ def perform(
for idx in range(n_mit_mot_outs):
if not mitmots_preallocated[<unsigned int>idx]:
inner_output_storage[<unsigned int>offset][0] = None
offset += 1
offset += 1
# 4.2. Collect slices for mitsots, sitsots and nitsots
if i != 0:
......@@ -400,17 +398,6 @@ def perform(
pdx = offset + n_shared_outs
cond = inner_output_storage[pdx][0] == 0
# 5.2. By calling fn() directly instead of calling the aesara
# function, it is possible that the updates have not been
# performed. Perform the updates if needed.
if need_update_inputs:
offset_out = len(inner_output_storage) - 1
for needs_update, storage in zip(inner_input_needs_update[::-1],
inner_input_storage[::-1]):
if needs_update:
storage[0] = inner_output_storage[offset_out][0]
offset_out -= 1
offset_out = 0
# 5.3 Copy over the values for mit_mot outputs
......@@ -421,11 +408,12 @@ def perform(
if mitmots_preallocated[<unsigned int>mitmot_out_idx]:
# This output tap has been preallocated.
inp_idx = (mitmot_inp_offset + tap_array[j].index(k))
inner_inp_idx = n_seqs + inp_idx
# Verify whether the input points to the same data as
# it did before the execution of the inner function.
old_var = old_mitmot_input_storage[inp_idx]
new_var = inner_input_storage[n_seqs + inp_idx][0]
new_var = inner_input_storage[inner_inp_idx][0]
if old_var is new_var:
old_data = old_mitmot_input_data[inp_idx]
same_data = (new_var.data == old_data)
......@@ -437,15 +425,15 @@ def perform(
# modified inplace and nothing needs to be done.
if not same_data:
outer_outputs[j][0][<unsigned int>(k + pos[j])] = \
inner_input_storage[<unsigned int>(n_seqs + inp_idx)][0]
inner_input_storage[<unsigned int>(inner_inp_idx)][0]
else:
# This output tap has not been preallocated, recover
# its value as usual
outer_outputs[j][0][<unsigned int>(k + pos[j])] = \
inner_output_storage[<unsigned int>offset_out][0]
offset_out += 1
offset_out += 1
mitmot_out_idx += 1
mitmot_inp_offset += tap_array_len[j]
......
......@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.313 # must match constant returned in function get_version()
version = 0.314 # must match constant returned in function get_version()
need_reload = False
scan_perform: Optional[ModuleType] = None
......
......@@ -843,9 +843,14 @@ def test_random_state_transfer():
np.testing.assert_array_almost_equal(f1(), f2(), decimal=6)
def test_gradient_scan():
# Test for a crash when using MRG inside scan and taking the gradient
# See https://groups.google.com/d/msg/theano-dev/UbcYyU5m-M8/UO9UgXqnQP0J
@pytest.mark.parametrize(
"mode",
[
"FAST_RUN",
"FAST_COMPILE",
],
)
def test_gradient_scan(mode):
aesara_rng = MRG_RandomStream(10)
w = shared(np.ones(1, dtype="float32"))
......@@ -855,8 +860,12 @@ def test_gradient_scan():
x = vector(dtype="float32")
values, updates = scan(one_step, outputs_info=x, n_steps=10)
gw = grad(at_sum(values[-1]), w)
f = function([x], gw)
f(np.arange(1, dtype="float32"))
f = function([x], gw, mode=mode)
assert np.allclose(
f(np.arange(1, dtype=np.float32)),
np.array([0.13928187], dtype=np.float32),
rtol=1e6,
)
def test_simple_shared_mrg_random():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论