提交 a6b41547 authored 作者: James Bergstra's avatar James Bergstra

Improved the assignment operator in c support code in scalar to handle more

types.
上级 9302cc0c
...@@ -182,23 +182,44 @@ class Scalar(Type): ...@@ -182,23 +182,44 @@ class Scalar(Type):
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square; ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
return ret; return ret;
} }
complex_type& operator =(const scalar_type& y) { template <typename T>
this->real=y; complex_type& operator =(const T& y);
this->imag=0;
return *this;
}
%(upcast)s
}; };
""" """
operator_eq = """
template <> %(mytype)s & %(mytype)s::operator =(const npy_int8 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const npy_int16 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const npy_int32 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const npy_int64 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const npy_float32 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const npy_float64 & y)
{ this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const theano_complex128 & y)
{ this->real=y.real; this->imag=y.imag; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const theano_complex64 & y)
{ this->real=y.real; this->imag=y.imag; return *this; }
"""
# todo: use C templating # todo: use C templating
return template % dict(nbits = 64, half_nbits = 32, upcast="") + template % dict(nbits = 128, half_nbits = 64, upcast=""" return template % dict(nbits = 64, half_nbits = 32) \
complex_type& operator =(theano_complex64 y) { + template % dict(nbits = 128, half_nbits = 64) \
this->real=y.real; + operator_eq % dict(mytype='theano_complex128') \
this->imag=y.imag; + operator_eq % dict(mytype='theano_complex64')
return *this;
}
""")
def c_code_cache_version(self):
return (2,)
int8 = Scalar('int8') int8 = Scalar('int8')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论