提交 e7d77b00 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a missing inputs check to OpFromGraph

上级 e464e50a
......@@ -22,6 +22,7 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.opt import in2out, local_optimizer
from aesara.graph.utils import MissingInputError
from aesara.tensor.basic_opt import ShapeFeature
......@@ -340,6 +341,10 @@ class OpFromGraph(Op, HasInnerGraph):
if isinstance(i, SharedVariable):
raise TypeError(f"SharedVariables not allowed as inputs; {i}")
for var in graph_inputs(outputs, inputs):
if var not in inputs and not isinstance(var, (Constant, SharedVariable)):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
if "updates" in kwargs or "givens" in kwargs:
raise NotImplementedError("Updates and givens are not allowed here")
......@@ -362,6 +367,7 @@ class OpFromGraph(Op, HasInnerGraph):
local_outputs,
(clone_d, update_d, update_expr, shared_inputs),
) = new
assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
......
......@@ -12,6 +12,7 @@ from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.utils import MissingInputError
from aesara.printing import debugprint
from aesara.tensor.basic import as_tensor
from aesara.tensor.basic_opt import ShapeOptimizer
......@@ -503,6 +504,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
out_fn = function([], out)
assert np.array_equal(out_fn(), 2.0)
def test_missing_input(self):
x = at.lscalar("x")
with pytest.raises(MissingInputError):
OpFromGraph([], [x])
def test_debugprint():
x, y, z = matrices("xyz")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论