提交 5a21406f authored 作者: James Bergstra's avatar James Bergstra

make scalar only emit complex-related code when complex type is in use.

上级 5aa75e84
...@@ -175,111 +175,118 @@ class Scalar(Type): ...@@ -175,111 +175,118 @@ class Scalar(Type):
return "" return ""
def c_support_code(self): def c_support_code(self):
template = """
struct theano_complex%(nbits)s : public npy_complex%(nbits)s
{
typedef theano_complex%(nbits)s complex_type;
typedef npy_float%(half_nbits)s scalar_type;
complex_type operator +(complex_type y) {
complex_type ret;
ret.real = this->real + y.real;
ret.imag = this->imag + y.imag;
return ret;
}
complex_type operator -(complex_type y) {
complex_type ret;
ret.real = this->real - y.real;
ret.imag = this->imag - y.imag;
return ret;
}
complex_type operator *(complex_type y) {
complex_type ret;
ret.real = this->real * y.real - this->imag * y.imag;
ret.imag = this->real * y.imag + this->imag * y.real;
return ret;
}
complex_type operator /(complex_type y) {
complex_type ret;
scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
return ret;
}
template <typename T>
complex_type& operator =(const T& y);
theano_complex%(nbits)s() {} if self.dtype.startswith('complex'):
template = """
struct theano_complex%(nbits)s : public npy_complex%(nbits)s
{
typedef theano_complex%(nbits)s complex_type;
typedef npy_float%(half_nbits)s scalar_type;
complex_type operator +(complex_type y) {
complex_type ret;
ret.real = this->real + y.real;
ret.imag = this->imag + y.imag;
return ret;
}
complex_type operator -(complex_type y) {
complex_type ret;
ret.real = this->real - y.real;
ret.imag = this->imag - y.imag;
return ret;
}
complex_type operator *(complex_type y) {
complex_type ret;
ret.real = this->real * y.real - this->imag * y.imag;
ret.imag = this->real * y.imag + this->imag * y.real;
return ret;
}
complex_type operator /(complex_type y) {
complex_type ret;
scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
return ret;
}
template <typename T>
complex_type& operator =(const T& y);
theano_complex%(nbits)s() {}
template <typename T>
theano_complex%(nbits)s(const T& y) { *this = y; }
template <typename TR, typename TI>
theano_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
};
"""
operator_eq = """
template <> %(mytype)s & %(mytype)s::operator=<npy_int8>(const npy_int8 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_int16>(const npy_int16 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_int32>(const npy_int32 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_int64>(const npy_int64 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_float32>(const npy_float32 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_float64>(const npy_float64 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<theano_complex128>(const theano_complex128 & y)
{ this->real=y.real; this->imag=y.imag; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<theano_complex64>(const theano_complex64 & y)
{ this->real=y.real; this->imag=y.imag; return *this; }
template <typename T> template <typename T>
theano_complex%(nbits)s(const T& y) { *this = y; } const %(mytype)s
operator+(const %(mytype)s &x, const T& y)
template <typename TR, typename TI> { return %(mytype)s(x.real+y, x.imag); }
theano_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
};
"""
operator_eq = """
template <> %(mytype)s & %(mytype)s::operator=<npy_int8>(const npy_int8 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_int16>(const npy_int16 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_int32>(const npy_int32 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_int64>(const npy_int64 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_float32>(const npy_float32 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<npy_float64>(const npy_float64 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator=<theano_complex128>(const theano_complex128 & y) template <typename T>
{ this->real=y.real; this->imag=y.imag; return *this; } const %(mytype)s
operator+(const T& y, const %(mytype)s &x)
template <> %(mytype)s & %(mytype)s::operator=<theano_complex64>(const theano_complex64 & y) { return %(mytype)s(x.real+y, x.imag); }
{ this->real=y.real; this->imag=y.imag; return *this; }
template <typename T> template <typename T>
const %(mytype)s const %(mytype)s
operator+(const %(mytype)s &x, const T& y) operator-(const %(mytype)s &x, const T& y)
{ return %(mytype)s(x.real+y, x.imag); } { return %(mytype)s(x.real-y, x.imag); }
template <typename T> template <typename T>
const %(mytype)s const %(mytype)s
operator+(const T& y, const %(mytype)s &x) operator-(const T& x, const %(mytype)s &y)
{ return %(mytype)s(x.real+y, x.imag); } { return %(mytype)s(x-y.real, -y.imag); }
template <typename T> template <typename T>
const %(mytype)s const %(mytype)s
operator-(const %(mytype)s &x, const T& y) operator*(const %(mytype)s &x, const T& y)
{ return %(mytype)s(x.real-y, x.imag); } { return %(mytype)s(x.real*y, x.imag*y); }
template <typename T> template <typename T>
const %(mytype)s const %(mytype)s
operator-(const T& x, const %(mytype)s &y) operator*(const T& x, const %(mytype)s &y)
{ return %(mytype)s(x-y.real, -y.imag); } { return %(mytype)s(x*y.real, x*y.imag); }
"""
template <typename T> # todo: use C templating
const %(mytype)s return template % dict(nbits = 64, half_nbits = 32) \
operator*(const %(mytype)s &x, const T& y) + template % dict(nbits = 128, half_nbits = 64) \
{ return %(mytype)s(x.real*y, x.imag*y); } + operator_eq % dict(mytype='theano_complex128') \
+ operator_eq % dict(mytype='theano_complex64')
template <typename T> else:
const %(mytype)s
operator*(const T& x, const %(mytype)s &y)
{ return %(mytype)s(x*y.real, x*y.imag); }
"""
# todo: use C templating return ""
return template % dict(nbits = 64, half_nbits = 32) \
+ template % dict(nbits = 128, half_nbits = 64) \
+ operator_eq % dict(mytype='theano_complex128') \
+ operator_eq % dict(mytype='theano_complex64')
def c_code_cache_version(self): def c_code_cache_version(self):
# no need to put lib.amdlibm here as c_compile_args() are put in the key. # no need to put lib.amdlibm here as c_compile_args() are put in the key.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论