提交 1a3af4b2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Reduce overhead of Function call

上级 a0c64b5f
...@@ -393,6 +393,8 @@ class Function: ...@@ -393,6 +393,8 @@ class Function:
assert len(self.input_storage) == len(self.maker.fgraph.inputs) assert len(self.input_storage) == len(self.maker.fgraph.inputs)
assert len(self.output_storage) == len(self.maker.fgraph.outputs) assert len(self.output_storage) == len(self.maker.fgraph.outputs)
self.has_defaults = any(refeed for _, refeed, _ in self.defaults)
# Group indexes of inputs that are potentially aliased to each other # Group indexes of inputs that are potentially aliased to each other
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type, # Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
# even though there could be two distinct types that use the same kinds of underlying objects. # even though there could be two distinct types that use the same kinds of underlying objects.
...@@ -540,14 +542,40 @@ class Function: ...@@ -540,14 +542,40 @@ class Function:
self._value = ValueAttribute() self._value = ValueAttribute()
self._container = ContainerAttribute() self._container = ContainerAttribute()
# TODO: Get rid of all this `expanded_inputs` nonsense update_storage = [
assert len(self.maker.expanded_inputs) == len(self.input_storage) container
for inp, container in zip(
self.maker.expanded_inputs, input_storage, strict=True
)
if inp.update is not None
]
# Updates are the last inner outputs that are not returned by Function.__call__
self.n_returned_outputs = len(self.output_storage) - len(update_storage)
# Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself
self.update_input_storage: tuple[int, Container] = ()
if getattr(vm, "need_update_inputs", True):
self.update_input_storage = tuple(
zip(
range(self.n_returned_outputs, len(output_storage)),
update_storage,
strict=True,
)
)
# This is used only when `vm.need_update_inputs` is `False`, because # In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage
# we're using one of the VM objects and it is putting updates back into # After the call, we want to erase (some of) these references, to allow Python to GC them if unused
# the input containers all by itself. # Required input containers are the non-default inputs, must always be provided again, so we GC them
self.n_returned_outputs = len(self.output_storage) - sum( self.clear_input_storage_data = tuple(
inp.update is not None for inp in self.maker.expanded_inputs container.storage for container in input_storage if container.required
)
# This is only done when `vm.allow_gc` is True, which can change at runtime.
self.clear_output_storage_data = tuple(
container.storage
for container, variable in zip(
self.output_storage, self.maker.fgraph.outputs, strict=True
)
if variable.owner is not None # Not a constant output
) )
for node in self.maker.fgraph.apply_nodes: for node in self.maker.fgraph.apply_nodes:
...@@ -747,7 +775,7 @@ class Function: ...@@ -747,7 +775,7 @@ class Function:
elif isinstance(profile, str): elif isinstance(profile, str):
profile = pytensor.compile.profiling.ProfileStats(message=profile) profile = pytensor.compile.profiling.ProfileStats(message=profile)
f_cpy = maker.__class__( f_cpy = type(maker)(
inputs=ins, inputs=ins,
outputs=outs, outputs=outs,
fgraph=fg_cpy, fgraph=fg_cpy,
...@@ -765,6 +793,8 @@ class Function: ...@@ -765,6 +793,8 @@ class Function:
# check that. # check that.
accept_inplace=True, accept_inplace=True,
no_fgraph_prep=True, no_fgraph_prep=True,
output_keys=maker.output_keys,
name=name,
).create(input_storage, storage_map=new_storage_map) ).create(input_storage, storage_map=new_storage_map)
for in_ori, in_cpy, ori, cpy in zip( for in_ori, in_cpy, ori, cpy in zip(
...@@ -797,8 +827,6 @@ class Function: ...@@ -797,8 +827,6 @@ class Function:
f_cpy.trust_input = self.trust_input f_cpy.trust_input = self.trust_input
f_cpy.unpack_single = self.unpack_single f_cpy.unpack_single = self.unpack_single
f_cpy.name = name
f_cpy.maker.fgraph.name = name
return f_cpy return f_cpy
def _restore_defaults(self): def _restore_defaults(self):
...@@ -808,7 +836,7 @@ class Function: ...@@ -808,7 +836,7 @@ class Function:
value = value.storage[0] value = value.storage[0]
self[i] = value self[i] = value
def __call__(self, *args, **kwargs): def __call__(self, *args, output_subset=None, **kwargs):
""" """
Evaluates value of a function on given arguments. Evaluates value of a function on given arguments.
...@@ -836,20 +864,21 @@ class Function: ...@@ -836,20 +864,21 @@ class Function:
List of outputs on indices/keys from ``output_subset`` or all of them, List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed. if ``output_subset`` is not passed.
""" """
trust_input = self.trust_input
input_storage = self.input_storage input_storage = self.input_storage
vm = self.vm
profile = self.profile profile = self.profile
if profile: if profile:
t0 = time.perf_counter() t0 = time.perf_counter()
output_subset = kwargs.pop("output_subset", None)
if output_subset is not None: if output_subset is not None:
warnings.warn("output_subset is deprecated.", FutureWarning) warnings.warn("output_subset is deprecated.", FutureWarning)
if self.output_keys is not None: if self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset] output_subset = [self.output_keys.index(key) for key in output_subset]
# Reinitialize each container's 'provided' counter # Reinitialize each container's 'provided' counter
if self.trust_input: if trust_input:
for arg_container, arg in zip(input_storage, args, strict=False): for arg_container, arg in zip(input_storage, args, strict=False):
arg_container.storage[0] = arg arg_container.storage[0] = arg
else: else:
...@@ -908,7 +937,7 @@ class Function: ...@@ -908,7 +937,7 @@ class Function:
for k, arg in kwargs.items(): for k, arg in kwargs.items():
self[k] = arg self[k] = arg
if not self.trust_input: if not trust_input:
# Collect aliased inputs among the storage space # Collect aliased inputs among the storage space
for potential_group in self._potential_aliased_input_groups: for potential_group in self._potential_aliased_input_groups:
args_share_memory: list[list[int]] = [] args_share_memory: list[list[int]] = []
...@@ -960,11 +989,7 @@ class Function: ...@@ -960,11 +989,7 @@ class Function:
if profile: if profile:
t0_fn = time.perf_counter() t0_fn = time.perf_counter()
try: try:
outputs = ( outputs = vm() if output_subset is None else vm(output_subset=output_subset)
self.vm()
if output_subset is None
else self.vm(output_subset=output_subset)
)
except Exception: except Exception:
self._restore_defaults() self._restore_defaults()
if hasattr(self.vm, "position_of_error"): if hasattr(self.vm, "position_of_error"):
...@@ -991,39 +1016,23 @@ class Function: ...@@ -991,39 +1016,23 @@ class Function:
# Retrieve the values that were computed # Retrieve the values that were computed
if outputs is None: if outputs is None:
outputs = [x.data for x in self.output_storage] outputs = [x.storage[0] for x in self.output_storage]
# Remove internal references to required inputs. # Set updates and filter them out from the returned outputs
# These cannot be re-used anyway. for i, input_storage in self.update_input_storage:
for arg_container in input_storage: input_storage.storage[0] = outputs[i]
if arg_container.required: outputs = outputs[: self.n_returned_outputs]
arg_container.storage[0] = None
# Remove input and output values from storage data
# if we are allowing garbage collection, remove the for storage_data in self.clear_input_storage_data:
# output reference from the internal storage cells storage_data[0] = None
if getattr(self.vm, "allow_gc", False): if getattr(vm, "allow_gc", False):
# strict=False because we are in a hot loop for storage_data in self.clear_output_storage_data:
for o_container, o_variable in zip( storage_data[0] = None
self.output_storage, self.maker.fgraph.outputs, strict=False
):
if o_variable.owner is not None:
# this node is the variable of computation
# WARNING: This circumvents the 'readonly' attribute in x
o_container.storage[0] = None
if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
# strict=False because we are in a hot loop
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, input_storage, strict=False))
):
if input.update is not None:
storage.data = outputs.pop()
else:
outputs = outputs[: self.n_returned_outputs]
# Put default values back in the storage # Put default values back in the storage
self._restore_defaults() if self.has_defaults:
self._restore_defaults()
if profile: if profile:
dt_call = time.perf_counter() - t0 dt_call = time.perf_counter() - t0
...@@ -1031,33 +1040,29 @@ class Function: ...@@ -1031,33 +1040,29 @@ class Function:
self.maker.mode.call_time += dt_call self.maker.mode.call_time += dt_call
profile.fct_callcount += 1 profile.fct_callcount += 1
profile.fct_call_time += dt_call profile.fct_call_time += dt_call
if hasattr(self.vm, "update_profile"): if hasattr(vm, "update_profile"):
self.vm.update_profile(profile) vm.update_profile(profile)
if profile.ignore_first_call: if profile.ignore_first_call:
profile.reset() profile.reset()
profile.ignore_first_call = False profile.ignore_first_call = False
if self.return_none: if self.return_none:
return None return None
elif self.unpack_single and len(outputs) == 1 and output_subset is None:
return outputs[0]
else:
if self.output_keys is not None:
assert len(self.output_keys) == len(outputs)
if output_subset is None: if output_subset is not None:
# strict=False because we are in a hot loop outputs = [outputs[i] for i in output_subset]
return dict(zip(self.output_keys, outputs, strict=False))
else:
return {
self.output_keys[index]: outputs[index]
for index in output_subset
}
if output_subset is None: if self.output_keys is None:
return outputs if self.unpack_single:
[out] = outputs
return out
else: else:
return [outputs[i] for i in output_subset] return outputs
else:
output_keys = self.output_keys
if output_subset is not None:
output_keys = [output_keys[i] for i in output_subset]
return dict(zip(output_keys, outputs, strict=True))
value = property( value = property(
lambda self: self._value, lambda self: self._value,
...@@ -1077,9 +1082,10 @@ class Function: ...@@ -1077,9 +1082,10 @@ class Function:
# 1.no allow_gc return False # 1.no allow_gc return False
# 2.has allow_gc, if allow_gc is False, return True # 2.has allow_gc, if allow_gc is False, return True
if not getattr(self.vm, "allow_gc", True): if not getattr(self.vm, "allow_gc", True):
for key in self.vm.storage_map: storage_map = self.vm.storage_map
if not isinstance(key, Constant): for key, value in storage_map.items():
self.vm.storage_map[key][0] = None if key.owner is not None: # Not a constant
value[0] = None
for node in self.nodes_with_inner_function: for node in self.nodes_with_inner_function:
if hasattr(node.fn, "free"): if hasattr(node.fn, "free"):
...@@ -1091,10 +1097,6 @@ class Function: ...@@ -1091,10 +1097,6 @@ class Function:
""" """
return [i.variable for i in self.maker.inputs if i.implicit] return [i.variable for i in self.maker.inputs if i.implicit]
def sync_shared(self):
# NOTE: sync was needed on old gpu backend
pass
def dprint(self, **kwargs): def dprint(self, **kwargs):
"""Debug print itself """Debug print itself
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论