提交 81c4cc22 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed support code in cc, added support code for complex numbers in base_tensor

上级 623b4175
...@@ -166,6 +166,27 @@ class BaseTensor(ResultBase): ...@@ -166,6 +166,27 @@ class BaseTensor(ResultBase):
def c_libraries(self): def c_libraries(self):
return [] return []
def c_support_code(cls):
operator_template = """
me operator %(op)s(me y) {
me ret;
ret.real = this->real %(op)s y.real;
ret.imag = this->imag %(op)s y.imag;
return ret;
}
"""
template = """
struct theano_complex%(nbits)s : public npy_complex%(nbits)s
{
typedef theano_complex%(nbits)s me;
typedef npy_complex%(nbits)s base;
%(operators)s
};
"""
d = dict(operators = "\n".join([operator_template % dict(op=op) for op in ["+", "-", "*", "/"]]))
return template % dict(d, nbits = 64) + template % dict(d, nbits = 128)
############################ ############################
# Tensor specific attributes # Tensor specific attributes
......
...@@ -395,9 +395,9 @@ class CLinker(Linker): ...@@ -395,9 +395,9 @@ class CLinker(Linker):
return self.tasks[failure_code - n] return self.tasks[failure_code - n]
def support_code(self): def support_code(self):
ret = "" ret = set()
for x in self.results + self.op_order: for x in self.results + self.op_order:
try: ret += x.c_support_code() try: ret.add(x.c_support_code())
except AbstractFunctionError: pass except AbstractFunctionError: pass
return ret return ret
...@@ -501,7 +501,10 @@ class CLinker(Linker): ...@@ -501,7 +501,10 @@ class CLinker(Linker):
} }
""" % dict(struct_name = self.struct_name) """ % dict(struct_name = self.struct_name)
instantiate.customize.add_support_code(self.support_code() + self.struct_code + static) instantiate.customize.add_support_code(self.struct_code)
instantiate.customize.add_support_code(static)
for support_code in self.support_code():
instantiate.customize.add_support_code(support_code)
instantiate.customize.add_extra_compile_arg("-w") instantiate.customize.add_extra_compile_arg("-w")
for arg in self.compile_args(): for arg in self.compile_args():
instantiate.customize.add_extra_compile_arg(arg) instantiate.customize.add_extra_compile_arg(arg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论