提交 19231d79 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Added new Op for C ops to inherit from

上级 c7d72dcc
...@@ -55,7 +55,7 @@ from theano.gof.link import \ ...@@ -55,7 +55,7 @@ from theano.gof.link import \
Container, Linker, LocalLinker, PerformLinker, WrapLinker, WrapLinkerMany Container, Linker, LocalLinker, PerformLinker, WrapLinker, WrapLinkerMany
from theano.gof.op import \ from theano.gof.op import \
Op, OpenMPOp, PureOp, ops_with_inner_function Op, OpenMPOp, PureOp, COp, ops_with_inner_function
from theano.gof.opt import ( from theano.gof.opt import (
Optimizer, Optimizer,
......
...@@ -974,3 +974,106 @@ int main( int argc, const char* argv[] ) ...@@ -974,3 +974,106 @@ int main( int argc, const char* argv[] )
self.update_self_openmp() self.update_self_openmp()
return super(OpenMPOp, self).make_thunk(node, storage_map, return super(OpenMPOp, self).make_thunk(node, storage_map,
compute_map, no_recycling) compute_map, no_recycling)
class COp(Op):
""" Class to allow an op to have an external C implementation.
An op can use this class by inheriting from it and calling its
__init__() method, providing it with a path to an external file containing
the C implementation and the name of the function, in that file, to call
to perform the computations for the op.
"""
def __init__(self, func_file, func_name):
self.func_file = func_file
self.func_name = func_name
# Load the func
f = open(self.func_file, "r")
self.func_code = f.read()
f.close()
def c_code_cache_version(self):
return hash(self.func_code)
def c_support_code_apply(self, node, name):
func_code = self.func_code.replace("<<<<NODE_NAME_PLACEHOLDER>>>>",
name)
if hasattr(self, 'check_inputs') and self.check_inputs == False:
return func_code
else:
define_macros, undef_macros = self.get_c_macros(node, name)
return os.linesep.join([define_macros, func_code, undef_macros])
def format_c_function_args(self, inp, out):
# Generate an string containing the arguments sent to the external C
# function. The argstring will be of format :
# "input0, input1, input2, (void**)&output0, (void**)&output1"
input_arg_str = ", ".join(inp)
output_arg_str = ", ".join(["(void**)&%s"] * len(out)) % tuple(out)
return input_arg_str + ", " + output_arg_str
def get_c_macros(self, node, name):
define_template = "#define %s %s" + os.linesep
undef_template = "#undef %s" + os.linesep
define_macros = ""
undef_macros = ""
# Extract the various properties of the input and output variables
variables = node.inputs + node.outputs
variable_names = (["INPUT_%i" % i for i in range(len(node.inputs))] +
["OUTPUT_%i" % i for i in range(len(node.inputs))])
variable_dtypes_names = [v.dtype for v in variables]
variable_dtypes = [numpy.dtype(d) for d in variable_dtypes_names]
variable_typenums = [d.num for d in variable_dtypes]
variable_itemsizes = [d.itemsize for d in variable_dtypes]
# Generate dtype macros
for i in range(len(variables)):
macro_name = "DTYPE_" + variable_names[i]
macro_value = "npy_" + variable_dtypes_names[i]
define_macros += define_template % (macro_name, macro_value)
undef_macros += undef_template % macro_name
# Generate typenum macros
for i in range(len(variables)):
macro_name = "TYPENUM_" + variable_names[i]
macro_value = variable_typenums[i]
define_macros += define_template % (macro_name, macro_value)
undef_macros += undef_template % macro_name
# Generate itemsize macros
for i in range(len(variables)):
macro_name = "ITEMSIZE_" + variable_names[i]
macro_value = variable_itemsizes[i]
define_macros += define_template % (macro_name, macro_value)
undef_macros += undef_template % macro_name
return define_macros, undef_macros
def c_code(self, node, name, inp, out, sub):
func_name = self.func_name.replace("<<<<NODE_NAME_PLACEHOLDER>>>>",
name)
func_args = self.format_c_function_args(inp, out)
fail = sub['fail']
# Generate the C code
c_code = """
{
int result = %(func_name)s(%(func_args)s);
if (result != 0)
{
%(fail)s;
}
}
""" % locals()
return c_code
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论