提交 da02145b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix bug in operator declarations when both complex64 and 128 were used.

上级 a21a33fd
......@@ -184,6 +184,8 @@ class Scalar(Type):
def c_support_code(self):
if self.dtype.startswith('complex'):
cplx_types = ['theano_complex64', 'theano_complex128']
real_types = ['npy_int8', 'npy_int16', 'npy_int32', 'npy_int64', 'npy_float32', 'npy_float64']
template = """
struct theano_complex%(nbits)s : public npy_complex%(nbits)s
......@@ -197,6 +199,7 @@ class Scalar(Type):
ret.imag = this->imag + y.imag;
return ret;
}
complex_type operator -() const {
complex_type ret;
ret.real = -this->real;
......@@ -232,76 +235,85 @@ class Scalar(Type):
template <typename TR, typename TI>
theano_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
};
"""
operator_eq = """
template <> %(mytype)s & %(mytype)s::operator=<npy_int8>(const npy_int8 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_int16>(const npy_int16 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_int32>(const npy_int32 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_int64>(const npy_int64 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_float32>(const npy_float32 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_float64>(const npy_float64 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<theano_complex128>(const theano_complex128 & y)
{ this->real=y.real; this->imag=y.imag; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<theano_complex64>(const theano_complex64 & y)
{ this->real=y.real; this->imag=y.imag; return *this; }
template <typename T>
const %(mytype)s
operator+(const %(mytype)s &x, const T& y)
{ return %(mytype)s(x.real+y, x.imag); }
template <typename T>
const %(mytype)s
operator+(const T& y, const %(mytype)s &x)
{ return %(mytype)s(x.real+y, x.imag); }
template <typename T>
const %(mytype)s
operator-(const %(mytype)s &x, const T& y)
{ return %(mytype)s(x.real-y, x.imag); }
template <typename T>
const %(mytype)s
operator-(const T& x, const %(mytype)s &y)
{ return %(mytype)s(x-y.real, -y.imag); }
template <typename T>
const %(mytype)s
operator*(const %(mytype)s &x, const T& y)
{ return %(mytype)s(x.real*y, x.imag*y); }
template <typename T>
const %(mytype)s
operator*(const T& x, const %(mytype)s &y)
{ return %(mytype)s(x*y.real, x*y.imag); }
};
"""
# todo: use C templating
def operator_eq_real(mytype, othertype):
return '''
template <> %(mytype)s & %(mytype)s::operator=<%(othertype)s>(const %(othertype)s & y)
{ this->real=y; this->imag=0; return *this; }
''' % dict(mytype = mytype, othertype = othertype)
def operator_eq_cplx(mytype, othertype):
return '''
template <> %(mytype)s & %(mytype)s::operator=<%(othertype)s>(const %(othertype)s & y)
{ this->real=y.real; this->imag=y.imag; return *this; }
''' % dict(mytype = mytype, othertype = othertype)
operator_eq = ''.join(operator_eq_real(ctype, rtype)
for ctype in cplx_types
for rtype in real_types) \
+ ''.join(operator_eq_cplx(ctype1, ctype2)
for ctype1 in cplx_types
for ctype2 in cplx_types)
# We are not using C++ generic templating here, because this would
# generate two different functions for adding a complex64 and a
# complex128, one returning a complex64, the other a complex128,
# and the compiler complains it is ambiguous.
# Instead, we generate code for known and safe types only.
def operator_plus_real(mytype, othertype):
return '''
const %(mytype)s operator+(const %(mytype)s &x, const %(othertype)s &y)
{ return %(mytype)s(x.real+y, x.imag); }
const %(mytype)s operator+(const %(othertype)s &y, const %(mytype)s &x)
{ return %(mytype)s(x.real+y, x.imag); }
''' % dict(mytype = mytype, othertype = othertype)
operator_plus = ''.join(operator_plus_real(ctype, rtype)
for ctype in cplx_types
for rtype in real_types)
def operator_minus_real(mytype, othertype):
return '''
const %(mytype)s operator-(const %(mytype)s &x, const %(othertype)s &y)
{ return %(mytype)s(x.real-y, x.imag); }
const %(mytype)s operator-(const %(othertype)s &y, const %(mytype)s &x)
{ return %(mytype)s(y-x.real, -x.imag); }
''' % dict(mytype = mytype, othertype = othertype)
operator_minus = ''.join(operator_minus_real(ctype, rtype)
for ctype in cplx_types
for rtype in real_types)
def operator_mul_real(mytype, othertype):
return '''
const %(mytype)s operator*(const %(mytype)s &x, const %(othertype)s &y)
{ return %(mytype)s(x.real*y, x.imag*y); }
const %(mytype)s operator*(const %(othertype)s &y, const %(mytype)s &x)
{ return %(mytype)s(x.real*y, x.imag*y); }
''' % dict(mytype = mytype, othertype = othertype)
operator_mul = ''.join(operator_mul_real(ctype, rtype)
for ctype in cplx_types
for rtype in real_types)
return template % dict(nbits = 64, half_nbits = 32) \
+ template % dict(nbits = 128, half_nbits = 64) \
+ operator_eq % dict(mytype='theano_complex128') \
+ operator_eq % dict(mytype='theano_complex64')
+ operator_eq \
+ operator_plus \
+ operator_minus \
+ operator_mul
else:
return ""
def c_code_cache_version(self):
return (9, numpy.__version__) # Make operators work with 64 and 128 arguments at the same time
return (8, numpy.__version__) # put const around operators and added unary '-' operator
# no need to put lib.amdlibm here as c_compile_args() are put in the key.
return (7,) # make complex c code optional
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论