提交 762c4c5b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Remove redundant cloning when swapping nominal variables in OpFromGraph

上级 18cd693a
......@@ -19,7 +19,6 @@ from pytensor.graph.basic import (
clone_replace,
graph_inputs,
io_connection_pattern,
replace_nominals_with_dummies,
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType
......@@ -333,25 +332,33 @@ class OpFromGraph(Op, HasInnerGraph):
if not (isinstance(inputs, list) and isinstance(outputs, list)):
raise TypeError("Inputs and outputs must be lists")
for i in inputs + outputs:
if not isinstance(i, Variable):
for out in outputs:
if not isinstance(out, Variable):
raise TypeError(
f"Inputs and outputs must be Variable instances; got {i}"
f"Inputs and outputs must be Variable instances; got {out}"
)
if i in inputs:
if isinstance(i, Constant):
raise TypeError(f"Constants not allowed as inputs; {i}")
if isinstance(i, SharedVariable):
raise TypeError(f"SharedVariables not allowed as inputs; {i}")
dummy_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)
dummy_inputs.append(inp.type())
if "updates" in kwargs or "givens" in kwargs:
raise NotImplementedError("Updates and givens are not supported")
self.is_inline = inline
dummy_shared_inputs = []
self.shared_inputs = []
inner_graph_inputs = graph_inputs(outputs, inputs)
for var in inner_graph_inputs:
for var in graph_inputs(outputs, inputs):
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
......@@ -359,26 +366,17 @@ class OpFromGraph(Op, HasInnerGraph):
# That's why we collect the shared variables and replace them
# with dummies.
self.shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
# The inputs should be `NominalVariable`s, so that graphs can be merged
replacements = {}
for n, v in enumerate(inputs):
replacements[v] = NominalVariable(n, v.type)
shared_vars = [
NominalVariable(n, var.type)
for n, var in enumerate(self.shared_inputs, start=len(inputs) + 1)
]
replacements.update(dict(zip(self.shared_inputs, shared_vars)))
replacements = dict(
zip(inputs + self.shared_inputs, dummy_inputs + dummy_shared_inputs)
)
new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + shared_vars,
inputs=inputs + self.shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
......@@ -395,6 +393,21 @@ class OpFromGraph(Op, HasInnerGraph):
assert not shared_inputs
self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
# The inputs need to be `NominalVariable`s so that we can merge
# inner-graphs
nominal_local_inputs = tuple(
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
)
self.fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
for i, inp in enumerate(self.fgraph.inputs):
nom_inp = nominal_local_inputs[i]
self.fgraph.inputs[i] = nom_inp
self.fgraph.clients.pop(inp, None)
self.fgraph.add_input(nom_inp)
self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
......@@ -417,6 +430,7 @@ class OpFromGraph(Op, HasInnerGraph):
else:
self.set_lop_overrides("default")
self._lop_type = "lop"
self.set_rop_overrides(rop_overrides)
self._connection_pattern = connection_pattern
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论