提交 5f8f1fa9 authored 作者: Frederic Bastien's avatar Frederic Bastien

merge, backport patch about complex.

......@@ -180,9 +180,23 @@ class Scalar(Type):
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
return ret;
}
};
"""
return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64)
complex_type& operator =(const scalar_type& y) {
this->real=y;
this->imag=0;
return *this;
}
%(upcast)s
};
"""
# todo: use C templating
return template % dict(nbits = 64, half_nbits = 32, upcast="") + template % dict(nbits = 128, half_nbits = 64, upcast="""
complex_type& operator =(theano_complex64 y) {
this->real=y.real;
this->imag=y.imag;
return *this;
}
""")
int8 = Scalar('int8')
......@@ -264,6 +278,9 @@ def _multi(*fns):
ints = _multi(int64)
floats = _multi(float64)
complexs = _multi(complex128)
complexs64 = _multi(complex64)
complexs128 = _multi(complex128)
......
......@@ -446,42 +446,7 @@ class TensorType(Type):
def c_support_code(cls):
"""Override `CLinkerOp.c_support_code` """
template = """
struct theano_complex%(nbits)s : public npy_complex%(nbits)s
{
typedef theano_complex%(nbits)s complex_type;
typedef npy_float%(half_nbits)s scalar_type;
complex_type operator +(complex_type y) {
complex_type ret;
ret.real = this->real + y.real;
ret.imag = this->imag + y.imag;
return ret;
}
complex_type operator -(complex_type y) {
complex_type ret;
ret.real = this->real - y.real;
ret.imag = this->imag - y.imag;
return ret;
}
complex_type operator *(complex_type y) {
complex_type ret;
ret.real = this->real * y.real - this->imag * y.imag;
ret.imag = this->real * y.imag + this->imag * y.real;
return ret;
}
complex_type operator /(complex_type y) {
complex_type ret;
scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
return ret;
}
};
"""
return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64)
# todo: use C templating
return scal.Scalar("int8").c_support_code()
# Easy constructors
......
......@@ -1920,6 +1920,41 @@ def test_sum_overflow():
f = function([a], sum(a))
assert f([1]*300) == 300
def test_convert_to_complex():
a = value(numpy.ones(3, dtype='complex64')+0.5j)
b = value(numpy.ones(3, dtype='complex128')+0.5j)
f = function([a],basic.convert_to_complex128(a))
#we need to compare with the same type.
assert a.type.values_eq_approx(b.data, f(a.data))
f = function([b],basic.convert_to_complex128(b))
assert b.type.values_eq_approx(b.data, f(b.data))
f = function([a],basic.convert_to_complex64(a))
assert a.type.values_eq_approx(a.data, f(a.data))
#down cast don,t work for now
#f = function([b],basic.convert_to_complex64(b))
#assert b.type.values_eq_approx(b.data, f(b.data))
for nbits in (64, 128):
for t in ['int8','int16','int32','int64','float32','float64']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex128'))
f = function([a],basic.convert_to_complex128(a))
assert a.type.values_eq_approx(b.data, f(a.data))
def test_bug_complext_10_august_09():
v0 = dmatrix()
v1 = basic.convert_to_complex128(v0)
inputs = [v0]
outputs = [v1]
f = function(inputs, outputs)
i = numpy.zeros((2,2))
assert (f(i)==numpy.zeros((2,2))).all()
if __name__ == '__main__':
if len(sys.argv) >= 2 and sys.argv[1] == 'OPT':
default_mode = compile.Mode(linker = 'c&py',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论