提交 e7d77b00 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a missing inputs check to OpFromGraph

上级 e464e50a
...@@ -22,6 +22,7 @@ from aesara.graph.fg import FunctionGraph ...@@ -22,6 +22,7 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
from aesara.graph.op import HasInnerGraph, Op from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.opt import in2out, local_optimizer from aesara.graph.opt import in2out, local_optimizer
from aesara.graph.utils import MissingInputError
from aesara.tensor.basic_opt import ShapeFeature from aesara.tensor.basic_opt import ShapeFeature
...@@ -340,6 +341,10 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -340,6 +341,10 @@ class OpFromGraph(Op, HasInnerGraph):
if isinstance(i, SharedVariable): if isinstance(i, SharedVariable):
raise TypeError(f"SharedVariables not allowed as inputs; {i}") raise TypeError(f"SharedVariables not allowed as inputs; {i}")
for var in graph_inputs(outputs, inputs):
if var not in inputs and not isinstance(var, (Constant, SharedVariable)):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
if "updates" in kwargs or "givens" in kwargs: if "updates" in kwargs or "givens" in kwargs:
raise NotImplementedError("Updates and givens are not allowed here") raise NotImplementedError("Updates and givens are not allowed here")
...@@ -362,6 +367,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -362,6 +367,7 @@ class OpFromGraph(Op, HasInnerGraph):
local_outputs, local_outputs,
(clone_d, update_d, update_expr, shared_inputs), (clone_d, update_d, update_expr, shared_inputs),
) = new ) = new
assert len(local_inputs) == len(inputs) + len(self.shared_inputs) assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
assert len(local_outputs) == len(outputs) assert len(local_outputs) == len(outputs)
assert not update_d assert not update_d
......
...@@ -12,6 +12,7 @@ from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad ...@@ -12,6 +12,7 @@ from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.graph.utils import MissingInputError
from aesara.printing import debugprint from aesara.printing import debugprint
from aesara.tensor.basic import as_tensor from aesara.tensor.basic import as_tensor
from aesara.tensor.basic_opt import ShapeOptimizer from aesara.tensor.basic_opt import ShapeOptimizer
...@@ -503,6 +504,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -503,6 +504,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
out_fn = function([], out) out_fn = function([], out)
assert np.array_equal(out_fn(), 2.0) assert np.array_equal(out_fn(), 2.0)
def test_missing_input(self):
x = at.lscalar("x")
with pytest.raises(MissingInputError):
OpFromGraph([], [x])
def test_debugprint(): def test_debugprint():
x, y, z = matrices("xyz") x, y, z = matrices("xyz")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论