提交 7dcce643 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed and tested complex operations for +-*/

上级 fb352f6b
......@@ -205,19 +205,18 @@ class T_subtensor(unittest.TestCase):
class T_add(unittest.TestCase):
def test_complex128(self):
a = tinit(numpy.ones(3, dtype='complex128'))
b = tinit(numpy.ones(3, dtype='complex128'))
f = Function([a,b], [a+b], linker_cls = gof.CLinker)
self.failUnless(numpy.all((a.data + b.data) ==
f(a.data, b.data)))
def test_complex128b(self):
a = tinit(numpy.ones(3, dtype='complex128')+0.5j)
b = tinit(numpy.ones(3, dtype='complex128'))
f = Function([a,b], [a+b], linker_cls = gof.CLinker)
self.failUnless(numpy.all((a.data + b.data) ==
f(a.data, b.data)))
def test_complex_all_ops(self):
for nbits in (64, 128):
a = tinit(numpy.ones(3, dtype='complex%i' % nbits)+0.5j)
b = tinit(numpy.ones(3, dtype='complex%i' % nbits)+1.5j)
tests = (("+", lambda x,y: x+y),
("-", lambda x,y: x-y),
("*", lambda x,y: x*y),
("/", lambda x,y: x/y))
for s, fn in tests:
f = Function([a,b], [fn(a, b)], linker_cls = gof.CLinker)
self.failUnless(numpy.all(fn(a.data, b.data) == f(a.data, b.data)))
class T_abs(unittest.TestCase):
......
......@@ -169,25 +169,40 @@ class BaseTensor(ResultBase):
return []
def c_support_code(cls):
operator_template = """
me operator %(op)s(me y) {
me ret;
ret.real = this->real %(op)s y.real;
ret.imag = this->imag %(op)s y.imag;
return ret;
}
"""
template = """
struct theano_complex%(nbits)s : public npy_complex%(nbits)s
{
typedef theano_complex%(nbits)s me;
typedef npy_complex%(nbits)s base;
typedef theano_complex%(nbits)s complex_type;
typedef npy_float%(half_nbits)s scalar_type;
%(operators)s
complex_type operator +(complex_type y) {
complex_type ret;
ret.real = this->real + y.real;
ret.imag = this->imag + y.imag;
return ret;
}
complex_type operator -(complex_type y) {
complex_type ret;
ret.real = this->real - y.real;
ret.imag = this->imag - y.imag;
return ret;
}
complex_type operator *(complex_type y) {
complex_type ret;
ret.real = this->real * y.real - this->imag * y.imag;
ret.imag = this->real * y.imag + this->imag * y.real;
return ret;
}
complex_type operator /(complex_type y) {
complex_type ret;
scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
return ret;
}
};
"""
d = dict(operators = "\n".join([operator_template % dict(op=op) for op in ["+", "-", "*", "/"]]))
return template % dict(d, nbits = 64) + template % dict(d, nbits = 128)
return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64)
############################
......
......@@ -471,7 +471,7 @@ class SubElemwise(_Elemwise):
def grad(self, (x, y), gz):
return gz, -gz
def c_foreach(self, (x_i, y_i), (z_i, )):
return "z_i = x_i - y_i;"
return "%(z)s_i = %(x)s_i - %(y)s_i;"
sub_elemwise = _constructor(SubElemwise)
class SubElemwiseInplace(SubElemwise.inplace_version()):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论