提交 0e107ac0 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make OpFromGraph use prepare_node

上级 3b923957
......@@ -124,17 +124,14 @@ class OpFromGraph(gof.Op):
list(inputs) + self.shared_inputs,
[type() for type in self.output_types])
def make_thunk(self, node, storage_map, compute_map, no_recycling):
ret = super(OpFromGraph, self).make_thunk(node, storage_map,
compute_map, no_recycling)
if not hasattr(self, "fn"):
self.fn = orig_function(self.new_inputs,
self.new_outputs,
**self.kwargs)
return ret
def prepare_node(self, node, storage_map, compute_map):
if not hasattr(node.tag, "fn"):
node.tag.fn = orig_function(self.new_inputs,
self.new_outputs,
**self.kwargs)
def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
variables = node.tag.fn(*inputs)
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables):
# TODO: when function's output-borrowing semantics are correct,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论