提交 762c4c5b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Remove redundant cloning when swapping nominal variables in OpFromGraph

上级 18cd693a
...@@ -19,7 +19,6 @@ from pytensor.graph.basic import ( ...@@ -19,7 +19,6 @@ from pytensor.graph.basic import (
clone_replace, clone_replace,
graph_inputs, graph_inputs,
io_connection_pattern, io_connection_pattern,
replace_nominals_with_dummies,
) )
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
...@@ -333,25 +332,33 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -333,25 +332,33 @@ class OpFromGraph(Op, HasInnerGraph):
if not (isinstance(inputs, list) and isinstance(outputs, list)): if not (isinstance(inputs, list) and isinstance(outputs, list)):
raise TypeError("Inputs and outputs must be lists") raise TypeError("Inputs and outputs must be lists")
for i in inputs + outputs: for out in outputs:
if not isinstance(i, Variable): if not isinstance(out, Variable):
raise TypeError( raise TypeError(
f"Inputs and outputs must be Variable instances; got {i}" f"Inputs and outputs must be Variable instances; got {out}"
) )
if i in inputs:
if isinstance(i, Constant): dummy_inputs = []
raise TypeError(f"Constants not allowed as inputs; {i}") for n, inp in enumerate(inputs):
if isinstance(i, SharedVariable): if (
raise TypeError(f"SharedVariables not allowed as inputs; {i}") not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)
dummy_inputs.append(inp.type())
if "updates" in kwargs or "givens" in kwargs: if "updates" in kwargs or "givens" in kwargs:
raise NotImplementedError("Updates and givens are not supported") raise NotImplementedError("Updates and givens are not supported")
self.is_inline = inline self.is_inline = inline
dummy_shared_inputs = []
self.shared_inputs = [] self.shared_inputs = []
inner_graph_inputs = graph_inputs(outputs, inputs) for var in graph_inputs(outputs, inputs):
for var in inner_graph_inputs:
if isinstance(var, SharedVariable): if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should # To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with # not see them; otherwise, there will be problems with
...@@ -359,26 +366,17 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -359,26 +366,17 @@ class OpFromGraph(Op, HasInnerGraph):
# That's why we collect the shared variables and replace them # That's why we collect the shared variables and replace them
# with dummies. # with dummies.
self.shared_inputs.append(var) self.shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant): elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}") raise MissingInputError(f"OpFromGraph is missing an input: {var}")
inputs, outputs = replace_nominals_with_dummies(inputs, outputs) replacements = dict(
zip(inputs + self.shared_inputs, dummy_inputs + dummy_shared_inputs)
# The inputs should be `NominalVariable`s, so that graphs can be merged )
replacements = {}
for n, v in enumerate(inputs):
replacements[v] = NominalVariable(n, v.type)
shared_vars = [
NominalVariable(n, var.type)
for n, var in enumerate(self.shared_inputs, start=len(inputs) + 1)
]
replacements.update(dict(zip(self.shared_inputs, shared_vars)))
new = rebuild_collect_shared( new = rebuild_collect_shared(
cast(Sequence[Variable], outputs), cast(Sequence[Variable], outputs),
inputs=inputs + shared_vars, inputs=inputs + self.shared_inputs,
replace=replacements, replace=replacements,
copy_inputs_over=False, copy_inputs_over=False,
) )
...@@ -395,6 +393,21 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -395,6 +393,21 @@ class OpFromGraph(Op, HasInnerGraph):
assert not shared_inputs assert not shared_inputs
self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False) self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
# The inputs need to be `NominalVariable`s so that we can merge
# inner-graphs
nominal_local_inputs = tuple(
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
)
self.fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
for i, inp in enumerate(self.fgraph.inputs):
nom_inp = nominal_local_inputs[i]
self.fgraph.inputs[i] = nom_inp
self.fgraph.clients.pop(inp, None)
self.fgraph.add_input(nom_inp)
self.kwargs = kwargs self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs] self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs] self.output_types = [out.type for out in outputs]
...@@ -417,6 +430,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -417,6 +430,7 @@ class OpFromGraph(Op, HasInnerGraph):
else: else:
self.set_lop_overrides("default") self.set_lop_overrides("default")
self._lop_type = "lop" self._lop_type = "lop"
self.set_rop_overrides(rop_overrides) self.set_rop_overrides(rop_overrides)
self._connection_pattern = connection_pattern self._connection_pattern = connection_pattern
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论