提交 a643b5b2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up FunctionMaker interface

This removes some unused steps involving a missing `SymbolicInputKits` and changes `FunctionMaker._check_unused_inputs` to a static method without the preceding underscore.
上级 c54ef356
...@@ -28,7 +28,7 @@ from aesara.compile.function.types import ( ...@@ -28,7 +28,7 @@ from aesara.compile.function.types import (
from aesara.compile.mode import Mode, register_mode from aesara.compile.mode import Mode, register_mode
from aesara.compile.ops import OutputGuard, _output_guard from aesara.compile.ops import OutputGuard, _output_guard
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Variable, graph_inputs, io_toposort from aesara.graph.basic import Variable, io_toposort
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import BadOptimization from aesara.graph.features import BadOptimization
from aesara.graph.fg import InconsistencyError from aesara.graph.fg import InconsistencyError
...@@ -2041,20 +2041,10 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2041,20 +2041,10 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
inputs = [self.wrap_in(i) for i in inputs] inputs = [self.wrap_in(i) for i in inputs]
outputs = [self.wrap_out(o) for o in outputs] outputs = [self.wrap_out(o) for o in outputs]
_inputs = list(
graph_inputs(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
)
)
# Check if some input variables are unused # Check if some input variables are unused
self._check_unused_inputs(inputs, outputs, on_unused_input) self.check_unused_inputs(inputs, outputs, on_unused_input)
# Make a list of (SymbolicInput|SymblicInputKits, indices, indices = [[input, None, [input]] for input in inputs]
# [SymbolicInput,...]), one tuple for each input. (See
# Function.indices for more details)
indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
# make the fgraph # make the fgraph
for i in range(mode.stability_patience): for i in range(mode.stability_patience):
......
...@@ -653,7 +653,7 @@ class Function: ...@@ -653,7 +653,7 @@ class Function:
) )
# Re initialize Outs and swap update and variable in Ins # Re initialize Outs and swap update and variable in Ins
# By doing this, we can pass FunctionMaker._check_unused_inputs() # By doing this, we can pass FunctionMaker.check_unused_inputs()
outs = list(map(SymbolicOutput, fg_cpy.outputs[: len(maker.outputs)])) outs = list(map(SymbolicOutput, fg_cpy.outputs[: len(maker.outputs)]))
for out_ori, out_cpy in zip(maker.outputs, outs): for out_ori, out_cpy in zip(maker.outputs, outs):
out_cpy.borrow = out_ori.borrow out_cpy.borrow = out_ori.borrow
...@@ -1335,20 +1335,6 @@ class FunctionMaker: ...@@ -1335,20 +1335,6 @@ class FunctionMaker:
"instance" "instance"
) )
@staticmethod
def expand_in(sinput, rinputs):
# For SymbolicInputKits, this extracts a list of SymbolicInput
# instances and corresponding indices such that these
# SymbolicInputs are representative of some of the Variable
# instances in inputs. For SymbolicInput, this returns None
# as the list of indices and a list with just the
# SymbolicInput.
# if isinstance(sinput, SymbolicInputKit):
# return sinput.complete(rinputs)
# elif isinstance(sinput, SymbolicInput):
if isinstance(sinput, SymbolicInput):
return [None, [sinput]]
@staticmethod @staticmethod
def wrap_out(output): def wrap_out(output):
if isinstance(output, SymbolicOutput): if isinstance(output, SymbolicOutput):
...@@ -1358,6 +1344,59 @@ class FunctionMaker: ...@@ -1358,6 +1344,59 @@ class FunctionMaker:
else: else:
raise TypeError(f"Unknown output type: {type(output)} ({output})") raise TypeError(f"Unknown output type: {type(output)} ({output})")
@staticmethod
def check_unused_inputs(inputs, outputs, on_unused_input):
if on_unused_input is None:
on_unused_input = config.on_unused_input
if on_unused_input == "ignore":
return
# There should be two categories of variables in inputs:
# - variables that have to be provided (used_inputs)
# - shared variables that will be updated
used_inputs = list(
ancestors(
(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
),
blockers=[i.variable for i in inputs],
)
)
msg = (
"aesara.function was asked to create a function computing "
"outputs given certain inputs, but the provided input "
"variable at index %i is not part of the computational graph "
"needed to compute the outputs: %s.\n%s"
)
warn_msg = (
"To make this warning into an error, you can pass the "
"parameter on_unused_input='raise' to aesara.function. "
"To disable it completely, use on_unused_input='ignore'."
)
err_msg = (
"To make this error into a warning, you can pass the "
"parameter on_unused_input='warn' to aesara.function. "
"To disable it completely, use on_unused_input='ignore'."
)
for i in inputs:
if (i.variable not in used_inputs) and (i.update is None):
if on_unused_input == "warn":
warnings.warn(
msg % (inputs.index(i), i.variable, warn_msg), stacklevel=6
)
elif on_unused_input == "raise":
raise UnusedInputError(msg % (inputs.index(i), i.variable, err_msg))
else:
raise ValueError(
"Invalid value for keyword on_unused_input of aesara.function: "
f"'{on_unused_input}'.\n"
"Valid values are 'raise', 'warn', and 'ignore'."
)
def __init__( def __init__(
self, self,
inputs, inputs,
...@@ -1407,20 +1446,11 @@ class FunctionMaker: ...@@ -1407,20 +1446,11 @@ class FunctionMaker:
# Wrap them in In or Out instances if needed. # Wrap them in In or Out instances if needed.
inputs = [self.wrap_in(i) for i in inputs] inputs = [self.wrap_in(i) for i in inputs]
outputs = [self.wrap_out(o) for o in outputs] outputs = [self.wrap_out(o) for o in outputs]
_inputs = list(
graph_inputs(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
)
)
# Check if some input variables are unused # Check if some input variables are unused
self._check_unused_inputs(inputs, outputs, on_unused_input) self.check_unused_inputs(inputs, outputs, on_unused_input)
# Make a list of (SymbolicInput|SymblicInputKits, indices, indices = [[input, None, [input]] for input in inputs]
# [SymbolicInput,...]), one tuple for each input. (See
# Function.indices for more details)
indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
if fgraph is None: if fgraph is None:
need_opt = True need_opt = True
...@@ -1537,58 +1567,6 @@ class FunctionMaker: ...@@ -1537,58 +1567,6 @@ class FunctionMaker:
for i in self.inputs for i in self.inputs
] ]
def _check_unused_inputs(self, inputs, outputs, on_unused_input):
if on_unused_input is None:
on_unused_input = config.on_unused_input
if on_unused_input == "ignore":
return
# There should be two categories of variables in inputs:
# - variables that have to be provided (used_inputs)
# - shared variables that will be updated
used_inputs = list(
ancestors(
(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
),
blockers=[i.variable for i in inputs],
)
)
msg = (
"aesara.function was asked to create a function computing "
"outputs given certain inputs, but the provided input "
"variable at index %i is not part of the computational graph "
"needed to compute the outputs: %s.\n%s"
)
warn_msg = (
"To make this warning into an error, you can pass the "
"parameter on_unused_input='raise' to aesara.function. "
"To disable it completely, use on_unused_input='ignore'."
)
err_msg = (
"To make this error into a warning, you can pass the "
"parameter on_unused_input='warn' to aesara.function. "
"To disable it completely, use on_unused_input='ignore'."
)
for i in inputs:
if (i.variable not in used_inputs) and (i.update is None):
if on_unused_input == "warn":
warnings.warn(
msg % (inputs.index(i), i.variable, warn_msg), stacklevel=6
)
elif on_unused_input == "raise":
raise UnusedInputError(msg % (inputs.index(i), i.variable, err_msg))
else:
raise ValueError(
"Invalid value for keyword on_unused_input of aesara.function: "
f"'{on_unused_input}'.\n"
"Valid values are 'raise', 'warn', and 'ignore'."
)
def create(self, input_storage=None, trustme=False, storage_map=None): def create(self, input_storage=None, trustme=False, storage_map=None):
""" """
Create a function. Create a function.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论