提交 da02145b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix bug in operator declarations when both complex64 and 128 were used.

上级 a21a33fd
...@@ -184,6 +184,8 @@ class Scalar(Type): ...@@ -184,6 +184,8 @@ class Scalar(Type):
def c_support_code(self): def c_support_code(self):
if self.dtype.startswith('complex'): if self.dtype.startswith('complex'):
cplx_types = ['theano_complex64', 'theano_complex128']
real_types = ['npy_int8', 'npy_int16', 'npy_int32', 'npy_int64', 'npy_float32', 'npy_float64']
template = """ template = """
struct theano_complex%(nbits)s : public npy_complex%(nbits)s struct theano_complex%(nbits)s : public npy_complex%(nbits)s
...@@ -197,6 +199,7 @@ class Scalar(Type): ...@@ -197,6 +199,7 @@ class Scalar(Type):
ret.imag = this->imag + y.imag; ret.imag = this->imag + y.imag;
return ret; return ret;
} }
complex_type operator -() const { complex_type operator -() const {
complex_type ret; complex_type ret;
ret.real = -this->real; ret.real = -this->real;
...@@ -233,75 +236,84 @@ class Scalar(Type): ...@@ -233,75 +236,84 @@ class Scalar(Type):
template <typename TR, typename TI> template <typename TR, typename TI>
theano_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; } theano_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
}; };
""" """
operator_eq = """
template <> %(mytype)s & %(mytype)s::operator=<npy_int8>(const npy_int8 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_int16>(const npy_int16 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_int32>(const npy_int32 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_int64>(const npy_int64 & y) def operator_eq_real(mytype, othertype):
return '''
template <> %(mytype)s & %(mytype)s::operator=<%(othertype)s>(const %(othertype)s & y)
{ this->real=y; this->imag=0; return *this; } { this->real=y; this->imag=0; return *this; }
''' % dict(mytype = mytype, othertype = othertype)
template <> %(mytype)s & %(mytype)s::operator=<npy_float32>(const npy_float32 & y) def operator_eq_cplx(mytype, othertype):
{ this->real=y; this->imag=0; return *this; } return '''
template <> %(mytype)s & %(mytype)s::operator=<%(othertype)s>(const %(othertype)s & y)
template <> %(mytype)s & %(mytype)s::operator=<npy_float64>(const npy_float64 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<theano_complex128>(const theano_complex128 & y)
{ this->real=y.real; this->imag=y.imag; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<theano_complex64>(const theano_complex64 & y)
{ this->real=y.real; this->imag=y.imag; return *this; } { this->real=y.real; this->imag=y.imag; return *this; }
''' % dict(mytype = mytype, othertype = othertype)
template <typename T>
const %(mytype)s operator_eq = ''.join(operator_eq_real(ctype, rtype)
operator+(const %(mytype)s &x, const T& y) for ctype in cplx_types
for rtype in real_types) \
+ ''.join(operator_eq_cplx(ctype1, ctype2)
for ctype1 in cplx_types
for ctype2 in cplx_types)
# We are not using C++ generic templating here, because this would
# generate two different functions for adding a complex64 and a
# complex128, one returning a complex64, the other a complex128,
# and the compiler complains it is ambiguous.
# Instead, we generate code for known and safe types only.
def operator_plus_real(mytype, othertype):
return '''
const %(mytype)s operator+(const %(mytype)s &x, const %(othertype)s &y)
{ return %(mytype)s(x.real+y, x.imag); } { return %(mytype)s(x.real+y, x.imag); }
template <typename T> const %(mytype)s operator+(const %(othertype)s &y, const %(mytype)s &x)
const %(mytype)s
operator+(const T& y, const %(mytype)s &x)
{ return %(mytype)s(x.real+y, x.imag); } { return %(mytype)s(x.real+y, x.imag); }
''' % dict(mytype = mytype, othertype = othertype)
template <typename T> operator_plus = ''.join(operator_plus_real(ctype, rtype)
const %(mytype)s for ctype in cplx_types
operator-(const %(mytype)s &x, const T& y) for rtype in real_types)
def operator_minus_real(mytype, othertype):
return '''
const %(mytype)s operator-(const %(mytype)s &x, const %(othertype)s &y)
{ return %(mytype)s(x.real-y, x.imag); } { return %(mytype)s(x.real-y, x.imag); }
template <typename T> const %(mytype)s operator-(const %(othertype)s &y, const %(mytype)s &x)
const %(mytype)s { return %(mytype)s(y-x.real, -x.imag); }
operator-(const T& x, const %(mytype)s &y) ''' % dict(mytype = mytype, othertype = othertype)
{ return %(mytype)s(x-y.real, -y.imag); }
template <typename T> operator_minus = ''.join(operator_minus_real(ctype, rtype)
const %(mytype)s for ctype in cplx_types
operator*(const %(mytype)s &x, const T& y) for rtype in real_types)
def operator_mul_real(mytype, othertype):
return '''
const %(mytype)s operator*(const %(mytype)s &x, const %(othertype)s &y)
{ return %(mytype)s(x.real*y, x.imag*y); } { return %(mytype)s(x.real*y, x.imag*y); }
template <typename T> const %(mytype)s operator*(const %(othertype)s &y, const %(mytype)s &x)
const %(mytype)s { return %(mytype)s(x.real*y, x.imag*y); }
operator*(const T& x, const %(mytype)s &y) ''' % dict(mytype = mytype, othertype = othertype)
{ return %(mytype)s(x*y.real, x*y.imag); }
""" operator_mul = ''.join(operator_mul_real(ctype, rtype)
for ctype in cplx_types
for rtype in real_types)
# todo: use C templating
return template % dict(nbits = 64, half_nbits = 32) \ return template % dict(nbits = 64, half_nbits = 32) \
+ template % dict(nbits = 128, half_nbits = 64) \ + template % dict(nbits = 128, half_nbits = 64) \
+ operator_eq % dict(mytype='theano_complex128') \ + operator_eq \
+ operator_eq % dict(mytype='theano_complex64') + operator_plus \
+ operator_minus \
+ operator_mul
else: else:
return "" return ""
def c_code_cache_version(self): def c_code_cache_version(self):
return (9, numpy.__version__) # Make operators work with 64 and 128 arguments at the same time
return (8, numpy.__version__) # put const around operators and added unary '-' operator return (8, numpy.__version__) # put const around operators and added unary '-' operator
# no need to put lib.amdlibm here as c_compile_args() are put in the key. # no need to put lib.amdlibm here as c_compile_args() are put in the key.
return (7,) # make complex c code optional return (7,) # make complex c code optional
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论