提交 18cd693a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Remove unnecessary graph_inputs usage in OpFromGraph

上级 556816f3
...@@ -344,21 +344,23 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -344,21 +344,23 @@ class OpFromGraph(Op, HasInnerGraph):
if isinstance(i, SharedVariable): if isinstance(i, SharedVariable):
raise TypeError(f"SharedVariables not allowed as inputs; {i}") raise TypeError(f"SharedVariables not allowed as inputs; {i}")
for var in graph_inputs(outputs, inputs):
if var not in inputs and not isinstance(var, (Constant, SharedVariable)):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
if "updates" in kwargs or "givens" in kwargs: if "updates" in kwargs or "givens" in kwargs:
raise NotImplementedError("Updates and givens are not allowed here") raise NotImplementedError("Updates and givens are not supported")
self.is_inline = inline self.is_inline = inline
# To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient.
self.shared_inputs = [] self.shared_inputs = []
for var in graph_inputs(outputs): inner_graph_inputs = graph_inputs(outputs, inputs)
for var in inner_graph_inputs:
if isinstance(var, SharedVariable): if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
self.shared_inputs.append(var) self.shared_inputs.append(var)
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
inputs, outputs = replace_nominals_with_dummies(inputs, outputs) inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论