提交 e15c09df authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Generalize the inner-FunctionGraph construction process

上级 762c4c5b
......@@ -2,7 +2,7 @@
from collections import OrderedDict
from copy import copy
from functools import partial
from typing import List, Optional, Sequence, cast
from typing import Dict, List, Optional, Sequence, Tuple, cast
import pytensor.tensor as at
from pytensor import function
......@@ -81,6 +81,81 @@ def infer_shape(outs, inputs, input_shapes):
return ret
def construct_nominal_fgraph(
inputs: Sequence[Variable], outputs: Sequence[Variable]
) -> Tuple[
FunctionGraph,
Sequence[Variable],
Dict[Variable, Variable],
Dict[Variable, Variable],
]:
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
dummy_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)
dummy_inputs.append(inp.type())
dummy_shared_inputs = []
shared_inputs = []
for var in graph_inputs(outputs, inputs):
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))
new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
(
local_inputs,
local_outputs,
(clone_d, update_d, update_expr, new_shared_inputs),
) = new
assert len(local_inputs) == len(inputs) + len(shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not new_shared_inputs
fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
# The inputs need to be `NominalVariable`s so that we can merge
# inner-graphs
nominal_local_inputs = tuple(
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
)
fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
for i, inp in enumerate(fgraph.inputs):
nom_inp = nominal_local_inputs[i]
fgraph.inputs[i] = nom_inp
fgraph.clients.pop(inp, None)
fgraph.add_input(nom_inp)
return fgraph, shared_inputs, update_d, update_expr
class OpFromGraph(Op, HasInnerGraph):
r"""
This creates an `Op` from inputs and outputs lists of variables.
......@@ -338,76 +413,15 @@ class OpFromGraph(Op, HasInnerGraph):
f"Inputs and outputs must be Variable instances; got {out}"
)
dummy_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)
dummy_inputs.append(inp.type())
if "updates" in kwargs or "givens" in kwargs:
raise NotImplementedError("Updates and givens are not supported")
self.is_inline = inline
dummy_shared_inputs = []
self.shared_inputs = []
for var in graph_inputs(outputs, inputs):
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
self.shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
replacements = dict(
zip(inputs + self.shared_inputs, dummy_inputs + dummy_shared_inputs)
self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
inputs, outputs
)
new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + self.shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
(
local_inputs,
local_outputs,
(clone_d, update_d, update_expr, shared_inputs),
) = new
assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not shared_inputs
self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
# The inputs need to be `NominalVariable`s so that we can merge
# inner-graphs
nominal_local_inputs = tuple(
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
)
self.fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
for i, inp in enumerate(self.fgraph.inputs):
nom_inp = nominal_local_inputs[i]
self.fgraph.inputs[i] = nom_inp
self.fgraph.clients.pop(inp, None)
self.fgraph.add_input(nom_inp)
self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
......
......@@ -55,8 +55,7 @@ import numpy as np
import pytensor
from pytensor import tensor as at
from pytensor.compile import SharedVariable
from pytensor.compile.builders import infer_shape
from pytensor.compile.builders import construct_nominal_fgraph, infer_shape
from pytensor.compile.function.pfunc import pfunc
from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_default_mode, get_mode
......@@ -65,17 +64,13 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
from pytensor.graph.basic import (
Apply,
Constant,
NominalVariable,
Variable,
clone_replace,
equal_computations,
graph_inputs,
io_connection_pattern,
replace_nominals_with_dummies,
)
from pytensor.graph.features import NoOutputFromInplace
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.utils import InconsistencyError, MissingInputError
from pytensor.link.c.basic import CLinker
......@@ -755,22 +750,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
If ``True``, all the shared variables used in the inner-graph must be provided.
"""
inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
self.fgraph, shared_inputs, _, _ = construct_nominal_fgraph(inputs, outputs)
input_replacements = []
for n, v in enumerate(inputs):
if not isinstance(v, (SharedVariable, Constant)):
input_replacements.append((v, NominalVariable(n, v.type)))
assert not isinstance(v, NominalVariable)
outputs = clone_replace(outputs, replace=input_replacements)
if input_replacements:
_, inputs_ = zip(*input_replacements)
inputs = list(inputs_)
else:
inputs = []
# The shared variables should have been removed, so, if there are
# any, it's because the user didn't specify an input.
if shared_inputs:
raise MissingInputError(f"Scan is missing inputs: {shared_inputs}")
self.info = info
self.truncate_gradient = truncate_gradient
......@@ -782,7 +767,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if self.name:
message = self.name + " sub profile"
message = f"{self.name} sub profile"
else:
message = "Scan sub profile"
......@@ -805,7 +790,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
while idx < info.n_mit_mot_outs:
# Not that for mit_mot there are several output slices per
# output sequence
o = outputs[idx]
o = self.fgraph.outputs[idx]
self.output_types.append(
# TODO: What can we actually say about the shape of this
# added dimension?
......@@ -818,7 +803,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# mit_sot / sit_sot / nit_sot
end = idx + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
for o in outputs[idx:end]:
for o in self.fgraph.outputs[idx:end]:
self.output_types.append(
# TODO: What can we actually say about the shape of this
# added dimension?
......@@ -826,7 +811,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
# shared outputs + possibly the ending condition
for o in outputs[end:]:
for o in self.fgraph.outputs[end:]:
self.output_types.append(o.type)
if info.as_while:
......@@ -862,8 +847,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs
self.fgraph = FunctionGraph(inputs, outputs, clone=False)
_ = self.prepare_fgraph(self.fgraph)
if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
......@@ -871,10 +854,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
"Inner-graphs must not contain in-place operations."
)
# Do the missing inputs check here to have the error early.
for var in graph_inputs(self.inner_outputs, self.inner_inputs):
if var not in self.inner_inputs and not isinstance(var, Constant):
raise MissingInputError(f"ScanOp is missing an input: {repr(var)}")
self._cmodule_key = CLinker().cmodule_key_variables(
self.inner_inputs, self.inner_outputs, []
)
......
......@@ -586,10 +586,6 @@ class TestScan:
assert np.allclose(pytensor_values, v_out)
def test_oinp_iinp_iout_oout_mappings(self):
"""
Test the mapping produces by
ScanOp.get_oinp_iinp_iout_oout_mappings()
"""
rng = RandomStream(123)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论