提交 bb703bd8 authored 作者: Frederic's avatar Frederic

OpFromGraph now compile in make_thunk().

上级 f6537797
...@@ -25,8 +25,8 @@ class OpFromGraph(gof.Op): ...@@ -25,8 +25,8 @@ class OpFromGraph(gof.Op):
- support shared var - support shared var
- __hash__, __eq__ otherwise won't merge - __hash__, __eq__ otherwise won't merge
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- move call to function to make_thunk().
- opt to unfold it, work inplace on inputs - opt to unfold it, work inplace on inputs
- grad() make it support DisconnectedType and the new interface
""" """
def __init__(self, inputs, outputs, **kwargs): def __init__(self, inputs, outputs, **kwargs):
...@@ -42,13 +42,12 @@ class OpFromGraph(gof.Op): ...@@ -42,13 +42,12 @@ class OpFromGraph(gof.Op):
shared_inputs = [var for var in gof.graph.inputs(outputs) shared_inputs = [var for var in gof.graph.inputs(outputs)
if isinstance(var, SharedVariable)] if isinstance(var, SharedVariable)]
if shared_inputs: if shared_inputs:
raise NotImplementedError("OpFromGraph do not support SharedVariable in the inner graph") raise NotImplementedError(
# TODO: the graph may have implicit inputs like "OpFromGraph do not support SharedVariable in the inner graph")
# SharedVariable instances.
# what impact to they have on the validity of this Op?
self.fn = orig_function(inputs, outputs, **kwargs)
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.kwargs = kwargs
self.input_types = [input.type for input in inputs] self.input_types = [input.type for input in inputs]
self.output_types = [output.type for output in outputs] self.output_types = [output.type for output in outputs]
...@@ -70,6 +69,13 @@ class OpFromGraph(gof.Op): ...@@ -70,6 +69,13 @@ class OpFromGraph(gof.Op):
inputs, inputs,
[type() for type in self.output_types]) [type() for type in self.output_types])
def make_thunk(self, node, storage_map, compute_map, no_recycling):
ret = super(OpFromGraph, self).make_thunk(node, storage_map,
compute_map, no_recycling)
if not hasattr(self, "fn"):
self.fn = orig_function(self.inputs, self.outputs, **self.kwargs)
return ret
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
variables = self.fn(*inputs) variables = self.fn(*inputs)
assert len(variables) == len(outputs) assert len(variables) == len(outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论