提交 cb59fa4e authored 作者: Frederic's avatar Frederic

Add c code to rebroadcast.

上级 d01ddc02
......@@ -3335,6 +3335,36 @@ class Rebroadcast(Op):
return [None]
return self(*eval_points, **dict(return_list=True))
def c_code(self, node, nodename, inp, out, sub):
iname, = inp
oname, = out
if isinstance(node.inputs[0].type, TensorType):
code = ""
for axis, value in self.axis.iteritems():
if value:
code += """
if(PyArray_DIMS(%(iname)s)[%(axis)s] != 1){
PyErr_Format(PyExc_ValueError,
"Dimension %(axis)s in Rebroadcast's input was"
" supposed to be 1 (got %s instead)",
PyArray_DIMS(%(iname)s)[%(axis)s]);
%(fail)s
}
""" % locals()
return code + """
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""" % locals()
else:
#TODO: if your type is not listed here, make a damn registry of
# shape_i ops for various types of variables.
# Do not continue this madness.
return super(Shape, self).c_code(node, nodename, (x,), (out,), sub)
def c_code_cache_version(self):
return (1,)
def addbroadcast(x, *axes):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论