提交 97317a50 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow defining an OpFromGraph from constant and shared inputs.

Also adds a strict flag
上级 339aab4d
...@@ -92,38 +92,29 @@ def construct_nominal_fgraph( ...@@ -92,38 +92,29 @@ def construct_nominal_fgraph(
dict[Variable, Variable], dict[Variable, Variable],
]: ]:
"""Construct an inner-`FunctionGraph` with ordered nominal inputs.""" """Construct an inner-`FunctionGraph` with ordered nominal inputs."""
dummy_inputs = [] implicit_shared_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)
dummy_inputs.append(inp.type())
dummy_shared_inputs = [] dummy_inputs = [inp.type() for inp in inputs]
shared_inputs = [] dummy_implicit_shared_inputs = []
for var in graph_inputs(outputs, inputs): for var in graph_inputs(outputs, inputs):
if var in inputs:
continue
if isinstance(var, SharedVariable): if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should # We allow shared inputs to be added automatically to the graph
# not see them; otherwise, there will be problems with implicit_shared_inputs.append(var)
# gradients. dummy_implicit_shared_inputs.append(var.type())
# That's why we collect the shared variables and replace them elif not isinstance(var, Constant):
# with dummies. raise MissingInputError(f"NominalGraph is missing an input: {var}")
shared_inputs.append(var)
dummy_shared_inputs.append(var.type()) replacements = dict(
elif var not in inputs and not isinstance(var, Constant): zip(
raise MissingInputError(f"OpFromGraph is missing an input: {var}") inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs
)
replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs)) )
new = rebuild_collect_shared( new = rebuild_collect_shared(
cast(Sequence[Variable], outputs), cast(Sequence[Variable], outputs),
inputs=inputs + shared_inputs, inputs=inputs + implicit_shared_inputs,
replace=replacements, replace=replacements,
copy_inputs_over=False, copy_inputs_over=False,
) )
...@@ -133,7 +124,7 @@ def construct_nominal_fgraph( ...@@ -133,7 +124,7 @@ def construct_nominal_fgraph(
(clone_d, update_d, update_expr, new_shared_inputs), (clone_d, update_d, update_expr, new_shared_inputs),
) = new ) = new
assert len(local_inputs) == len(inputs) + len(shared_inputs) assert len(local_inputs) == len(inputs) + len(implicit_shared_inputs)
assert len(local_outputs) == len(outputs) assert len(local_outputs) == len(outputs)
assert not update_d assert not update_d
assert not update_expr assert not update_expr
...@@ -155,7 +146,7 @@ def construct_nominal_fgraph( ...@@ -155,7 +146,7 @@ def construct_nominal_fgraph(
fgraph.clients.pop(inp, None) fgraph.clients.pop(inp, None)
fgraph.add_input(nom_inp) fgraph.add_input(nom_inp)
return fgraph, shared_inputs, update_d, update_expr return fgraph, implicit_shared_inputs, update_d, update_expr
class OpFromGraph(Op, HasInnerGraph): class OpFromGraph(Op, HasInnerGraph):
...@@ -177,8 +168,6 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -177,8 +168,6 @@ class OpFromGraph(Op, HasInnerGraph):
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
- add support for NullType and DisconnectedType when R_op supports them - add support for NullType and DisconnectedType when R_op supports them
- check how it works with updates. - check how it works with updates.
- add test with constant as input or inside the inner graph.
- Add support for the GPU? Probably just need an opt to remove transfer
- Add support to pickle this Op. - Add support to pickle this Op.
- Add support/test with random generator - Add support/test with random generator
- Add optimization to removing unused inputs/outputs - Add optimization to removing unused inputs/outputs
...@@ -310,11 +299,13 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -310,11 +299,13 @@ class OpFromGraph(Op, HasInnerGraph):
self, self,
inputs: list[Variable], inputs: list[Variable],
outputs: list[Variable], outputs: list[Variable],
*,
inline: bool = False, inline: bool = False,
lop_overrides: str = "default", lop_overrides: str = "default",
grad_overrides: str = "default", grad_overrides: str = "default",
rop_overrides: str = "default", rop_overrides: str = "default",
connection_pattern: Optional[list[list[bool]]] = None, connection_pattern: Optional[list[list[bool]]] = None,
strict: bool = False,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs, **kwargs,
): ):
...@@ -399,6 +390,10 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -399,6 +390,10 @@ class OpFromGraph(Op, HasInnerGraph):
must be equal to number of outputs. connection_pattern If not must be equal to number of outputs. connection_pattern If not
``None``, this will be used as the connection_pattern for this ``None``, this will be used as the connection_pattern for this
:class:`Op`. :class:`Op`.
strict: bool, default False
If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with
shared variables.
name name
A name for debugging purposes. A name for debugging purposes.
kwargs kwargs
...@@ -424,6 +419,12 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -424,6 +419,12 @@ class OpFromGraph(Op, HasInnerGraph):
inputs, outputs inputs, outputs
) )
if strict and self.shared_inputs:
raise ValueError(
"All variables needed to compute inner-graph must be provided as inputs under strict=True. "
f"The inner-graph implicitly depends on the following shared variables {self.shared_inputs}"
)
self.kwargs = kwargs self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs] self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs] self.output_types = [out.type for out in outputs]
......
...@@ -15,7 +15,7 @@ from pytensor.graph.null_type import NullType ...@@ -15,7 +15,7 @@ from pytensor.graph.null_type import NullType
from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint from pytensor.printing import debugprint
from pytensor.tensor.basic import as_tensor from pytensor.tensor.basic import constant
from pytensor.tensor.math import dot, exp, sigmoid from pytensor.tensor.math import dot, exp, sigmoid
from pytensor.tensor.math import round as pt_round from pytensor.tensor.math import round as pt_round
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
...@@ -43,12 +43,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -43,12 +43,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
with pytest.raises(TypeError): with pytest.raises(TypeError):
OpFromGraph([1], [1]) OpFromGraph([1], [1])
with pytest.raises(TypeError):
OpFromGraph([x, as_tensor(1)], [x])
with pytest.raises(TypeError):
OpFromGraph([shared(1)], [1])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
OpFromGraph([x], [x], updates={}) OpFromGraph([x], [x], updates={})
...@@ -559,6 +553,31 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -559,6 +553,31 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
# The original `op.fgraph` outputs should stay the same, though # The original `op.fgraph` outputs should stay the same, though
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])
def test_explicit_input_from_constant(self):
x = pt.dscalar("x")
y = constant(1.0, name="y")
test_ofg = OpFromGraph([x, y], [x + y])
out = test_ofg(x, y)
assert out.eval({x: 5}) == 6
def test_explicit_input_from_shared(self):
x = pt.dscalar("x")
y = shared(1.0, name="y")
with pytest.raises(
ValueError,
match=r"The inner-graph implicitly depends on the following shared variables \[y\]",
):
OpFromGraph([x], [x + y], strict=True)
test_ofg = OpFromGraph([x, y], [x + y], strict=True)
out = test_ofg(x, y)
assert out.eval({x: 5}) == 6
y.set_value(2.0)
assert out.eval({x: 6})
@config.change_flags(floatX="float64") @config.change_flags(floatX="float64")
def test_debugprint(): def test_debugprint():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论