提交 ae72860a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1659 from nouiz/rebroadcast

Add c code to rebroadcast.
......@@ -3335,6 +3335,37 @@ class Rebroadcast(Op):
return [None]
return self(*eval_points, **dict(return_list=True))
def c_code(self, node, nodename, inp, out, sub):
iname, = inp
oname, = out
fail = sub['fail']
if isinstance(node.inputs[0].type, TensorType):
code = ""
for axis, value in self.axis.iteritems():
if value:
code += """
if(PyArray_DIMS(%(iname)s)[%(axis)s] != 1){
PyErr_Format(PyExc_ValueError,
"Dimension %(axis)s in Rebroadcast's input was"
" supposed to be 1 (got %%d instead)",
PyArray_DIMS(%(iname)s)[%(axis)s]);
%(fail)s
}
""" % locals()
return code + """
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""" % locals()
else:
#TODO: if your type is not listed here, make a damn registry of
# shape_i ops for various types of variables.
# Do not continue this madness.
return super(Shape, self).c_code(node, nodename, (x,), (out,), sub)
def c_code_cache_version(self):
return (1,)
def addbroadcast(x, *axes):
"""
......@@ -3507,6 +3538,36 @@ class Join(Op):
out[0] = theano._asarray(numpy.concatenate(tensors, axis=axis),
dtype=node.outputs[0].type.dtype)
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
axis, tensors = inputs[0], inputs[1:]
l = len(tensors)
out, = outputs
fail = sub['fail']
code = """
PyObject* list = PyList_New(%(l)s);
""" % locals()
for i, inp in enumerate(tensors):
code += """
Py_INCREF(%(inp)s);
PyList_SetItem(list, %(i)s, (PyObject*)%(inp)s);
""" % locals()
code += """
//PyObject* PyArray_Concatenate(PyObject* obj, int axis)
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)PyArray_Concatenate(list,
((dtype_%(axis)s *)PyArray_DATA(%(axis)s))[0]);
Py_DECREF(list);
if(!%(out)s){
%(fail)s
}
""" % locals()
return code
def R_op(self, inputs, eval_points):
if None in eval_points[1:]:
return [None]
......
......@@ -594,7 +594,10 @@ class MakeVector(T.Op):
ret = """
npy_intp dims[1];
dims[0] = %(out_shape)s;
%(out)s = (PyArrayObject*)PyArray_EMPTY(1, dims, %(out_dtype)s, 0);
if(!%(out)s || PyArray_DIMS(%(out)s)[0] == %(out_shape)s){
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject*)PyArray_EMPTY(1, dims, %(out_dtype)s, 0);
}
""" % locals()
for idx, i in enumerate(inp):
ret += """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论