bla

上级 de75d2c1
...@@ -177,12 +177,41 @@ class omega_op(gof.PythonOp): ...@@ -177,12 +177,41 @@ class omega_op(gof.PythonOp):
def grad(*args): def grad(*args):
return UNDEFINED return UNDEFINED
def __create_c_code(self):
behavior = self.c_impl(self.inputs, self.outputs)
(inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl)
struct = """
struct _omega_%(name)s {
_omega_%(name)s() {}
void extract(void) {
}
void execute(void) {
%(code)s
}
void sync(void) {
}
};
""" % self.__class__.__name__, behavior
def c_alloc(self):
raise Exception("Cannot allocate output arrays for this Op.")
def c_impl(inputs, outputs): def c_impl(inputs, outputs):
raise NotImplementedError() raise NotImplementedError()
def c_thunk(self):
self.c_alloc()
if self.c_module:
a
else:
def c_perform(self): def c_perform(self):
pass self.c_thunk()()
def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None): def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
def f(x, y): def f(x, y):
...@@ -272,8 +301,43 @@ def tensor_scalar_op(impl): ...@@ -272,8 +301,43 @@ def tensor_scalar_op(impl):
return ret return ret
# @omega_op
# def add((x, y), (z, )):
# def grad(gz):
# return gz
# def c_alloc():
# return numpy.ndarray(x.shape, dtype = x.dtype)
# c_impl = """
# for (int i = 0; i < z.ncols; i++) {
# for (int j = 0; j < z.nrows; j++) {
# z(i, j) = x(i, j) + y(i, j);
# }
# }
# """
## Addition ## ## Addition ##
class add(omega_op):
impl = assert_same_shapes(numpy.ndarray.__add__)
def grad(x, y, gz):
return gz
def alloc(x, y):
return numpy.ndarray(x.shape, dtype = x.dtype)
def c_impl(x, y, z):
return """
for (int i = 0; i < z.ncols; i++) {
for (int j = 0; j < z.nrows; j++) {
z(i, j) = x(i, j) + y(i, j);
}
}
"""
class proto_add_elemwise(omega_op): class proto_add_elemwise(omega_op):
def grad(x, y, gz): def grad(x, y, gz):
return gz return gz
......
#ifndef _OMEGA_H
#define _OMEGA_H
//#include whatever defines PyArrayObject
template<typename T>
struct TMat_t
{
T * __restrict__ d;/**< pointer to element (0,0) */
size_t M; /**< number of rows */
size_t N; /**< number of columns */
size_t m; /**< row stride */
size_t n; /**< column stride */
bool invalid;
/** null */
TMat_t(const PyArrayObject *o) :
d((double*) o->data),
M((o->nd==2) ? o->dimensions[0] : 0),
N((o->nd==2) ? o->dimensions[1] : 0),
m((o->nd==2) ? o->strides[0] / sizeof(double) : 0),
n((o->nd==2) ? o->strides[1] / sizeof(double) : 0),
invalid((o->nd !=2) || (o->descr->elsize != sizeof(T)))
{
}
/** unsafe element access */
const T & operator()(size_t i, size_t j) const
{
return d[ i * m + j*n];
}
/** unsafe element access */
T & operator()(size_t i, size_t j)
{
return d[ i * m + j*n];
}
/** safe element access */
const T & at(size_t i, size_t j) const
{
return d[ assert((i < M) && (j < N)), i * m + j*n];
}
/** safe element access */
T & at(size_t i, size_t j)
{
return d[ assert((i < M) && (j < N)), i * m + j*n];
}
};
#endif
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论