提交 82f6a14f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Cleanup Function.__call__

上级 f0a9ec25
...@@ -326,8 +326,8 @@ class Function: ...@@ -326,8 +326,8 @@ class Function:
def __init__( def __init__(
self, self,
vm: "VM", vm: "VM",
input_storage, input_storage: list[Container],
output_storage, output_storage: list[Container],
indices, indices,
outputs, outputs,
defaults, defaults,
...@@ -372,7 +372,6 @@ class Function: ...@@ -372,7 +372,6 @@ class Function:
name name
A string name. A string name.
""" """
# TODO: Rename to `vm`
self.vm = vm self.vm = vm
self.input_storage = input_storage self.input_storage = input_storage
self.output_storage = output_storage self.output_storage = output_storage
...@@ -388,31 +387,49 @@ class Function: ...@@ -388,31 +387,49 @@ class Function:
self.nodes_with_inner_function = [] self.nodes_with_inner_function = []
self.output_keys = output_keys self.output_keys = output_keys
# See if we have any mutable / borrow inputs assert len(self.input_storage) == len(self.maker.fgraph.inputs)
# TODO: this only need to be set if there is more than one input assert len(self.output_storage) == len(self.maker.fgraph.outputs)
self._check_for_aliased_inputs = False
for i in maker.inputs: # Group indexes of inputs that are potentially aliased to each other
# If the input is a shared variable, the memory region is # Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
# under PyTensor control and so we don't need to check if it # even though there could be two distinct types that use the same kinds of underlying objects.
# is aliased as we never do that. potential_aliased_input_groups = []
if ( for inp in maker.inputs:
isinstance(i, In) # If the input is a shared variable, the memory region is under PyTensor control
and not i.shared # and can't be aliased.
and (getattr(i, "borrow", False) or getattr(i, "mutable", False)) if not (
isinstance(inp, In)
and inp.borrow
and not inp.shared
and hasattr(inp.variable.type, "may_share_memory")
): ):
self._check_for_aliased_inputs = True continue
for group in potential_aliased_input_groups:
# If one is super of the other, that means one could be replaced by the other
if any(
inp.variable.type.is_super(other_inp.variable.type)
or other_inp.variable.type.is_super(inp.variable.type)
for other_inp in group
):
group.append(inp)
break break
else: # no break
# Input makes a new group
potential_aliased_input_groups.append([inp])
# Potential aliased inputs are those that belong to the same group
self._potential_aliased_input_groups: tuple[tuple[int, ...], ...] = tuple(
tuple(maker.inputs.index(inp) for inp in group)
for group in potential_aliased_input_groups
if len(group) > 1
)
# We will be popping stuff off this `containers` object. It is a copy. # We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage) containers = list(self.input_storage)
finder = {} finder = {}
inv_finder = {} inv_finder = {}
def distribute(indices, cs, value):
input.distribute(value, indices, cs)
for c in cs:
c.provided += 1
# Store the list of names of named inputs. # Store the list of names of named inputs.
named_inputs = [] named_inputs = []
# Count the number of un-named inputs. # Count the number of un-named inputs.
...@@ -777,6 +794,13 @@ class Function: ...@@ -777,6 +794,13 @@ class Function:
f_cpy.maker.fgraph.name = name f_cpy.maker.fgraph.name = name
return f_cpy return f_cpy
def _restore_defaults(self):
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, Container):
value = value.storage[0]
self[i] = value
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" """
Evaluates value of a function on given arguments. Evaluates value of a function on given arguments.
...@@ -805,15 +829,10 @@ class Function: ...@@ -805,15 +829,10 @@ class Function:
List of outputs on indices/keys from ``output_subset`` or all of them, List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed. if ``output_subset`` is not passed.
""" """
input_storage = self.input_storage
def restore_defaults():
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, Container):
value = value.storage[0]
self[i] = value
profile = self.profile profile = self.profile
if profile:
t0 = time.perf_counter() t0 = time.perf_counter()
output_subset = kwargs.pop("output_subset", None) output_subset = kwargs.pop("output_subset", None)
...@@ -822,35 +841,31 @@ class Function: ...@@ -822,35 +841,31 @@ class Function:
# Reinitialize each container's 'provided' counter # Reinitialize each container's 'provided' counter
if self.trust_input: if self.trust_input:
i = 0 for arg_container, arg in zip(input_storage, args, strict=False):
for arg in args: arg_container.storage[0] = arg
s = self.input_storage[i]
s.storage[0] = arg
i += 1
else: else:
for c in self.input_storage: for arg_container in input_storage:
c.provided = 0 arg_container.provided = 0
if len(args) + len(kwargs) > len(self.input_storage): if len(args) + len(kwargs) > len(input_storage):
raise TypeError("Too many parameter passed to pytensor function") raise TypeError("Too many parameter passed to pytensor function")
# Set positional arguments # Set positional arguments
i = 0 for arg_container, arg in zip(input_storage, args, strict=False):
for arg in args: # See discussion about None as input
# TODO: provide a option for skipping the filter if we really
# want speed.
s = self.input_storage[i]
# see this emails for a discuation about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5 # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
if arg is None: if arg is None:
s.storage[0] = arg arg_container.storage[0] = arg
else: else:
try: try:
s.storage[0] = s.type.filter( arg_container.storage[0] = arg_container.type.filter(
arg, strict=s.strict, allow_downcast=s.allow_downcast arg,
strict=arg_container.strict,
allow_downcast=arg_container.allow_downcast,
) )
except Exception as e: except Exception as e:
i = input_storage.index(arg_container)
function_name = "pytensor function" function_name = "pytensor function"
argument_name = "argument" argument_name = "argument"
if self.name: if self.name:
...@@ -875,84 +890,65 @@ class Function: ...@@ -875,84 +890,65 @@ class Function:
+ function_name + function_name
+ f" at index {int(i)} (0-based). {where}" + f" at index {int(i)} (0-based). {where}"
) + e.args ) + e.args
restore_defaults() self._restore_defaults()
raise raise
s.provided += 1 arg_container.provided += 1
i += 1
# Set keyword arguments # Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items(): for k, arg in kwargs.items():
self[k] = arg self[k] = arg
if ( if not self.trust_input:
not self.trust_input
and
# The getattr is only needed for old pickle
getattr(self, "_check_for_aliased_inputs", True)
):
# Collect aliased inputs among the storage space # Collect aliased inputs among the storage space
args_share_memory = [] for potential_group in self._potential_aliased_input_groups:
for i in range(len(self.input_storage)): args_share_memory: list[list[int]] = []
i_var = self.maker.inputs[i].variable for i in potential_group:
i_val = self.input_storage[i].storage[0] i_type = self.maker.inputs[i].variable.type
if hasattr(i_var.type, "may_share_memory"): i_val = input_storage[i].storage[0]
is_aliased = False
for j in range(len(args_share_memory)): # Check if value is aliased with any of the values in one of the groups
group_j = zip( for j_group in args_share_memory:
[
self.maker.inputs[k].variable
for k in args_share_memory[j]
],
[
self.input_storage[k].storage[0]
for k in args_share_memory[j]
],
)
if any( if any(
( i_type.may_share_memory(input_storage[j].storage[0], i_val)
var.type is i_var.type for j in j_group
and var.type.may_share_memory(val, i_val)
)
for (var, val) in group_j
): ):
is_aliased = True j_group.append(i)
args_share_memory[j].append(i)
break break
else: # no break
if not is_aliased: # Create a new group
args_share_memory.append([i]) args_share_memory.append([i])
# Check for groups of more than one argument that share memory # Check for groups of more than one argument that share memory
for group in args_share_memory: for group in args_share_memory:
if len(group) > 1: if len(group) > 1:
# copy all but the first # copy all but the first
for j in group[1:]: for i in group[1:]:
self.input_storage[j].storage[0] = copy.copy( input_storage[i].storage[0] = copy.copy(
self.input_storage[j].storage[0] input_storage[i].storage[0]
) )
# Check if inputs are missing, or if inputs were set more than once, or # Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit. # if we tried to provide inputs that are supposed to be implicit.
if not self.trust_input: for arg_container in input_storage:
for c in self.input_storage: if arg_container.required and not arg_container.provided:
if c.required and not c.provided: self._restore_defaults()
restore_defaults()
raise TypeError( raise TypeError(
f"Missing required input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}" f"Missing required input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
) )
if c.provided > 1: if arg_container.provided > 1:
restore_defaults() self._restore_defaults()
raise TypeError( raise TypeError(
f"Multiple values for input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}" f"Multiple values for input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
) )
if c.implicit and c.provided > 0: if arg_container.implicit and arg_container.provided > 0:
restore_defaults() self._restore_defaults()
raise TypeError( raise TypeError(
f"Tried to provide value for implicit input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}" f"Tried to provide value for implicit input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
) )
# Do the actual work # Do the actual work
if profile:
t0_fn = time.perf_counter() t0_fn = time.perf_counter()
try: try:
outputs = ( outputs = (
...@@ -961,7 +957,7 @@ class Function: ...@@ -961,7 +957,7 @@ class Function:
else self.vm(output_subset=output_subset) else self.vm(output_subset=output_subset)
) )
except Exception: except Exception:
restore_defaults() self._restore_defaults()
if hasattr(self.vm, "position_of_error"): if hasattr(self.vm, "position_of_error"):
# this is a new vm-provided function or c linker # this is a new vm-provided function or c linker
# they need this because the exception manipulation # they need this because the exception manipulation
...@@ -979,26 +975,24 @@ class Function: ...@@ -979,26 +975,24 @@ class Function:
# old-style linkers raise their own exceptions # old-style linkers raise their own exceptions
raise raise
if profile:
dt_fn = time.perf_counter() - t0_fn dt_fn = time.perf_counter() - t0_fn
self.maker.mode.fn_time += dt_fn self.maker.mode.fn_time += dt_fn
if profile:
profile.vm_call_time += dt_fn profile.vm_call_time += dt_fn
# Retrieve the values that were computed # Retrieve the values that were computed
if outputs is None: if outputs is None:
outputs = [x.data for x in self.output_storage] outputs = [x.data for x in self.output_storage]
assert len(outputs) == len(self.output_storage)
# Remove internal references to required inputs. # Remove internal references to required inputs.
# These cannot be re-used anyway. # These cannot be re-used anyway.
for c in self.input_storage: for arg_container in input_storage:
if c.required: if arg_container.required:
c.storage[0] = None arg_container.storage[0] = None
# if we are allowing garbage collection, remove the # if we are allowing garbage collection, remove the
# output reference from the internal storage cells # output reference from the internal storage cells
if getattr(self.vm, "allow_gc", False): if getattr(self.vm, "allow_gc", False):
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
for o_container, o_variable in zip( for o_container, o_variable in zip(
self.output_storage, self.maker.fgraph.outputs self.output_storage, self.maker.fgraph.outputs
): ):
...@@ -1007,12 +1001,10 @@ class Function: ...@@ -1007,12 +1001,10 @@ class Function:
# WARNING: This circumvents the 'readonly' attribute in x # WARNING: This circumvents the 'readonly' attribute in x
o_container.storage[0] = None o_container.storage[0] = None
# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
# perform the updates themselves
if getattr(self.vm, "need_update_inputs", True): if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function # Update the inputs that have an update function
for input, storage in reversed( for input, storage in reversed(
list(zip(self.maker.expanded_inputs, self.input_storage)) list(zip(self.maker.expanded_inputs, input_storage))
): ):
if input.update is not None: if input.update is not None:
storage.data = outputs.pop() storage.data = outputs.pop()
...@@ -1020,17 +1012,12 @@ class Function: ...@@ -1020,17 +1012,12 @@ class Function:
outputs = outputs[: self.n_returned_outputs] outputs = outputs[: self.n_returned_outputs]
# Put default values back in the storage # Put default values back in the storage
restore_defaults() self._restore_defaults()
#
# NOTE: This logic needs to be replicated in
# scan.
# grep for 'PROFILE_CODE'
#
if profile:
dt_call = time.perf_counter() - t0 dt_call = time.perf_counter() - t0
pytensor.compile.profiling.total_fct_exec_time += dt_call pytensor.compile.profiling.total_fct_exec_time += dt_call
self.maker.mode.call_time += dt_call self.maker.mode.call_time += dt_call
if profile:
profile.fct_callcount += 1 profile.fct_callcount += 1
profile.fct_call_time += dt_call profile.fct_call_time += dt_call
if hasattr(self.vm, "update_profile"): if hasattr(self.vm, "update_profile"):
...@@ -1038,6 +1025,7 @@ class Function: ...@@ -1038,6 +1025,7 @@ class Function:
if profile.ignore_first_call: if profile.ignore_first_call:
profile.reset() profile.reset()
profile.ignore_first_call = False profile.ignore_first_call = False
if self.return_none: if self.return_none:
return None return None
elif self.unpack_single and len(outputs) == 1 and output_subset is None: elif self.unpack_single and len(outputs) == 1 and output_subset is None:
......
...@@ -128,9 +128,6 @@ class DisconnectedType(Type): ...@@ -128,9 +128,6 @@ class DisconnectedType(Type):
" a symbolic placeholder." " a symbolic placeholder."
) )
def may_share_memory(a, b):
return False
def value_eq(a, b, force_same_dtype=True): def value_eq(a, b, force_same_dtype=True):
raise AssertionError( raise AssertionError(
"If you're assigning to a DisconnectedType you're" "If you're assigning to a DisconnectedType you're"
......
...@@ -26,9 +26,6 @@ class NullType(Type): ...@@ -26,9 +26,6 @@ class NullType(Type):
def filter_variable(self, other, allow_convert=True): def filter_variable(self, other, allow_convert=True):
raise ValueError("No values may be assigned to a NullType") raise ValueError("No values may be assigned to a NullType")
def may_share_memory(a, b):
return False
def values_eq(self, a, b, force_same_dtype=True): def values_eq(self, a, b, force_same_dtype=True):
raise ValueError("NullType has no values to compare") raise ValueError("NullType has no values to compare")
......
...@@ -48,10 +48,7 @@ class Type(MetaObject, Generic[D]): ...@@ -48,10 +48,7 @@ class Type(MetaObject, Generic[D]):
unique element (i.e. it uses `self.__eq__`). unique element (i.e. it uses `self.__eq__`).
""" """
if self == otype: return self == otype
return True
return False
def is_super(self, otype: "Type") -> bool | None: def is_super(self, otype: "Type") -> bool | None:
"""Determine if `self` is a supertype of `otype`. """Determine if `self` is a supertype of `otype`.
......
...@@ -303,13 +303,6 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -303,13 +303,6 @@ class ScalarType(CType, HasDataType, HasShape):
dtype = self.dtype dtype = self.dtype
return type(self)(dtype) return type(self)(dtype)
@staticmethod
def may_share_memory(a, b):
# This class represent basic c type, represented in python
# with numpy.scalar. They are read only. So from python, they
# can never share memory.
return False
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
py_type = self.dtype_specs()[0] py_type = self.dtype_specs()[0]
if strict and not isinstance(data, py_type): if strict and not isinstance(data, py_type):
......
...@@ -126,12 +126,6 @@ class NoneTypeT(Generic): ...@@ -126,12 +126,6 @@ class NoneTypeT(Generic):
else: else:
raise TypeError("Expected None!") raise TypeError("Expected None!")
@staticmethod
def may_share_memory(a, b):
# None never share memory between object, in the sense of DebugMode.
# Python None are singleton
return False
none_type_t = NoneTypeT() none_type_t = NoneTypeT()
......
...@@ -730,6 +730,8 @@ class TestFunction: ...@@ -730,6 +730,8 @@ class TestFunction:
s1 = shared(b) s1 = shared(b)
s2 = shared(b) s2 = shared(b)
x1 = vector() x1 = vector()
x2 = vector(shape=(3,))
x3 = vector(shape=(1,))
# Assert cases we should not check for aliased inputs # Assert cases we should not check for aliased inputs
for d in [ for d in [
...@@ -737,27 +739,29 @@ class TestFunction: ...@@ -737,27 +739,29 @@ class TestFunction:
dict(outputs=[s1 + 1, s2 + 3]), dict(outputs=[s1 + 1, s2 + 3]),
dict(outputs=[s1 + 1], updates=[(s2, s2 + 3)]), dict(outputs=[s1 + 1], updates=[(s2, s2 + 3)]),
dict(inputs=[x1], outputs=[x1 + 1], updates=[(s2, s2 + 3)]), dict(inputs=[x1], outputs=[x1 + 1], updates=[(s2, s2 + 3)]),
dict(
inputs=[In(x1, mutable=True)], outputs=[x1 + 1], updates=[(s2, s2 + 3)]
),
dict(
inputs=[In(x2, mutable=True), In(x3, mutable=True)],
outputs=[x2 + 2, x3 + 3],
),
]: ]:
if "inputs" not in d: if "inputs" not in d:
d["inputs"] = [] d["inputs"] = []
f = function(**d) f = function(**d)
assert not f._check_for_aliased_inputs, d assert not f._potential_aliased_input_groups, d
# Assert cases we should check for aliased inputs # Assert cases we should check for aliased inputs
for d in [ for d in [
dict( dict(
inputs=[In(x1, borrow=True)], inputs=[In(x1, mutable=True), In(x2, mutable=True)],
outputs=[x1 + 1], outputs=[x1 + 1, x2 + 2],
updates=[(s2, s2 + 3)],
),
dict(
inputs=[In(x1, borrow=True, mutable=True)],
outputs=[x1 + 1],
updates=[(s2, s2 + 3)], updates=[(s2, s2 + 3)],
), ),
dict( dict(
inputs=[In(x1, mutable=True)], inputs=[In(x1, mutable=True), In(x3, mutable=True)],
outputs=[x1 + 1], outputs=[x1 + 1, x3 + 3],
updates=[(s2, s2 + 3)], updates=[(s2, s2 + 3)],
), ),
]: ]:
...@@ -765,7 +769,7 @@ class TestFunction: ...@@ -765,7 +769,7 @@ class TestFunction:
d["inputs"] = [] d["inputs"] = []
f = function(**d) f = function(**d)
assert f._check_for_aliased_inputs, d assert f._potential_aliased_input_groups, d
def test_output_dictionary(self): def test_output_dictionary(self):
# Tests that function works when outputs is a dictionary # Tests that function works when outputs is a dictionary
...@@ -879,7 +883,7 @@ class TestPicklefunction: ...@@ -879,7 +883,7 @@ class TestPicklefunction:
f = function( f = function(
[ [
x, x,
In(a, value=1.0, name="a"), In(a, value=1.0, name="a", mutable=True),
In(s, value=0.0, update=s + a * x, mutable=True), In(s, value=0.0, update=s + a * x, mutable=True),
], ],
s + a * x, s + a * x,
...@@ -901,7 +905,12 @@ class TestPicklefunction: ...@@ -901,7 +905,12 @@ class TestPicklefunction:
assert x not in g.container assert x not in g.container
assert x not in g.value assert x not in g.value
assert len(f.defaults) == len(g.defaults) assert len(f.defaults) == len(g.defaults)
assert f._check_for_aliased_inputs is g._check_for_aliased_inputs # Shared variable is the first input
assert (
f._potential_aliased_input_groups
== g._potential_aliased_input_groups
== ((1, 2),)
)
assert f.name == g.name assert f.name == g.name
assert f.maker.fgraph.name == g.maker.fgraph.name assert f.maker.fgraph.name == g.maker.fgraph.name
# print(f"{f.defaults = }") # print(f"{f.defaults = }")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论