提交 a7e53331 authored 作者: james@mackie's avatar james@mackie

moved tensorOp impl and perform, as well as _constructor to gof.op

上级 a86d558a
...@@ -15,6 +15,16 @@ __all__ = ['Op', ...@@ -15,6 +15,16 @@ __all__ = ['Op',
] ]
def constructor(op_cls):
"""Make an Op look like a Result-valued function."""
def f(*args, **kwargs):
op = op_cls(*args, **kwargs)
if len(op.outputs) > 1:
return op.outputs
else:
return op.outputs[0]
return f
class Op(object): class Op(object):
""" """
Op represents a computation on the storage in its 'inputs' slot, Op represents a computation on the storage in its 'inputs' slot,
...@@ -41,9 +51,8 @@ class Op(object): ...@@ -41,9 +51,8 @@ class Op(object):
doc = "Same as self.outputs[0] if this Op's has_default_output field is True.") doc = "Same as self.outputs[0] if this Op's has_default_output field is True.")
def __init__(self, *inputs): def __init__(self, **kwargs):
# this might be a bit brainless pass
raise AbstractFunctionError("Op is an abstract class. Its constructor does nothing, you must override it.")
def get_input(self, i): def get_input(self, i):
return self._inputs[i] return self._inputs[i]
...@@ -114,13 +123,33 @@ class Op(object): ...@@ -114,13 +123,33 @@ class Op(object):
# #
# perform # perform
# #
def impl(self, *args):
"""Return output data [tuple], given input data
If this Op has a single output (len(self.outputs)==1) then the return
value of this function will be assigned to self.outputs[0].data.
If this Op has multiple otuputs, then this function should return a
tuple with the data for outputs[0], outputs[1], outputs[2], etc.
"""
raise AbstractFunctionError()
def perform(self): def perform(self):
""" """
Performs the computation associated to this Op and places the Performs the computation associated to this Op and places the
result(s) in the output Results. result(s) in the output Results.
TODO: consider moving this function to the python linker.
""" """
raise AbstractFunctionError() res = self.impl(*[input.data for input in self.inputs])
if self.nout == 1:
self.outputs[0].data = res
else:
assert len(res) == len(self.outputs)
for output, value in zip(self.outputs, res):
output.data = value
# #
...@@ -196,7 +225,7 @@ class Op(object): ...@@ -196,7 +225,7 @@ class Op(object):
raise AbstractFunctionError() raise AbstractFunctionError()
#TODO: consider adding a flag to the base class that toggles this behaviour
class GuardedOp(Op): class GuardedOp(Op):
"""An Op that disallows input properties to change after construction""" """An Op that disallows input properties to change after construction"""
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论