提交 7dcce643 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed and tested complex operations for +-*/

上级 fb352f6b
...@@ -205,19 +205,18 @@ class T_subtensor(unittest.TestCase): ...@@ -205,19 +205,18 @@ class T_subtensor(unittest.TestCase):
class T_add(unittest.TestCase): class T_add(unittest.TestCase):
def test_complex128(self):
a = tinit(numpy.ones(3, dtype='complex128')) def test_complex_all_ops(self):
b = tinit(numpy.ones(3, dtype='complex128')) for nbits in (64, 128):
f = Function([a,b], [a+b], linker_cls = gof.CLinker) a = tinit(numpy.ones(3, dtype='complex%i' % nbits)+0.5j)
self.failUnless(numpy.all((a.data + b.data) == b = tinit(numpy.ones(3, dtype='complex%i' % nbits)+1.5j)
f(a.data, b.data))) tests = (("+", lambda x,y: x+y),
("-", lambda x,y: x-y),
def test_complex128b(self): ("*", lambda x,y: x*y),
a = tinit(numpy.ones(3, dtype='complex128')+0.5j) ("/", lambda x,y: x/y))
b = tinit(numpy.ones(3, dtype='complex128')) for s, fn in tests:
f = Function([a,b], [a+b], linker_cls = gof.CLinker) f = Function([a,b], [fn(a, b)], linker_cls = gof.CLinker)
self.failUnless(numpy.all((a.data + b.data) == self.failUnless(numpy.all(fn(a.data, b.data) == f(a.data, b.data)))
f(a.data, b.data)))
class T_abs(unittest.TestCase): class T_abs(unittest.TestCase):
......
...@@ -169,25 +169,40 @@ class BaseTensor(ResultBase): ...@@ -169,25 +169,40 @@ class BaseTensor(ResultBase):
return [] return []
def c_support_code(cls): def c_support_code(cls):
operator_template = """
me operator %(op)s(me y) {
me ret;
ret.real = this->real %(op)s y.real;
ret.imag = this->imag %(op)s y.imag;
return ret;
}
"""
template = """ template = """
struct theano_complex%(nbits)s : public npy_complex%(nbits)s struct theano_complex%(nbits)s : public npy_complex%(nbits)s
{ {
typedef theano_complex%(nbits)s me; typedef theano_complex%(nbits)s complex_type;
typedef npy_complex%(nbits)s base; typedef npy_float%(half_nbits)s scalar_type;
%(operators)s complex_type operator +(complex_type y) {
complex_type ret;
ret.real = this->real + y.real;
ret.imag = this->imag + y.imag;
return ret;
}
complex_type operator -(complex_type y) {
complex_type ret;
ret.real = this->real - y.real;
ret.imag = this->imag - y.imag;
return ret;
}
complex_type operator *(complex_type y) {
complex_type ret;
ret.real = this->real * y.real - this->imag * y.imag;
ret.imag = this->real * y.imag + this->imag * y.real;
return ret;
}
complex_type operator /(complex_type y) {
complex_type ret;
scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
return ret;
}
}; };
""" """
d = dict(operators = "\n".join([operator_template % dict(op=op) for op in ["+", "-", "*", "/"]])) return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64)
return template % dict(d, nbits = 64) + template % dict(d, nbits = 128)
############################ ############################
......
...@@ -471,7 +471,7 @@ class SubElemwise(_Elemwise): ...@@ -471,7 +471,7 @@ class SubElemwise(_Elemwise):
def grad(self, (x, y), gz): def grad(self, (x, y), gz):
return gz, -gz return gz, -gz
def c_foreach(self, (x_i, y_i), (z_i, )): def c_foreach(self, (x_i, y_i), (z_i, )):
return "z_i = x_i - y_i;" return "%(z)s_i = %(x)s_i - %(y)s_i;"
sub_elemwise = _constructor(SubElemwise) sub_elemwise = _constructor(SubElemwise)
class SubElemwiseInplace(SubElemwise.inplace_version()): class SubElemwiseInplace(SubElemwise.inplace_version()):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论