提交 f6c583bf authored 作者: Hengjean's avatar Hengjean

Added C interface for TypedListType, GetItem op, insert op, append op and extend op.

上级 d35a7ad9
...@@ -80,6 +80,14 @@ class GetItem(Op): ...@@ -80,6 +80,14 @@ class GetItem(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def c_code(self, node, name, inp, out, sub):
x_name, index = inp[0], inp[1]
output_name = out[0]
return """
%(output_name)s = (typeof %(output_name)s) PyList_GetItem( (PyObject*) %(x_name)s, *((double *) PyArray_DATA(%(index)s)));
Py_INCREF(%(output_name)s);
""" % locals()
getitem = GetItem() getitem = GetItem()
...@@ -114,6 +122,22 @@ class Append(Op): ...@@ -114,6 +122,22 @@ class Append(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def c_code(self, node, name, inp, out, sub):
x_name, toAppend = inp[0], inp[1]
output_name = out[0]
if not self.inplace:
init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ;
""" % locals()
else:
init = """
%(output_name)s = %(x_name)s;
""" % locals()
return init + """
PyList_Append( (PyObject*) %(output_name)s,(PyObject*) %(toAppend)s);
Py_INCREF(%(output_name)s);
""" % locals()
append = Append() append = Append()
...@@ -148,6 +172,26 @@ class Extend(Op): ...@@ -148,6 +172,26 @@ class Extend(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def c_code(self, node, name, inp, out, sub):
x_name, toAppend = inp[0], inp[1]
output_name = out[0]
if not self.inplace:
init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ;
""" % locals()
else:
init = """
%(output_name)s = %(x_name)s;
""" % locals()
return init + """
int i =0;
int length = PyList_GET_SIZE((PyObject*) %(x_name)s);
for(i; i < length; i++){
PyList_Append( (PyObject*) %(output_name)s,(PyObject*) PyList_GetItem((PyObject*) %(toAppend)s,i));
}
Py_INCREF(%(output_name)s);
""" % locals()
extend = Extend() extend = Extend()
...@@ -183,6 +227,22 @@ class Insert(Op): ...@@ -183,6 +227,22 @@ class Insert(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def c_code(self, node, name, inp, out, sub):
x_name, index, toInsert = inp[0], inp[1], inp[2]
output_name = out[0]
if not self.inplace:
init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ;
""" % locals()
else:
init = """
%(output_name)s = %(x_name)s;
""" % locals()
return init + """
PyList_Insert((PyObject*) %(output_name)s, *((double *) PyArray_DATA(%(index)s)), (PyObject*) %(toInsert)s);
Py_INCREF(%(output_name)s);
""" % locals()
insert = Insert() insert = Insert()
......
...@@ -43,6 +43,7 @@ def random_lil(shape, dtype, nnz): ...@@ -43,6 +43,7 @@ def random_lil(shape, dtype, nnz):
class test_get_item(unittest.TestCase): class test_get_item(unittest.TestCase):
def setUp(self): def setUp(self):
theano.config.nocleanup = True
utt.seed_rng() utt.seed_rng()
def test_sanity_check_slice(self): def test_sanity_check_slice(self):
...@@ -76,6 +77,7 @@ class test_get_item(unittest.TestCase): ...@@ -76,6 +77,7 @@ class test_get_item(unittest.TestCase):
z) z)
x = rand_ranged_matrix(-1000, 1000, [100, 101]) x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(0, self.assertTrue(numpy.array_equal(f([x], numpy.asarray(0,
dtype=theano.config.floatX)), x)) dtype=theano.config.floatX)), x))
......
...@@ -66,6 +66,7 @@ class TypedListType(gof.Type): ...@@ -66,6 +66,7 @@ class TypedListType(gof.Type):
else: else:
return 0 return 0
<<<<<<< HEAD
def values_eq(self, a, b): def values_eq(self, a, b):
if not len(a) == len(b): if not len(a) == len(b):
return False return False
...@@ -75,3 +76,34 @@ class TypedListType(gof.Type): ...@@ -75,3 +76,34 @@ class TypedListType(gof.Type):
return False return False
return True return True
=======
def c_declare(self, name, sub):
return """
PyListObject* %(name)s;
""" % dict(name=name)
def c_init(self, name, sub):
return """
%(name)s = NULL;
""" % dict(name=name)
def c_extract(self, name, sub):
return """
if (!PyList_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a list");
%(fail)s
}
%(name)s = (PyListObject*) (py_%(name)s);
""" % dict(name=name, fail=sub['fail'])
def c_sync(self, name, sub):
return """
Py_XDECREF(py_%(name)s);
py_%(name)s = (PyObject*)(%(name)s);
Py_INCREF(py_%(name)s);
""" % dict(name=name)
def c_cleanup(self, name, sub):
return ""
>>>>>>> Added C interface for TypedListType, GetItem op, insert op, append op and extend op.
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论