提交 0e107ac0 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make OpFromGraph use prepare_node

上级 3b923957
...@@ -124,17 +124,14 @@ class OpFromGraph(gof.Op): ...@@ -124,17 +124,14 @@ class OpFromGraph(gof.Op):
list(inputs) + self.shared_inputs, list(inputs) + self.shared_inputs,
[type() for type in self.output_types]) [type() for type in self.output_types])
def make_thunk(self, node, storage_map, compute_map, no_recycling): def prepare_node(self, node, storage_map, compute_map):
ret = super(OpFromGraph, self).make_thunk(node, storage_map, if not hasattr(node.tag, "fn"):
compute_map, no_recycling) node.tag.fn = orig_function(self.new_inputs,
if not hasattr(self, "fn"):
self.fn = orig_function(self.new_inputs,
self.new_outputs, self.new_outputs,
**self.kwargs) **self.kwargs)
return ret
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
variables = self.fn(*inputs) variables = node.tag.fn(*inputs)
assert len(variables) == len(outputs) assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables): for output, variable in zip(outputs, variables):
# TODO: when function's output-borrowing semantics are correct, # TODO: when function's output-borrowing semantics are correct,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论