提交 97317a50 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow defining an OpFromGraph from constant and shared inputs.

Also adds a strict flag
上级 339aab4d
......@@ -92,38 +92,29 @@ def construct_nominal_fgraph(
dict[Variable, Variable],
]:
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
dummy_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)
implicit_shared_inputs = []
dummy_inputs.append(inp.type())
dummy_shared_inputs = []
shared_inputs = []
dummy_inputs = [inp.type() for inp in inputs]
dummy_implicit_shared_inputs = []
for var in graph_inputs(outputs, inputs):
if var in inputs:
continue
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))
# We allow shared inputs to be added automatically to the graph
implicit_shared_inputs.append(var)
dummy_implicit_shared_inputs.append(var.type())
elif not isinstance(var, Constant):
raise MissingInputError(f"NominalGraph is missing an input: {var}")
replacements = dict(
zip(
inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs
)
)
new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + shared_inputs,
inputs=inputs + implicit_shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
......@@ -133,7 +124,7 @@ def construct_nominal_fgraph(
(clone_d, update_d, update_expr, new_shared_inputs),
) = new
assert len(local_inputs) == len(inputs) + len(shared_inputs)
assert len(local_inputs) == len(inputs) + len(implicit_shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
......@@ -155,7 +146,7 @@ def construct_nominal_fgraph(
fgraph.clients.pop(inp, None)
fgraph.add_input(nom_inp)
return fgraph, shared_inputs, update_d, update_expr
return fgraph, implicit_shared_inputs, update_d, update_expr
class OpFromGraph(Op, HasInnerGraph):
......@@ -177,8 +168,6 @@ class OpFromGraph(Op, HasInnerGraph):
- grad() make it support DisconnectedType and the new interface
- add support for NullType and DisconnectedType when R_op supports them
- check how it works with updates.
- add test with constant as input or inside the inner graph.
- Add support for the GPU? Probably just need an opt to remove transfer
- Add support to pickle this Op.
- Add support/test with random generator
- Add optimization to removing unused inputs/outputs
......@@ -310,11 +299,13 @@ class OpFromGraph(Op, HasInnerGraph):
self,
inputs: list[Variable],
outputs: list[Variable],
*,
inline: bool = False,
lop_overrides: str = "default",
grad_overrides: str = "default",
rop_overrides: str = "default",
connection_pattern: Optional[list[list[bool]]] = None,
strict: bool = False,
name: Optional[str] = None,
**kwargs,
):
......@@ -399,6 +390,10 @@ class OpFromGraph(Op, HasInnerGraph):
must be equal to number of outputs. connection_pattern If not
``None``, this will be used as the connection_pattern for this
:class:`Op`.
strict: bool, default False
If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with
shared variables.
name
A name for debugging purposes.
kwargs
......@@ -424,6 +419,12 @@ class OpFromGraph(Op, HasInnerGraph):
inputs, outputs
)
if strict and self.shared_inputs:
raise ValueError(
"All variables needed to compute inner-graph must be provided as inputs under strict=True. "
f"The inner-graph implicitly depends on the following shared variables {self.shared_inputs}"
)
self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
......
......@@ -15,7 +15,7 @@ from pytensor.graph.null_type import NullType
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint
from pytensor.tensor.basic import as_tensor
from pytensor.tensor.basic import constant
from pytensor.tensor.math import dot, exp, sigmoid
from pytensor.tensor.math import round as pt_round
from pytensor.tensor.math import sum as pt_sum
......@@ -43,12 +43,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
with pytest.raises(TypeError):
OpFromGraph([1], [1])
with pytest.raises(TypeError):
OpFromGraph([x, as_tensor(1)], [x])
with pytest.raises(TypeError):
OpFromGraph([shared(1)], [1])
with pytest.raises(NotImplementedError):
OpFromGraph([x], [x], updates={})
......@@ -559,6 +553,31 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
# The original `op.fgraph` outputs should stay the same, though
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])
def test_explicit_input_from_constant(self):
x = pt.dscalar("x")
y = constant(1.0, name="y")
test_ofg = OpFromGraph([x, y], [x + y])
out = test_ofg(x, y)
assert out.eval({x: 5}) == 6
def test_explicit_input_from_shared(self):
x = pt.dscalar("x")
y = shared(1.0, name="y")
with pytest.raises(
ValueError,
match=r"The inner-graph implicitly depends on the following shared variables \[y\]",
):
OpFromGraph([x], [x + y], strict=True)
test_ofg = OpFromGraph([x, y], [x + y], strict=True)
out = test_ofg(x, y)
assert out.eval({x: 5}) == 6
y.set_value(2.0)
assert out.eval({x: 6})
@config.change_flags(floatX="float64")
def test_debugprint():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论