提交 93502185 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Lazily compute Scan inner-graph functions

上级 1886a580
...@@ -229,7 +229,7 @@ N.B.: ...@@ -229,7 +229,7 @@ N.B.:
for s in scan_ops: for s in scan_ops:
# prepare a dict which maps the scan op's inner inputs # prepare a dict which maps the scan op's inner inputs
# to its outer inputs. # to its outer inputs.
if hasattr(s.owner.op, "fn"): if hasattr(s.owner.op, "_fn"):
# If the op was compiled, print the optimized version. # If the op was compiled, print the optimized version.
inner_inputs = s.owner.op.fn.maker.fgraph.inputs inner_inputs = s.owner.op.fn.maker.fgraph.inputs
else: else:
...@@ -255,7 +255,7 @@ N.B.: ...@@ -255,7 +255,7 @@ N.B.:
scan_inner_to_outer_inputs=inner_to_outer_inputs, scan_inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids, used_ids=used_ids,
) )
if hasattr(s.owner.op, "fn"): if hasattr(s.owner.op, "_fn"):
# If the op was compiled, print the optimized version. # If the op was compiled, print the optimized version.
outputs = s.owner.op.fn.maker.fgraph.outputs outputs = s.owner.op.fn.maker.fgraph.outputs
else: else:
...@@ -1130,7 +1130,7 @@ def pydotprint( ...@@ -1130,7 +1130,7 @@ def pydotprint(
else: else:
new_name = basename + "_" + str(idx) new_name = basename + "_" + str(idx)
new_name = os.path.join(path, new_name + ext) new_name = os.path.join(path, new_name + ext)
if hasattr(scan_op.op, "fn"): if hasattr(scan_op.op, "_fn"):
to_print = scan_op.op.fn to_print = scan_op.op.fn
else: else:
to_print = scan_op.op.outputs to_print = scan_op.op.outputs
......
...@@ -666,16 +666,47 @@ class Scan(Op, ScanMethodsMixin): ...@@ -666,16 +666,47 @@ class Scan(Op, ScanMethodsMixin):
) )
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
(
self.preallocated_mitmot_outs,
self.mitmots_preallocated,
) = self._mitmot_preallocations()
def _mitmot_preallocations(self):
if config.scan__allow_output_prealloc:
preallocated_mitmot_outs = []
input_idx = self.n_seqs
for mitmot_idx in range(self.n_mit_mot):
for inp_tap in self.tap_array[mitmot_idx]:
if inp_tap in self.mit_mot_out_slices[mitmot_idx]:
# Figure out the index of the corresponding output
output_idx = sum(
[len(m) for m in self.mit_mot_out_slices[:mitmot_idx]]
)
output_idx += self.mit_mot_out_slices[mitmot_idx].index(inp_tap)
preallocated_mitmot_outs.append(output_idx)
input_idx += 1
preallocated_mitmot_outs.sort()
else:
# Output preallocation is not activated. Mark every mitmot output
# tap as not being preallocated
preallocated_mitmot_outs = []
# Store the list of mitmot output taps that have been altered so they
# can be preallocated
mitmots_preallocated = [
i in preallocated_mitmot_outs for i in range(self.n_mit_mot_outs)
]
return preallocated_mitmot_outs, mitmots_preallocated
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
if hasattr(self, "fn"): if getattr(self, "_fn", None) is not None:
if not hasattr(self, "thunk_mit_mot_out_slices"):
# The thunk has been compiled before mit_mot preallocation
# feature was implemented. Mark every mit_mot output tap as
# not having been preallocated
self.mitmots_preallocated = [False] * self.n_mit_mot_outs
if not hasattr(self, "outs_is_tensor"): if not hasattr(self, "outs_is_tensor"):
# The thunk has been compiled before the analysis, at # The thunk has been compiled before the analysis, at
# compilation time, of the location of the inputs and outputs. # compilation time, of the location of the inputs and outputs.
...@@ -1195,45 +1226,12 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1195,45 +1226,12 @@ class Scan(Op, ScanMethodsMixin):
) )
) )
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): @property
""" def fn(self):
"""Lazily compile the inner function graph."""
if getattr(self, "_fn", None) is not None:
return self._fn
Parameters
----------
node
Something previously returned by self.make_node.
storage_map
dict variable -> one-element-list where a computed
value for this variable may be found.
compute_map
dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
no_recycling
List of variables for which it is forbidden to reuse memory
allocated by a previous call.
impl
Use 'py' if we want python execution.
Notes
-----
If the thunk consults the storage_map on every call, it is safe
for it to ignore the no_recycling argument, because elements of the
no_recycling list will have a value of None in the storage map. If
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
"""
# Before building the thunk, validate that the inner graph is
# coherent
self.validate_inner_graph()
# Setting up all my variables in what I believe is a more Cython
# friendly form
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
# If a shared variable is the result of a ViewOp it is a clear # If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of # indication that we need to copy that value after the perform of
# scan is done # scan is done
...@@ -1247,7 +1245,6 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1247,7 +1245,6 @@ class Scan(Op, ScanMethodsMixin):
# inplace at the end of the functions's execution. # inplace at the end of the functions's execution.
wrapped_inputs = [In(x, borrow=False) for x in self.inputs[: self.n_seqs]] wrapped_inputs = [In(x, borrow=False) for x in self.inputs[: self.n_seqs]]
new_outputs = [x for x in self.outputs] new_outputs = [x for x in self.outputs]
preallocated_mitmot_outs = []
input_idx = self.n_seqs input_idx = self.n_seqs
for mitmot_idx in range(self.n_mit_mot): for mitmot_idx in range(self.n_mit_mot):
...@@ -1276,7 +1273,6 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1276,7 +1273,6 @@ class Scan(Op, ScanMethodsMixin):
update=self.outputs[output_idx], update=self.outputs[output_idx],
) )
wrapped_inputs.append(wrapped_inp) wrapped_inputs.append(wrapped_inp)
preallocated_mitmot_outs.append(output_idx)
else: else:
# Wrap the corresponding input as usual. Leave the # Wrap the corresponding input as usual. Leave the
# output as-is. # output as-is.
...@@ -1292,16 +1288,9 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1292,16 +1288,9 @@ class Scan(Op, ScanMethodsMixin):
# Remove now useless outputs from the output list (start from the # Remove now useless outputs from the output list (start from the
# end to avoid altering the indices of the other outputs to be # end to avoid altering the indices of the other outputs to be
# deleted. # deleted.
preallocated_mitmot_outs.sort() for p in self.preallocated_mitmot_outs[::-1]:
for p in preallocated_mitmot_outs[::-1]:
del wrapped_outputs[p] del wrapped_outputs[p]
# Store the list of mitmot output taps that have been altered
# so they can be preallocated
self.mitmots_preallocated = [
i in preallocated_mitmot_outs for i in range(self.n_mit_mot_outs)
]
# Add an optimization to the compilation mode to attach a feature # Add an optimization to the compilation mode to attach a feature
# to the function graph just before the inplace optimizations are # to the function graph just before the inplace optimizations are
# applied (inplace optimizations start at position 50 so the # applied (inplace optimizations start at position 50 so the
...@@ -1309,16 +1298,13 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1309,16 +1298,13 @@ class Scan(Op, ScanMethodsMixin):
# so that it runs before them). This feature will prevent mitsot, # so that it runs before them). This feature will prevent mitsot,
# sitsot and nitsot outputs from being computed inplace (to allow # sitsot and nitsot outputs from being computed inplace (to allow
# their preallocation). # their preallocation).
mitsot_start = self.n_mit_mot_outs - len(preallocated_mitmot_outs) mitsot_start = self.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = mitsot_start + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot nitsot_end = mitsot_start + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot
feature = NoOutputFromInplace(mitsot_start, nitsot_end) feature = NoOutputFromInplace(mitsot_start, nitsot_end)
opt = AddFeatureOptimizer(feature) opt = AddFeatureOptimizer(feature)
compilation_mode = self.mode_instance.register((opt, 49.9)) compilation_mode = self.mode_instance.register((opt, 49.9))
else: else:
# Output preallocation is not activated. Mark every mitmot output
# tap as not being preallocated
self.mitmots_preallocated = [False] * self.n_mit_mot_outs
wrapped_inputs = [In(x, borrow=True) for x in self.inputs] wrapped_inputs = [In(x, borrow=True) for x in self.inputs]
wrapped_outputs = [Out(x, borrow=False) for x in self.outputs[:slices]] wrapped_outputs = [Out(x, borrow=False) for x in self.outputs[:slices]]
...@@ -1336,17 +1322,57 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1336,17 +1322,57 @@ class Scan(Op, ScanMethodsMixin):
profile = ScanProfileStats(name=self.name) profile = ScanProfileStats(name=self.name)
elif self.profile: elif self.profile:
profile = self.profile profile = self.profile
# make_thunk can be called many times on the same op
# we do not want to recompile the inner fct every time. self._fn = function(
if not getattr(self, "fn", None): wrapped_inputs,
self.fn = function( wrapped_outputs,
wrapped_inputs, mode=compilation_mode,
wrapped_outputs, name=self.name,
mode=compilation_mode, profile=profile,
name=self.name, on_unused_input="ignore",
profile=profile, )
on_unused_input="ignore",
) return self._fn
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
"""
Parameters
----------
node
Something previously returned by self.make_node.
storage_map
dict variable -> one-element-list where a computed
value for this variable may be found.
compute_map
dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
no_recycling
List of variables for which it is forbidden to reuse memory
allocated by a previous call.
impl
Use 'py' if we want python execution.
Notes
-----
If the thunk consults the storage_map on every call, it is safe
for it to ignore the no_recycling argument, because elements of the
no_recycling list will have a value of None in the storage map. If
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
"""
# Before building the thunk, validate that the inner graph is
# coherent
self.validate_inner_graph()
# Setting up all my variables in what I believe is a more Cython
# friendly form
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
# Analyse the compile inner function to determine which inputs and # Analyse the compile inner function to determine which inputs and
# outputs are on the gpu and speed up some checks during the execution # outputs are on the gpu and speed up some checks during the execution
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论