cc.py is now in working order

上级 5147d5fc
...@@ -18,21 +18,46 @@ class Double(ResultBase): ...@@ -18,21 +18,46 @@ class Double(ResultBase):
def __repr__(self): def __repr__(self):
return self.name return self.name
def c_type(self): # def c_is_simple(self): return True
return "double"
def c_declare(self):
return "double %(name)s; void* %(name)s_bad_thing;"
def c_init(self):
return """
%(name)s = 0;
%(name)s_bad_thing = malloc(100000);
printf("Initializing %(name)s\\n");
"""
def c_literal(self):
return str(self.data)
def c_data_extract(self): def c_extract(self):
return """ return """
%(type)s %(name)s = PyFloat_AsDouble(py_%(name)s); if (!PyFloat_Check(py_%(name)s)) {
%(fail)s PyErr_SetString(PyExc_TypeError, "not a double!");
%(fail)s
}
%(name)s = PyFloat_AsDouble(py_%(name)s);
%(name)s_bad_thing = NULL;
printf("Extracting %(name)s\\n");
""" """
def c_data_sync(self): def c_sync(self):
return """ return """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
py_%(name)s = PyFloat_FromDouble(%(name)s); py_%(name)s = PyFloat_FromDouble(%(name)s);
if (!py_%(name)s) if (!py_%(name)s)
py_%(name)s = Py_None; py_%(name)s = Py_None;
printf("Syncing %(name)s\\n");
"""
def c_cleanup(self):
return """
printf("Cleaning up %(name)s\\n");
if (%(name)s_bad_thing)
free(%(name)s_bad_thing);
""" """
...@@ -95,10 +120,22 @@ class _test_CLinker(unittest.TestCase): ...@@ -95,10 +120,22 @@ class _test_CLinker(unittest.TestCase):
def test_0(self): def test_0(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e])) lnk = CLinker(env([x, y, z], [e]), [x.r, y.r, z.r], [e.r])
print lnk.code_gen() cgen = lnk.code_gen()
fn = lnk.make_function([x.r, y.r, z.r], [e.r])
print fn(2.0, 2.0, 2.0)
# fn = 0
def test_1(self):
x, y, z = inputs()
z.r.constant = True
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y], [e]), [x.r, y.r], [e.r])
cgen = lnk.code_gen()
fn = lnk.make_function([x.r, y.r], [e.r])
print fn(2.0, 2.0)
# fn = 0
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
差异被折叠。
try:
from cutils_ext import *
except ImportError:
from scipy import weave
single_runner = """
if (!PyCObject_Check(py_cthunk)) {
PyErr_SetString(PyExc_ValueError,
"Argument to run_cthunk must be a PyCObject.");
return NULL;
}
void * ptr_addr = PyCObject_AsVoidPtr(py_cthunk);
int (*fn)(void*) = reinterpret_cast<int (*)(void*)>(ptr_addr);
void* it = PyCObject_GetDesc(py_cthunk);
int failure = fn(it);
return_val = failure;
"""
cthunk = object()
mod = weave.ext_tools.ext_module('cutils_ext')
fun =weave.ext_tools.ext_function('run_cthunk', single_runner, ['cthunk'])
fun.customize.add_extra_compile_arg('--permissive')
mod.add_function(fun)
mod.compile()
from cutils_ext import *
...@@ -161,7 +161,6 @@ class Env(graph.Graph): ...@@ -161,7 +161,6 @@ class Env(graph.Graph):
if do_import: if do_import:
for op in self.io_toposort(): for op in self.io_toposort():
try: try:
# print op
feature.on_import(op) feature.on_import(op)
except AbstractFunctionError: except AbstractFunctionError:
pass pass
......
...@@ -126,7 +126,7 @@ class Op(object): ...@@ -126,7 +126,7 @@ class Op(object):
return [["i%i" % i for i in xrange(len(self.inputs))], return [["i%i" % i for i in xrange(len(self.inputs))],
["o%i" % i for i in xrange(len(self.outputs))]] ["o%i" % i for i in xrange(len(self.outputs))]]
def c_validate(self): def c_validate_update(self):
""" """
Returns C code that checks that the inputs to this function Returns C code that checks that the inputs to this function
can be worked on. If a failure occurs, set an Exception can be worked on. If a failure occurs, set an Exception
...@@ -136,27 +136,27 @@ class Op(object): ...@@ -136,27 +136,27 @@ class Op(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_validate_cleanup(self): def c_validate_update_cleanup(self):
""" """
Clean up things allocated by c_validate(). Clean up things allocated by c_validate().
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_update(self): # def c_update(self):
""" # """
Returns C code that allocates and/or updates the outputs # Returns C code that allocates and/or updates the outputs
(eg resizing, etc.) so they can be manipulated safely # (eg resizing, etc.) so they can be manipulated safely
by c_code. # by c_code.
You may use the variable names defined by c_var_names() # You may use the variable names defined by c_var_names()
""" # """
raise AbstractFunctionError() # raise AbstractFunctionError()
def c_update_cleanup(self): # def c_update_cleanup(self):
""" # """
Clean up things allocated by c_update(). # Clean up things allocated by c_update().
""" # """
raise AbstractFunctionError() # raise AbstractFunctionError()
def c_code(self): def c_code(self):
""" """
...@@ -174,6 +174,31 @@ class Op(object): ...@@ -174,6 +174,31 @@ class Op(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_compile_args(self):
"""
Return a list of compile args recommended to manipulate this Op.
"""
raise AbstractFunctionError()
def c_headers(self):
"""
Return a list of header files that must be included from C to manipulate
this Op.
"""
raise AbstractFunctionError()
def c_libraries(self):
"""
Return a list of libraries to link against to manipulate this Op.
"""
raise AbstractFunctionError()
def c_support_code(self):
"""
Return utility code for use by this Op.
"""
raise AbstractFunctionError()
class GuardedOp(Op): class GuardedOp(Op):
......
...@@ -149,20 +149,23 @@ class ResultBase(object): ...@@ -149,20 +149,23 @@ class ResultBase(object):
# C code generators # C code generators
# #
def c_is_simple(self):
return False
def c_declare(self): def c_declare(self):
""" """
Declares variables that will be instantiated by c_data_extract. Declares variables that will be instantiated by c_data_extract.
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_extract(self): # def c_extract(self):
get_from_list = """ # get_from_list = """
PyObject* py_%(name)s = PyList_GET_ITEM(%(name)s_storage, 0); # //PyObject* py_%(name)s = PyList_GET_ITEM(%(name)s_storage, 0);
Py_XINCREF(py_%(name)s); # //Py_XINCREF(py_%(name)s);
""" # """
return get_from_list + self.c_data_extract() # return get_from_list + self.c_data_extract()
def c_data_extract(self): def c_extract(self):
""" """
# The code returned from this function must be templated using # The code returned from this function must be templated using
# "%(name)s", representing the name that the caller wants to # "%(name)s", representing the name that the caller wants to
...@@ -176,13 +179,13 @@ class ResultBase(object): ...@@ -176,13 +179,13 @@ class ResultBase(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_cleanup(self): # def c_cleanup(self):
decref = """ # decref = """
Py_XDECREF(py_%(name)s); # //Py_XDECREF(py_%(name)s);
""" # """
return self.c_data_cleanup() + decref # return self.c_data_cleanup() + decref
def c_data_cleanup(self): def c_cleanup(self):
""" """
This returns C code that should deallocate whatever This returns C code that should deallocate whatever
c_data_extract allocated or decrease the reference counts. Do c_data_extract allocated or decrease the reference counts. Do
...@@ -192,14 +195,14 @@ class ResultBase(object): ...@@ -192,14 +195,14 @@ class ResultBase(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_sync(self): # def c_sync(self):
set_in_list = """ # set_in_list = """
PyList_SET_ITEM(%(name)s_storage, 0, py_%(name)s); # //PyList_SET_ITEM(%(name)s_storage, 0, py_%(name)s);
Py_XDECREF(py_%(name)s); # //Py_XDECREF(py_%(name)s);
""" # """
return self.c_data_sync() + set_in_list # return self.c_data_sync() + set_in_list
def c_data_sync(self): def c_sync(self):
""" """
The code returned from this function must be templated using "%(name)s", The code returned from this function must be templated using "%(name)s",
representing the name that the caller wants to call this Result. representing the name that the caller wants to call this Result.
...@@ -209,20 +212,26 @@ class ResultBase(object): ...@@ -209,20 +212,26 @@ class ResultBase(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_compile_args(self):
"""
Return a list of compile args recommended to manipulate this Result.
"""
raise AbstractFunctionError()
def c_headers(self): def c_headers(self):
""" """
Return a list of header files that must be included from C to manipulate Return a list of header files that must be included from C to manipulate
this Result. this Result.
""" """
return [] raise AbstractFunctionError()
def c_libraries(self): def c_libraries(self):
""" """
Return a list of libraries to link against to manipulate this Result. Return a list of libraries to link against to manipulate this Result.
""" """
return [] raise AbstractFunctionError()
def c_support(self): def c_support_code(self):
""" """
Return utility code for use by this Result or Ops manipulating this Return utility code for use by this Result or Ops manipulating this
Result. Result.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论