提交 3d50b2f0 authored 作者: Frederic's avatar Frederic

Add c_code to the shape op.

上级 0845ddc3
......@@ -2459,6 +2459,30 @@ class Shape(Op):
def R_op(self, inputs, eval_points):
return [None]
def c_code(self, node, nodename, inp, out, sub):
x, = inp
z, = out
if isinstance(node.inputs[0].type, TensorType):
return """
npy_intp shape[] = {PyArray_NDIM(%(x)s)};
if(%(z)s == NULL || (PyArray_DIMS(%(z)s)[0] != shape[0]))
{
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, NPY_INT64);
}
for(int i=0;i<shape[0];i++)
{
((npy_int64*)PyArray_GETPTR1(%(z)s, i))[0] = PyArray_DIMS(%(x)s)[i];
}
""" % locals()
else:
#TODO: if your type is not listed here, make a damn registry of
# shape_i ops for various types of variables.
# Do not continue this madness.
return super(Shape_i, self).c_code(node, name, (x,), (out,), sub)
def c_code_cache_version(self):
return (1,)
@constructor
def old_shape(a):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论