提交 93502185 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Lazily compute Scan inner-graph functions

上级 1886a580
......@@ -229,7 +229,7 @@ N.B.:
for s in scan_ops:
# prepare a dict which maps the scan op's inner inputs
# to its outer inputs.
if hasattr(s.owner.op, "fn"):
if hasattr(s.owner.op, "_fn"):
# If the op was compiled, print the optimized version.
inner_inputs = s.owner.op.fn.maker.fgraph.inputs
else:
......@@ -255,7 +255,7 @@ N.B.:
scan_inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
)
if hasattr(s.owner.op, "fn"):
if hasattr(s.owner.op, "_fn"):
# If the op was compiled, print the optimized version.
outputs = s.owner.op.fn.maker.fgraph.outputs
else:
......@@ -1130,7 +1130,7 @@ def pydotprint(
else:
new_name = basename + "_" + str(idx)
new_name = os.path.join(path, new_name + ext)
if hasattr(scan_op.op, "fn"):
if hasattr(scan_op.op, "_fn"):
to_print = scan_op.op.fn
else:
to_print = scan_op.op.outputs
......
......@@ -666,16 +666,47 @@ class Scan(Op, ScanMethodsMixin):
)
self._hash_inner_graph = hash(self._cmodule_key)
(
self.preallocated_mitmot_outs,
self.mitmots_preallocated,
) = self._mitmot_preallocations()
def _mitmot_preallocations(self):
if config.scan__allow_output_prealloc:
preallocated_mitmot_outs = []
input_idx = self.n_seqs
for mitmot_idx in range(self.n_mit_mot):
for inp_tap in self.tap_array[mitmot_idx]:
if inp_tap in self.mit_mot_out_slices[mitmot_idx]:
# Figure out the index of the corresponding output
output_idx = sum(
[len(m) for m in self.mit_mot_out_slices[:mitmot_idx]]
)
output_idx += self.mit_mot_out_slices[mitmot_idx].index(inp_tap)
preallocated_mitmot_outs.append(output_idx)
input_idx += 1
preallocated_mitmot_outs.sort()
else:
# Output preallocation is not activated. Mark every mitmot output
# tap as not being preallocated
preallocated_mitmot_outs = []
# Store the list of mitmot output taps that have been altered so they
# can be preallocated
mitmots_preallocated = [
i in preallocated_mitmot_outs for i in range(self.n_mit_mot_outs)
]
return preallocated_mitmot_outs, mitmots_preallocated
def __setstate__(self, d):
self.__dict__.update(d)
if hasattr(self, "fn"):
if not hasattr(self, "thunk_mit_mot_out_slices"):
# The thunk has been compiled before mit_mot preallocation
# feature was implemented. Mark every mit_mot output tap as
# not having been preallocated
self.mitmots_preallocated = [False] * self.n_mit_mot_outs
if getattr(self, "_fn", None) is not None:
if not hasattr(self, "outs_is_tensor"):
# The thunk has been compiled before the analysis, at
# compilation time, of the location of the inputs and outputs.
......@@ -1195,45 +1226,12 @@ class Scan(Op, ScanMethodsMixin):
)
)
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
"""
@property
def fn(self):
"""Lazily compile the inner function graph."""
if getattr(self, "_fn", None) is not None:
return self._fn
Parameters
----------
node
Something previously returned by self.make_node.
storage_map
dict variable -> one-element-list where a computed
value for this variable may be found.
compute_map
dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
no_recycling
List of variables for which it is forbidden to reuse memory
allocated by a previous call.
impl
Use 'py' if we want python execution.
Notes
-----
If the thunk consults the storage_map on every call, it is safe
for it to ignore the no_recycling argument, because elements of the
no_recycling list will have a value of None in the storage map. If
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
"""
# Before building the thunk, validate that the inner graph is
# coherent
self.validate_inner_graph()
# Setting up all my variables in what I believe is a more Cython
# friendly form
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
# If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of
# scan is done
......@@ -1247,7 +1245,6 @@ class Scan(Op, ScanMethodsMixin):
# inplace at the end of the functions's execution.
wrapped_inputs = [In(x, borrow=False) for x in self.inputs[: self.n_seqs]]
new_outputs = [x for x in self.outputs]
preallocated_mitmot_outs = []
input_idx = self.n_seqs
for mitmot_idx in range(self.n_mit_mot):
......@@ -1276,7 +1273,6 @@ class Scan(Op, ScanMethodsMixin):
update=self.outputs[output_idx],
)
wrapped_inputs.append(wrapped_inp)
preallocated_mitmot_outs.append(output_idx)
else:
# Wrap the corresponding input as usual. Leave the
# output as-is.
......@@ -1292,16 +1288,9 @@ class Scan(Op, ScanMethodsMixin):
# Remove now useless outputs from the output list (start from the
# end to avoid altering the indices of the other outputs to be
# deleted.
preallocated_mitmot_outs.sort()
for p in preallocated_mitmot_outs[::-1]:
for p in self.preallocated_mitmot_outs[::-1]:
del wrapped_outputs[p]
# Store the list of mitmot output taps that have been altered
# so they can be preallocated
self.mitmots_preallocated = [
i in preallocated_mitmot_outs for i in range(self.n_mit_mot_outs)
]
# Add an optimization to the compilation mode to attach a feature
# to the function graph just before the inplace optimizations are
# applied (inplace optimizations start at position 50 so the
......@@ -1309,16 +1298,13 @@ class Scan(Op, ScanMethodsMixin):
# so that it runs before them). This feature will prevent mitsot,
# sitsot and nitsot outputs from being computed inplace (to allow
# their preallocation).
mitsot_start = self.n_mit_mot_outs - len(preallocated_mitmot_outs)
mitsot_start = self.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = mitsot_start + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot
feature = NoOutputFromInplace(mitsot_start, nitsot_end)
opt = AddFeatureOptimizer(feature)
compilation_mode = self.mode_instance.register((opt, 49.9))
else:
# Output preallocation is not activated. Mark every mitmot output
# tap as not being preallocated
self.mitmots_preallocated = [False] * self.n_mit_mot_outs
wrapped_inputs = [In(x, borrow=True) for x in self.inputs]
wrapped_outputs = [Out(x, borrow=False) for x in self.outputs[:slices]]
......@@ -1336,17 +1322,57 @@ class Scan(Op, ScanMethodsMixin):
profile = ScanProfileStats(name=self.name)
elif self.profile:
profile = self.profile
# make_thunk can be called many times on the same op
# we do not want to recompile the inner fct every time.
if not getattr(self, "fn", None):
self.fn = function(
wrapped_inputs,
wrapped_outputs,
mode=compilation_mode,
name=self.name,
profile=profile,
on_unused_input="ignore",
)
self._fn = function(
wrapped_inputs,
wrapped_outputs,
mode=compilation_mode,
name=self.name,
profile=profile,
on_unused_input="ignore",
)
return self._fn
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
"""
Parameters
----------
node
Something previously returned by self.make_node.
storage_map
dict variable -> one-element-list where a computed
value for this variable may be found.
compute_map
dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
no_recycling
List of variables for which it is forbidden to reuse memory
allocated by a previous call.
impl
Use 'py' if we want python execution.
Notes
-----
If the thunk consults the storage_map on every call, it is safe
for it to ignore the no_recycling argument, because elements of the
no_recycling list will have a value of None in the storage map. If
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
"""
# Before building the thunk, validate that the inner graph is
# coherent
self.validate_inner_graph()
# Setting up all my variables in what I believe is a more Cython
# friendly form
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
# Analyse the compile inner function to determine which inputs and
# outputs are on the gpu and speed up some checks during the execution
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论