提交 e15c09df authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Generalize the inner-FunctionGraph construction process

上级 762c4c5b
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from collections import OrderedDict from collections import OrderedDict
from copy import copy from copy import copy
from functools import partial from functools import partial
from typing import List, Optional, Sequence, cast from typing import Dict, List, Optional, Sequence, Tuple, cast
import pytensor.tensor as at import pytensor.tensor as at
from pytensor import function from pytensor import function
...@@ -81,6 +81,81 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -81,6 +81,81 @@ def infer_shape(outs, inputs, input_shapes):
return ret return ret
def construct_nominal_fgraph(
inputs: Sequence[Variable], outputs: Sequence[Variable]
) -> Tuple[
FunctionGraph,
Sequence[Variable],
Dict[Variable, Variable],
Dict[Variable, Variable],
]:
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
dummy_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)
dummy_inputs.append(inp.type())
dummy_shared_inputs = []
shared_inputs = []
for var in graph_inputs(outputs, inputs):
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))
new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
(
local_inputs,
local_outputs,
(clone_d, update_d, update_expr, new_shared_inputs),
) = new
assert len(local_inputs) == len(inputs) + len(shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not new_shared_inputs
fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
# The inputs need to be `NominalVariable`s so that we can merge
# inner-graphs
nominal_local_inputs = tuple(
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
)
fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
for i, inp in enumerate(fgraph.inputs):
nom_inp = nominal_local_inputs[i]
fgraph.inputs[i] = nom_inp
fgraph.clients.pop(inp, None)
fgraph.add_input(nom_inp)
return fgraph, shared_inputs, update_d, update_expr
class OpFromGraph(Op, HasInnerGraph): class OpFromGraph(Op, HasInnerGraph):
r""" r"""
This creates an `Op` from inputs and outputs lists of variables. This creates an `Op` from inputs and outputs lists of variables.
...@@ -338,76 +413,15 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -338,76 +413,15 @@ class OpFromGraph(Op, HasInnerGraph):
f"Inputs and outputs must be Variable instances; got {out}" f"Inputs and outputs must be Variable instances; got {out}"
) )
dummy_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)
dummy_inputs.append(inp.type())
if "updates" in kwargs or "givens" in kwargs: if "updates" in kwargs or "givens" in kwargs:
raise NotImplementedError("Updates and givens are not supported") raise NotImplementedError("Updates and givens are not supported")
self.is_inline = inline self.is_inline = inline
dummy_shared_inputs = [] self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
self.shared_inputs = [] inputs, outputs
for var in graph_inputs(outputs, inputs):
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
self.shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
replacements = dict(
zip(inputs + self.shared_inputs, dummy_inputs + dummy_shared_inputs)
) )
new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + self.shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
(
local_inputs,
local_outputs,
(clone_d, update_d, update_expr, shared_inputs),
) = new
assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not shared_inputs
self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
# The inputs need to be `NominalVariable`s so that we can merge
# inner-graphs
nominal_local_inputs = tuple(
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
)
self.fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
for i, inp in enumerate(self.fgraph.inputs):
nom_inp = nominal_local_inputs[i]
self.fgraph.inputs[i] = nom_inp
self.fgraph.clients.pop(inp, None)
self.fgraph.add_input(nom_inp)
self.kwargs = kwargs self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs] self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs] self.output_types = [out.type for out in outputs]
......
...@@ -55,8 +55,7 @@ import numpy as np ...@@ -55,8 +55,7 @@ import numpy as np
import pytensor import pytensor
from pytensor import tensor as at from pytensor import tensor as at
from pytensor.compile import SharedVariable from pytensor.compile.builders import construct_nominal_fgraph, infer_shape
from pytensor.compile.builders import infer_shape
from pytensor.compile.function.pfunc import pfunc from pytensor.compile.function.pfunc import pfunc
from pytensor.compile.io import In, Out from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.mode import Mode, get_default_mode, get_mode
...@@ -65,17 +64,13 @@ from pytensor.configdefaults import config ...@@ -65,17 +64,13 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
from pytensor.graph.basic import ( from pytensor.graph.basic import (
Apply, Apply,
Constant,
NominalVariable,
Variable, Variable,
clone_replace, clone_replace,
equal_computations, equal_computations,
graph_inputs, graph_inputs,
io_connection_pattern, io_connection_pattern,
replace_nominals_with_dummies,
) )
from pytensor.graph.features import NoOutputFromInplace from pytensor.graph.features import NoOutputFromInplace
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.graph.utils import InconsistencyError, MissingInputError
from pytensor.link.c.basic import CLinker from pytensor.link.c.basic import CLinker
...@@ -755,22 +750,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -755,22 +750,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
If ``True``, all the shared variables used in the inner-graph must be provided. If ``True``, all the shared variables used in the inner-graph must be provided.
""" """
inputs, outputs = replace_nominals_with_dummies(inputs, outputs) self.fgraph, shared_inputs, _, _ = construct_nominal_fgraph(inputs, outputs)
input_replacements = [] # The shared variables should have been removed, so, if there are
for n, v in enumerate(inputs): # any, it's because the user didn't specify an input.
if not isinstance(v, (SharedVariable, Constant)): if shared_inputs:
input_replacements.append((v, NominalVariable(n, v.type))) raise MissingInputError(f"Scan is missing inputs: {shared_inputs}")
assert not isinstance(v, NominalVariable)
outputs = clone_replace(outputs, replace=input_replacements)
if input_replacements:
_, inputs_ = zip(*input_replacements)
inputs = list(inputs_)
else:
inputs = []
self.info = info self.info = info
self.truncate_gradient = truncate_gradient self.truncate_gradient = truncate_gradient
...@@ -782,7 +767,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -782,7 +767,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Clone mode_instance, altering "allow_gc" for the linker, # Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile # and adding a message if we profile
if self.name: if self.name:
message = self.name + " sub profile" message = f"{self.name} sub profile"
else: else:
message = "Scan sub profile" message = "Scan sub profile"
...@@ -805,7 +790,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -805,7 +790,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
while idx < info.n_mit_mot_outs: while idx < info.n_mit_mot_outs:
# Not that for mit_mot there are several output slices per # Not that for mit_mot there are several output slices per
# output sequence # output sequence
o = outputs[idx] o = self.fgraph.outputs[idx]
self.output_types.append( self.output_types.append(
# TODO: What can we actually say about the shape of this # TODO: What can we actually say about the shape of this
# added dimension? # added dimension?
...@@ -818,7 +803,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -818,7 +803,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# mit_sot / sit_sot / nit_sot # mit_sot / sit_sot / nit_sot
end = idx + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot end = idx + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
for o in outputs[idx:end]: for o in self.fgraph.outputs[idx:end]:
self.output_types.append( self.output_types.append(
# TODO: What can we actually say about the shape of this # TODO: What can we actually say about the shape of this
# added dimension? # added dimension?
...@@ -826,7 +811,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -826,7 +811,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
# shared outputs + possibly the ending condition # shared outputs + possibly the ending condition
for o in outputs[end:]: for o in self.fgraph.outputs[end:]:
self.output_types.append(o.type) self.output_types.append(o.type)
if info.as_while: if info.as_while:
...@@ -862,8 +847,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -862,8 +847,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.n_outer_inputs = info.n_outer_inputs self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs self.n_outer_outputs = info.n_outer_outputs
self.fgraph = FunctionGraph(inputs, outputs, clone=False)
_ = self.prepare_fgraph(self.fgraph) _ = self.prepare_fgraph(self.fgraph)
if any(node.op.destroy_map for node in self.fgraph.apply_nodes): if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
...@@ -871,10 +854,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -871,10 +854,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
"Inner-graphs must not contain in-place operations." "Inner-graphs must not contain in-place operations."
) )
# Do the missing inputs check here to have the error early.
for var in graph_inputs(self.inner_outputs, self.inner_inputs):
if var not in self.inner_inputs and not isinstance(var, Constant):
raise MissingInputError(f"ScanOp is missing an input: {repr(var)}")
self._cmodule_key = CLinker().cmodule_key_variables( self._cmodule_key = CLinker().cmodule_key_variables(
self.inner_inputs, self.inner_outputs, [] self.inner_inputs, self.inner_outputs, []
) )
......
...@@ -586,10 +586,6 @@ class TestScan: ...@@ -586,10 +586,6 @@ class TestScan:
assert np.allclose(pytensor_values, v_out) assert np.allclose(pytensor_values, v_out)
def test_oinp_iinp_iout_oout_mappings(self): def test_oinp_iinp_iout_oout_mappings(self):
"""
Test the mapping produces by
ScanOp.get_oinp_iinp_iout_oout_mappings()
"""
rng = RandomStream(123) rng = RandomStream(123)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论