提交 8a6a44a5 authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

fixed segfault + minor cosmetic changes

上级 df3bc37e
...@@ -4012,6 +4012,8 @@ class Flatten(Op): ...@@ -4012,6 +4012,8 @@ class Flatten(Op):
""" """
view_map = {0: [0]} view_map = {0: [0]}
check_input = False
def __init__(self, outdim=1): def __init__(self, outdim=1):
self.outdim = int(outdim) self.outdim = int(outdim)
...@@ -4079,7 +4081,7 @@ class Flatten(Op): ...@@ -4079,7 +4081,7 @@ class Flatten(Op):
return self.make_node(*eval_points).outputs return self.make_node(*eval_points).outputs
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1, 1)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
x, = inputs x, = inputs
...@@ -4089,18 +4091,16 @@ class Flatten(Op): ...@@ -4089,18 +4091,16 @@ class Flatten(Op):
return """ return """
if (%(outdim)s == PyArray_NDIM(%(x)s)) if (%(outdim)s == PyArray_NDIM(%(x)s))
{ {
if (NULL != %(out)s && %(x)s != %(out)s) Py_XDECREF(%(out)s);
Py_XDECREF(%(out)s); Py_XINCREF(%(x)s);
%(out)s = %(x)s; %(out)s = %(x)s;
} }
else else
{ {
Py_XDECREF(%(out)s);
if (%(outdim)s == 1) if (%(outdim)s == 1)
{ {
if (NULL != %(out)s)
Py_XDECREF(%(out)s);
npy_intp size = PyArray_SIZE(%(x)s); npy_intp size = PyArray_SIZE(%(x)s);
PyArray_Dims newshape; PyArray_Dims newshape;
newshape.ptr = &size; newshape.ptr = &size;
...@@ -4111,9 +4111,6 @@ class Flatten(Op): ...@@ -4111,9 +4111,6 @@ class Flatten(Op):
} }
else else
{ {
if (NULL != %(out)s)
Py_XDECREF(%(out)s);
npy_intp *oldshape = PyArray_DIMS(%(x)s); npy_intp *oldshape = PyArray_DIMS(%(x)s);
npy_intp newshape_dims[%(outdim)s]; npy_intp newshape_dims[%(outdim)s];
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论