提交 7a7e27a6 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

OpFromGraph

上级 1059cf8f
...@@ -178,6 +178,45 @@ def fast_compute(*outputs): ...@@ -178,6 +178,45 @@ def fast_compute(*outputs):
class OpFromGraph(gof.Op):
"""
This create an Op from a list of input results and a list of output
results.
The signature is the same as the signature of FunctionFactory and/or
function and the resulting Op's perform will do the same operation as
function(inputs, outputs, **kwargs)
Take note that the following arguments will be forcefully set to
a particular value:
unpack_single = False
borrow_outputs = False
"""
def __init__(self, inputs, outputs, **kwargs):
kwargs['unpack_single'] = False
kwargs['borrow_outputs'] = False
self.fn = function(inputs, outputs, **kwargs)
self.input_types = [input.type for input in inputs]
self.output_types = [output.type for output in outputs]
def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types):
if not type == input.type:
raise TypeError("Wrong type, expected %s but got %s" % type, input.type)
return gof.Apply(self,
inputs,
[type() for type in self.output_types])
def perform(self, node, inputs, outputs):
results = self.fn(*inputs)
for output, result in zip(outputs, results):
output[0] = result
# class State: # class State:
# def __init__(self, init, next = None): # def __init__(self, init, next = None):
# self.init = init # self.init = init
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论