提交 a7e53331 authored 作者: james@mackie's avatar james@mackie

moved tensorOp impl and perform, as well as _constructor to gof.op

上级 a86d558a
...@@ -15,6 +15,16 @@ __all__ = ['Op', ...@@ -15,6 +15,16 @@ __all__ = ['Op',
] ]
def constructor(op_cls):
"""Make an Op look like a Result-valued function."""
def f(*args, **kwargs):
op = op_cls(*args, **kwargs)
if len(op.outputs) > 1:
return op.outputs
else:
return op.outputs[0]
return f
class Op(object): class Op(object):
""" """
Op represents a computation on the storage in its 'inputs' slot, Op represents a computation on the storage in its 'inputs' slot,
...@@ -41,9 +51,8 @@ class Op(object): ...@@ -41,9 +51,8 @@ class Op(object):
doc = "Same as self.outputs[0] if this Op's has_default_output field is True.") doc = "Same as self.outputs[0] if this Op's has_default_output field is True.")
def __init__(self, *inputs): def __init__(self, **kwargs):
# this might be a bit brainless pass
raise AbstractFunctionError("Op is an abstract class. Its constructor does nothing, you must override it.")
def get_input(self, i): def get_input(self, i):
return self._inputs[i] return self._inputs[i]
...@@ -114,13 +123,33 @@ class Op(object): ...@@ -114,13 +123,33 @@ class Op(object):
# #
# perform # perform
# #
def impl(self, *args):
"""Return output data [tuple], given input data
If this Op has a single output (len(self.outputs)==1) then the return
value of this function will be assigned to self.outputs[0].data.
If this Op has multiple otuputs, then this function should return a
tuple with the data for outputs[0], outputs[1], outputs[2], etc.
"""
raise AbstractFunctionError()
def perform(self): def perform(self):
""" """
Performs the computation associated to this Op and places the Performs the computation associated to this Op and places the
result(s) in the output Results. result(s) in the output Results.
TODO: consider moving this function to the python linker.
""" """
raise AbstractFunctionError() res = self.impl(*[input.data for input in self.inputs])
if self.nout == 1:
self.outputs[0].data = res
else:
assert len(res) == len(self.outputs)
for output, value in zip(self.outputs, res):
output.data = value
# #
...@@ -196,7 +225,7 @@ class Op(object): ...@@ -196,7 +225,7 @@ class Op(object):
raise AbstractFunctionError() raise AbstractFunctionError()
#TODO: consider adding a flag to the base class that toggles this behaviour
class GuardedOp(Op): class GuardedOp(Op):
"""An Op that disallows input properties to change after construction""" """An Op that disallows input properties to change after construction"""
......
...@@ -6,6 +6,7 @@ import numpy ...@@ -6,6 +6,7 @@ import numpy
from gof import ResultBase, Op, utils, Destroyer, Viewer, AbstractFunctionError from gof import ResultBase, Op, utils, Destroyer, Viewer, AbstractFunctionError
import gof.result import gof.result
import gof.op
from base_tensor import BaseTensor, BaseTensorOp from base_tensor import BaseTensor, BaseTensorOp
from elemwise import Elemwise from elemwise import Elemwise
...@@ -130,17 +131,6 @@ class _Op(BaseTensorOp): ...@@ -130,17 +131,6 @@ class _Op(BaseTensorOp):
def input_wrapper(cls, obj): def input_wrapper(cls, obj):
return _as_tensor(obj) return _as_tensor(obj)
def impl(self, *inputs):
raise AbstractFunctionError()
def perform(self):
res = self.impl(*[input.data for input in self.inputs])
if self.nout == 1:
self.outputs[0].data = res
else:
for output, value in zip(self.outputs, res):
output.data = value
def c_var_names(self): def c_var_names(self):
(self, inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl) (self, inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl)
inames = utils.from_return_values(inames) inames = utils.from_return_values(inames)
...@@ -231,18 +221,6 @@ class TensorScalarOp(_Elemwise): ...@@ -231,18 +221,6 @@ class TensorScalarOp(_Elemwise):
def c_code_foreach(self): def c_code_foreach(self):
return "%%(z)s_i = %s;" % self.c_expr return "%%(z)s_i = %s;" % self.c_expr
def _constructor(op_cls):
"""Return a function that calls op_cls(*input)
and returns the outputs of the op (with single outputs unpacked)
"""
def f(*args, **kwargs):
op = op_cls(*args, **kwargs)
if len(op.outputs) > 1:
return op.outputs
else:
return op.outputs[0]
return f
########################## ##########################
# Unary Operations # Unary Operations
...@@ -276,7 +254,7 @@ class Argmax(Op): ...@@ -276,7 +254,7 @@ class Argmax(Op):
x = self.inputs[0].data x = self.inputs[0].data
self.outputs[0].data = numpy.max(x, axis) self.outputs[0].data = numpy.max(x, axis)
self.outputs[1].data = numpy.argmax(x,axis) self.outputs[1].data = numpy.argmax(x,axis)
argmax = _constructor(Argmax) argmax = gof.op.constructor(Argmax)
def max(x, axis=None): def max(x, axis=None):
"""Return maximum elements obtained by iterating over given axis """Return maximum elements obtained by iterating over given axis
...@@ -292,7 +270,7 @@ class Exp(_Elemwise): ...@@ -292,7 +270,7 @@ class Exp(_Elemwise):
def impl(self, x): return numpy.exp(x) def impl(self, x): return numpy.exp(x)
def grad(self, x, gz): return gz * exp(x) def grad(self, x, gz): return gz * exp(x)
def c_foreach(self, (x_i, ), (z_i, )): return "z_i = exp(x_i);" def c_foreach(self, (x_i, ), (z_i, )): return "z_i = exp(x_i);"
exp = _constructor(Exp) exp = gof.op.constructor(Exp)
class Neg(_Elemwise): class Neg(_Elemwise):
...@@ -308,13 +286,13 @@ class Log(_Elemwise): ...@@ -308,13 +286,13 @@ class Log(_Elemwise):
def impl(self, x): return numpy.log(x) def impl(self, x): return numpy.log(x)
def grad(self, x, gz): return gz / x def grad(self, x, gz): return gz / x
def c_foreach(self, (x_i, ), (z_i, )): return "z_i = log(x_i);" def c_foreach(self, (x_i, ), (z_i, )): return "z_i = log(x_i);"
log = _constructor(Log) log = gof.op.constructor(Log)
class Log2(_Elemwise): class Log2(_Elemwise):
def impl(self, x): return numpy.log2(x) def impl(self, x): return numpy.log2(x)
def grad(self, x, gz): return gz / (x * numpy.log(2.0)) def grad(self, x, gz): return gz / (x * numpy.log(2.0))
def c_foreach(self, (x_i, ), (z_i, )): return "%(z)s_i = log2(%(x)s_i);" def c_foreach(self, (x_i, ), (z_i, )): return "%(z)s_i = log2(%(x)s_i);"
log2 = _constructor(Log2) log2 = gof.op.constructor(Log2)
class Sgn(_Elemwise): class Sgn(_Elemwise):
def impl(self, x): def impl(self, x):
...@@ -323,19 +301,19 @@ class Sgn(_Elemwise): ...@@ -323,19 +301,19 @@ class Sgn(_Elemwise):
return [None] return [None]
def c_foreach(self, (x_i, ), (z_i, )): def c_foreach(self, (x_i, ), (z_i, )):
return "%(z)s_i = %(x)s_i/abs(%(x)s_i);" # TODO: C use copysign return "%(z)s_i = %(x)s_i/abs(%(x)s_i);" # TODO: C use copysign
sgn = _constructor(Sgn) sgn = gof.op.constructor(Sgn)
class Sqr(_Elemwise): class Sqr(_Elemwise):
def impl(self, x): return x * x def impl(self, x): return x * x
def grad(self, x, gz): return 2.0 * x * gz def grad(self, x, gz): return 2.0 * x * gz
def c_foreach(self, (x_i, ), (z_i, )): return "%(z)s_i = %(x)s_i * %(x)s_i;" def c_foreach(self, (x_i, ), (z_i, )): return "%(z)s_i = %(x)s_i * %(x)s_i;"
sqr = _constructor(Sqr) sqr = gof.op.constructor(Sqr)
class Sqrt(_Elemwise): class Sqrt(_Elemwise):
def impl(self, x): return numpy.sqrt(x) def impl(self, x): return numpy.sqrt(x)
def grad(self, x, gz): return 0.5 * gz / sqrt(x) def grad(self, x, gz): return 0.5 * gz / sqrt(x)
def c_foreach(self, (x_i, ), (z_i, )): return "%(z)s_i = sqrt(%(x)s_i);" def c_foreach(self, (x_i, ), (z_i, )): return "%(z)s_i = sqrt(%(x)s_i);"
sqrt = _constructor(Sqrt) sqrt = gof.op.constructor(Sqrt)
class Sum(_Elemwise): class Sum(_Elemwise):
def impl(self, x): def impl(self, x):
...@@ -348,7 +326,7 @@ class Sum(_Elemwise): ...@@ -348,7 +326,7 @@ class Sum(_Elemwise):
return "dtype_%(sum)s* %(sum)sp = ((dtype_%(sum)s*)PyArray_DATA(%(sum)s)); %(sum)sp[0] = 0;" return "dtype_%(sum)s* %(sum)sp = ((dtype_%(sum)s*)PyArray_DATA(%(sum)s)); %(sum)sp[0] = 0;"
def c_foreach(self, (x_i, ), (sum, )): def c_foreach(self, (x_i, ), (sum, )):
return "%(sum)sp[0] += %(x)s_i;" return "%(sum)sp[0] += %(x)s_i;"
sum = _constructor(Sum) sum = gof.op.constructor(Sum)
class Fill(_Elemwise): class Fill(_Elemwise):
def impl(self, model, value): def impl(self, model, value):
...@@ -359,7 +337,7 @@ class Fill(_Elemwise): ...@@ -359,7 +337,7 @@ class Fill(_Elemwise):
return "dtype_%(value)s %(value)s0 = ((dtype_%(value)s*)PyArray_DATA(%(value)s))[0];" return "dtype_%(value)s %(value)s0 = ((dtype_%(value)s*)PyArray_DATA(%(value)s))[0];"
def c_foreach(self, (model_i, value), (z_i, )): def c_foreach(self, (model_i, value), (z_i, )):
return "%(z)s_i = %(value)s0;" return "%(z)s_i = %(value)s0;"
fill = _constructor(Fill) fill = gof.op.constructor(Fill)
def ones_like(model): def ones_like(model):
return fill(model, 1.0) return fill(model, 1.0)
def zeros_like(model): def zeros_like(model):
...@@ -373,7 +351,7 @@ class TensorCopy(_Elemwise): ...@@ -373,7 +351,7 @@ class TensorCopy(_Elemwise):
return gz return gz
def c_foreach(self, (x_i, ), (z_i, )): def c_foreach(self, (x_i, ), (z_i, )):
return "%(z)s_i = %(x)s_i;" return "%(z)s_i = %(x)s_i;"
tensor_copy = _constructor(TensorCopy) tensor_copy = gof.op.constructor(TensorCopy)
########################## ##########################
# View Operations # View Operations
...@@ -399,7 +377,7 @@ class Transpose(_Op, Viewer): ...@@ -399,7 +377,7 @@ class Transpose(_Op, Viewer):
} }
%(z)s = transposed; %(z)s = transposed;
""" """
transpose = _constructor(Transpose) transpose = gof.op.constructor(Transpose)
class Subtensor(Op, Viewer): class Subtensor(Op, Viewer):
nin = 2 nin = 2
...@@ -461,7 +439,7 @@ class Subtensor(Op, Viewer): ...@@ -461,7 +439,7 @@ class Subtensor(Op, Viewer):
# - option: return gz, but think about how to include a special addition # - option: return gz, but think about how to include a special addition
# function that works on a corresponding view of the original data # function that works on a corresponding view of the original data
raise NotImplementedError() raise NotImplementedError()
subtensor = _constructor(Subtensor) subtensor = gof.op.constructor(Subtensor)
########################## ##########################
...@@ -481,14 +459,14 @@ class AddElemwise(_Elemwise): ...@@ -481,14 +459,14 @@ class AddElemwise(_Elemwise):
return gz, gz return gz, gz
def c_foreach(self, (x_i, y_i), (z_i, )): def c_foreach(self, (x_i, y_i), (z_i, )):
return "%(z)s_i = %(x)s_i + %(y)s_i;" return "%(z)s_i = %(x)s_i + %(y)s_i;"
add_elemwise = _constructor(AddElemwise) add_elemwise = gof.op.constructor(AddElemwise)
class AddElemwiseInplace(AddElemwise.inplace_version()): class AddElemwiseInplace(AddElemwise.inplace_version()):
def impl(self, x, y): def impl(self, x, y):
_assert_same_shapes(x, y) _assert_same_shapes(x, y)
x += y x += y
return x return x
add_elemwise_inplace = _constructor(AddElemwiseInplace) add_elemwise_inplace = gof.op.constructor(AddElemwiseInplace)
# Scalar # # Scalar #
class AddScalar(TensorScalarOp): class AddScalar(TensorScalarOp):
...@@ -498,14 +476,14 @@ class AddScalar(TensorScalarOp): ...@@ -498,14 +476,14 @@ class AddScalar(TensorScalarOp):
def grad(self, (x, a), gz): def grad(self, (x, a), gz):
return gz, sum(gz) return gz, sum(gz)
c_expr = "x_i + a" c_expr = "x_i + a"
add_scalar = _constructor(AddScalar) add_scalar = gof.op.constructor(AddScalar)
class AddScalarInplace(AddScalar.inplace_version()): class AddScalarInplace(AddScalar.inplace_version()):
def impl(self, x, a): def impl(self, x, a):
_assert_tensor_scalar(x, a) _assert_tensor_scalar(x, a)
x += a x += a
return x return x
add_scalar_inplace = _constructor(AddScalarInplace) add_scalar_inplace = gof.op.constructor(AddScalarInplace)
add = _scalar_switch(add_elemwise, add_scalar, add_scalar) add = _scalar_switch(add_elemwise, add_scalar, add_scalar)
add_inplace = _scalar_switch(add_elemwise_inplace, add_scalar_inplace) add_inplace = _scalar_switch(add_elemwise_inplace, add_scalar_inplace)
...@@ -524,14 +502,14 @@ class SubElemwise(_Elemwise): ...@@ -524,14 +502,14 @@ class SubElemwise(_Elemwise):
return gz, -gz return gz, -gz
def c_foreach(self, (x_i, y_i), (z_i, )): def c_foreach(self, (x_i, y_i), (z_i, )):
return "%(z)s_i = %(x)s_i - %(y)s_i;" return "%(z)s_i = %(x)s_i - %(y)s_i;"
sub_elemwise = _constructor(SubElemwise) sub_elemwise = gof.op.constructor(SubElemwise)
class SubElemwiseInplace(SubElemwise.inplace_version()): class SubElemwiseInplace(SubElemwise.inplace_version()):
def impl(self, x, y): def impl(self, x, y):
_assert_same_shapes(x, y) _assert_same_shapes(x, y)
x -= y x -= y
return x return x
sub_elemwise_inplace = _constructor(SubElemwiseInplace) sub_elemwise_inplace = gof.op.constructor(SubElemwiseInplace)
# Scalar # # Scalar #
def sub_scalar_r(x, a): def sub_scalar_r(x, a):
...@@ -559,14 +537,14 @@ class MulElemwise(_Elemwise): ...@@ -559,14 +537,14 @@ class MulElemwise(_Elemwise):
return mul(y, gz), mul(x, gz) return mul(y, gz), mul(x, gz)
def c_foreach(self, (x_i, y_i), (z_i, )): def c_foreach(self, (x_i, y_i), (z_i, )):
return "%(z)s_i = %(x)s_i * %(y)s_i;" return "%(z)s_i = %(x)s_i * %(y)s_i;"
mul_elemwise = _constructor(MulElemwise) mul_elemwise = gof.op.constructor(MulElemwise)
class MulElemwiseInplace(MulElemwise.inplace_version()): class MulElemwiseInplace(MulElemwise.inplace_version()):
def impl(self, x, y): def impl(self, x, y):
_assert_same_shapes(x, y) _assert_same_shapes(x, y)
x *= y x *= y
return x return x
mul_elemwise_inplace = _constructor(MulElemwiseInplace) mul_elemwise_inplace = gof.op.constructor(MulElemwiseInplace)
# Scalar # # Scalar #
class Scale(TensorScalarOp): class Scale(TensorScalarOp):
...@@ -576,14 +554,14 @@ class Scale(TensorScalarOp): ...@@ -576,14 +554,14 @@ class Scale(TensorScalarOp):
def grad(self, (x, a), gz): def grad(self, (x, a), gz):
return scale(a, gz), sum(mul_elemwise(x, gz)) return scale(a, gz), sum(mul_elemwise(x, gz))
c_expr = "%(x)s_i * _%(a)s" c_expr = "%(x)s_i * _%(a)s"
scale = _constructor(Scale) scale = gof.op.constructor(Scale)
class ScaleInplace(Scale.inplace_version()): class ScaleInplace(Scale.inplace_version()):
def impl(self, x, a): def impl(self, x, a):
_assert_tensor_scalar(x, a) _assert_tensor_scalar(x, a)
x *= a x *= a
return x return x
scale_inplace = _constructor(ScaleInplace) scale_inplace = gof.op.constructor(ScaleInplace)
mul = _scalar_switch(mul_elemwise, scale, scale) mul = _scalar_switch(mul_elemwise, scale, scale)
mul_inplace = _scalar_switch(mul_elemwise_inplace, scale_inplace) mul_inplace = _scalar_switch(mul_elemwise_inplace, scale_inplace)
...@@ -602,14 +580,14 @@ class DivElemwise(_Elemwise): ...@@ -602,14 +580,14 @@ class DivElemwise(_Elemwise):
return div(gz, y), -div(mul(x, gz), (y*y)) return div(gz, y), -div(mul(x, gz), (y*y))
def c_foreach(self, (x_i, y_i), (z_i, )): def c_foreach(self, (x_i, y_i), (z_i, )):
return "%(z)s_i = %(x)s_i / %(y)s_i;" return "%(z)s_i = %(x)s_i / %(y)s_i;"
div_elemwise = _constructor(DivElemwise) div_elemwise = gof.op.constructor(DivElemwise)
class DivElemwiseInplace(DivElemwise.inplace_version()): class DivElemwiseInplace(DivElemwise.inplace_version()):
def impl(self, x, y): def impl(self, x, y):
_assert_same_shapes(x, y) _assert_same_shapes(x, y)
x /= y x /= y
return x return x
div_elemwise_inplace = _constructor(DivElemwiseInplace) div_elemwise_inplace = gof.op.constructor(DivElemwiseInplace)
class InvElemwise(_Elemwise): class InvElemwise(_Elemwise):
def impl(self, x): def impl(self, x):
...@@ -619,7 +597,7 @@ class InvElemwise(_Elemwise): ...@@ -619,7 +597,7 @@ class InvElemwise(_Elemwise):
return -gz * (ix * ix) return -gz * (ix * ix)
def c_foreach(self, (x_i, ), (z_i, )): def c_foreach(self, (x_i, ), (z_i, )):
return "%(z)s_i = 1.0 / %(x)s_i;" #TODO: cast 1.0 to the dtype of x return "%(z)s_i = 1.0 / %(x)s_i;" #TODO: cast 1.0 to the dtype of x
inv_elemwise = _constructor(InvElemwise) inv_elemwise = gof.op.constructor(InvElemwise)
# Scalar # # Scalar #
def div_scalar_r(x, a): def div_scalar_r(x, a):
...@@ -653,14 +631,14 @@ class PowElemwise(_Elemwise): ...@@ -653,14 +631,14 @@ class PowElemwise(_Elemwise):
return gx, gy return gx, gy
def c_foreach(self, (x_i, y_i), (z_i, )): def c_foreach(self, (x_i, y_i), (z_i, )):
return "%(z)s_i = pow(%(x)s_i, %(y)s_i);" return "%(z)s_i = pow(%(x)s_i, %(y)s_i);"
pow_elemwise = _constructor(PowElemwise) pow_elemwise = gof.op.constructor(PowElemwise)
class PowElemwiseInplace(PowElemwise.inplace_version()): class PowElemwiseInplace(PowElemwise.inplace_version()):
def impl(self, x, y): def impl(self, x, y):
_assert_same_shapes(x, y) _assert_same_shapes(x, y)
x **= y x **= y
return x return x
pow_elemwise_inplace = _constructor(PowElemwiseInplace) pow_elemwise_inplace = gof.op.constructor(PowElemwiseInplace)
# Scalar # # Scalar #
class PowScalarL(TensorScalarOp): class PowScalarL(TensorScalarOp):
...@@ -672,7 +650,7 @@ class PowScalarL(TensorScalarOp): ...@@ -672,7 +650,7 @@ class PowScalarL(TensorScalarOp):
gy = gz * log(x) * x ** y gy = gz * log(x) * x ** y
return gy, gx return gy, gx
c_expr = "pow(%(a)s, %(x)s_i)" c_expr = "pow(%(a)s, %(x)s_i)"
pow_scalar_l = _constructor(PowScalarL) pow_scalar_l = gof.op.constructor(PowScalarL)
class PowScalarR(TensorScalarOp): class PowScalarR(TensorScalarOp):
def impl(self, x, a): def impl(self, x, a):
...@@ -683,14 +661,14 @@ class PowScalarR(TensorScalarOp): ...@@ -683,14 +661,14 @@ class PowScalarR(TensorScalarOp):
gs = sum(mul_elemwise(mul_elemwise(gz, pow_scalar_r(x,s)), log(x))) gs = sum(mul_elemwise(mul_elemwise(gz, pow_scalar_r(x,s)), log(x)))
return gx, gs return gx, gs
c_expr = "pow(%(x)s_i, _%(a)s)" c_expr = "pow(%(x)s_i, _%(a)s)"
pow_scalar_r = _constructor(PowScalarR) pow_scalar_r = gof.op.constructor(PowScalarR)
class PowScalarRInplace(PowScalarR.inplace_version()): class PowScalarRInplace(PowScalarR.inplace_version()):
def impl(self, x, a): def impl(self, x, a):
_assert_tensor_scalar(x, a) _assert_tensor_scalar(x, a)
x **= a x **= a
return x return x
pow_scalar_r_inplace = _constructor(PowScalarRInplace) pow_scalar_r_inplace = gof.op.constructor(PowScalarRInplace)
pow = _scalar_switch(pow_elemwise, pow_scalar_r, pow_scalar_l) pow = _scalar_switch(pow_elemwise, pow_scalar_r, pow_scalar_l)
pow_inplace = _scalar_switch(pow_elemwise_inplace, pow_scalar_r_inplace) pow_inplace = _scalar_switch(pow_elemwise_inplace, pow_scalar_r_inplace)
...@@ -729,7 +707,7 @@ class Dot(_Op): ...@@ -729,7 +707,7 @@ class Dot(_Op):
return blas.ldflags() return blas.ldflags()
def c_impl(self, (_x, _y), (_z, )): def c_impl(self, (_x, _y), (_z, )):
return blas.gemm_code('', '1.0', '0.0') return blas.gemm_code('', '1.0', '0.0')
dot = _constructor(Dot) dot = gof.op.constructor(Dot)
class Gemm(_Op): class Gemm(_Op):
nin=5 nin=5
...@@ -788,7 +766,7 @@ class Gemm(_Op): ...@@ -788,7 +766,7 @@ class Gemm(_Op):
return blas.gemm_code( check_ab, return blas.gemm_code( check_ab,
'(_a->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_a->data)[0]) : (REAL)(((double*)_a->data)[0])', '(_a->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_a->data)[0]) : (REAL)(((double*)_a->data)[0])',
'(_b->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_b->data)[0]) : (REAL)(((double*)_b->data)[0])') '(_b->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_b->data)[0]) : (REAL)(((double*)_b->data)[0])')
gemm = _constructor(Gemm) gemm = gof.op.constructor(Gemm)
if 0: if 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论