提交 756be77e authored 作者: Brendan Murphy's avatar Brendan Murphy 提交者: Ricardo Vieira

Make complex scalars work with numpy 2.0

This is done using C++ generic functions to get/set the real/imag parts of complex numbers. This gives us an easy way to support Numpy v < 2.0, and allows the type underlying the bit width types, like pytensor_complex128, to be correctly inferred from the numpy complex types they inherit from. Updated pytensor_complex struct to use get/set real/imag aliases defined above. Also updated operators such as `Abs` to use get_real, get_imag. Macros have been added to ensure compatibility with numpy < 2.0 Note: redefining the complex arithmetic here means that we aren't treating NaNs and infinities as carefully as the C99 standard suggets (see Appendix G of the standard). The code has been like this since it was added to Theano, so we're keeping the existing behavior.
上级 bfc07777
...@@ -349,6 +349,8 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -349,6 +349,8 @@ class ScalarType(CType, HasDataType, HasShape):
# we declare them here and they will be re-used by TensorType # we declare them here and they will be re-used by TensorType
l.append("<numpy/arrayobject.h>") l.append("<numpy/arrayobject.h>")
l.append("<numpy/arrayscalars.h>") l.append("<numpy/arrayscalars.h>")
l.append("<numpy/npy_math.h>")
if config.lib__amdlibm and c_compiler.supports_amdlibm: if config.lib__amdlibm and c_compiler.supports_amdlibm:
l += ["<amdlibm.h>"] l += ["<amdlibm.h>"]
return l return l
...@@ -517,73 +519,167 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -517,73 +519,167 @@ class ScalarType(CType, HasDataType, HasShape):
# In that case we add the 'int' type to the real types. # In that case we add the 'int' type to the real types.
real_types.append("int") real_types.append("int")
# Macros for backwards compatibility with numpy < 2.0
#
# In numpy 2.0+, these are defined in npy_math.h, but
# for early versions, they must be vendored by users (e.g. PyTensor)
backwards_compat_macros = """
#ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
#define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
#include <numpy/npy_math.h>
#ifndef NPY_CSETREALF
#define NPY_CSETREALF(c, r) (c)->real = (r)
#endif
#ifndef NPY_CSETIMAGF
#define NPY_CSETIMAGF(c, i) (c)->imag = (i)
#endif
#ifndef NPY_CSETREAL
#define NPY_CSETREAL(c, r) (c)->real = (r)
#endif
#ifndef NPY_CSETIMAG
#define NPY_CSETIMAG(c, i) (c)->imag = (i)
#endif
#ifndef NPY_CSETREALL
#define NPY_CSETREALL(c, r) (c)->real = (r)
#endif
#ifndef NPY_CSETIMAGL
#define NPY_CSETIMAGL(c, i) (c)->imag = (i)
#endif
#endif
"""
def _make_get_set_real_imag(scalar_type: str) -> str:
"""Make overloaded getter/setter functions for real/imag parts of numpy complex types.
The functions called by these getter/setter functions are defining in npy_math.h, or
in the `backward_compat_macros` defined above.
Args:
scalar_type: float, double, or longdouble
Returns:
C++ code for defining set_real, set_imag, get_real, and get_imag, overloaded for the
given type.
"""
complex_type = "npy_c" + scalar_type
suffix = "" if scalar_type == "double" else scalar_type[0]
if scalar_type == "longdouble":
scalar_type = "npy_" + scalar_type
return_type = scalar_type
template = f"""
static inline {return_type} get_real(const {complex_type} z)
{{
return npy_creal{suffix}(z);
}}
static inline void set_real({complex_type} *z, const {scalar_type} r)
{{
NPY_CSETREAL{suffix.upper()}(z, r);
}}
static inline {return_type} get_imag(const {complex_type} z)
{{
return npy_cimag{suffix}(z);
}}
static inline void set_imag({complex_type} *z, const {scalar_type} i)
{{
NPY_CSETIMAG{suffix.upper()}(z, i);
}}
"""
return template
get_set_aliases = "\n".join(
_make_get_set_real_imag(stype)
for stype in ["float", "double", "longdouble"]
)
get_set_aliases = backwards_compat_macros + "\n" + get_set_aliases
# Template for defining pytensor_complex64 and pytensor_complex128 structs/classes
#
# The npy_complex64, npy_complex128 types are aliases defined at run time based on
# the size of floats and doubles on the machine. This means that both types are
# not necessarily defined on every machine, but a machine with 32-bit floats and
# 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128
# as an alias of npy_complex128.
#
# In any case, the get/set real/imag functions defined above will always work for
# npy_complex64 and npy_complex128.
template = """ template = """
struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s {
{ typedef pytensor_complex%(nbits)s complex_type;
typedef pytensor_complex%(nbits)s complex_type; typedef npy_float%(half_nbits)s scalar_type;
typedef npy_float%(half_nbits)s scalar_type;
complex_type operator+(const complex_type &y) const {
complex_type operator +(const complex_type &y) const { complex_type ret;
complex_type ret; set_real(&ret, get_real(*this) + get_real(y));
ret.real = this->real + y.real; set_imag(&ret, get_imag(*this) + get_imag(y));
ret.imag = this->imag + y.imag; return ret;
return ret; }
}
complex_type operator-() const {
complex_type operator -() const { complex_type ret;
complex_type ret; set_real(&ret, -get_real(*this));
ret.real = -this->real; set_imag(&ret, -get_imag(*this));
ret.imag = -this->imag; return ret;
return ret; }
} bool operator==(const complex_type &y) const {
bool operator ==(const complex_type &y) const { return (get_real(*this) == get_real(y)) && (get_imag(*this) == get_imag(y));
return (this->real == y.real) && (this->imag == y.imag); }
} bool operator==(const scalar_type &y) const {
bool operator ==(const scalar_type &y) const { return (get_real(*this) == y) && (get_real(*this) == 0);
return (this->real == y) && (this->imag == 0); }
} complex_type operator-(const complex_type &y) const {
complex_type operator -(const complex_type &y) const { complex_type ret;
complex_type ret; set_real(&ret, get_real(*this) - get_real(y));
ret.real = this->real - y.real; set_imag(&ret, get_imag(*this) - get_imag(y));
ret.imag = this->imag - y.imag; return ret;
return ret; }
} complex_type operator*(const complex_type &y) const {
complex_type operator *(const complex_type &y) const { complex_type ret;
complex_type ret; set_real(&ret, get_real(*this) * get_real(y) - get_imag(*this) * get_imag(y));
ret.real = this->real * y.real - this->imag * y.imag; set_imag(&ret, get_imag(*this) * get_real(y) + get_real(*this) * get_imag(y));
ret.imag = this->real * y.imag + this->imag * y.real; return ret;
return ret; }
} complex_type operator/(const complex_type &y) const {
complex_type operator /(const complex_type &y) const { complex_type ret;
complex_type ret; scalar_type y_norm_square = get_real(y) * get_real(y) + get_imag(y) * get_imag(y);
scalar_type y_norm_square = y.real * y.real + y.imag * y.imag; set_real(&ret, (get_real(*this) * get_real(y) + get_imag(*this) * get_imag(y)) / y_norm_square);
ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square; set_imag(&ret, (get_imag(*this) * get_real(y) - get_real(*this) * get_imag(y)) / y_norm_square);
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square; return ret;
return ret; }
} template <typename T> complex_type &operator=(const T &y);
template <typename T>
complex_type& operator =(const T& y);
pytensor_complex%(nbits)s() {}
pytensor_complex%(nbits)s() {}
template <typename T> pytensor_complex%(nbits)s(const T &y) { *this = y; }
template <typename T>
pytensor_complex%(nbits)s(const T& y) { *this = y; } template <typename TR, typename TI>
pytensor_complex%(nbits)s(const TR &r, const TI &i) {
template <typename TR, typename TI> set_real(this, r);
pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; } set_imag(this, i);
}
}; };
""" """
def operator_eq_real(mytype, othertype): def operator_eq_real(mytype, othertype):
return f""" return f"""
template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y) template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y)
{{ this->real=y; this->imag=0; return *this; }} {{ set_real(this, y); set_imag(this, 0); return *this; }}
""" """
def operator_eq_cplx(mytype, othertype): def operator_eq_cplx(mytype, othertype):
return f""" return f"""
template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y) template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y)
{{ this->real=y.real; this->imag=y.imag; return *this; }} {{ set_real(this, get_real(y)); set_imag(this, get_imag(y)); return *this; }}
""" """
operator_eq = "".join( operator_eq = "".join(
...@@ -605,10 +701,10 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -605,10 +701,10 @@ class ScalarType(CType, HasDataType, HasShape):
def operator_plus_real(mytype, othertype): def operator_plus_real(mytype, othertype):
return f""" return f"""
const {mytype} operator+(const {mytype} &x, const {othertype} &y) const {mytype} operator+(const {mytype} &x, const {othertype} &y)
{{ return {mytype}(x.real+y, x.imag); }} {{ return {mytype}(get_real(x) + y, get_imag(x)); }}
const {mytype} operator+(const {othertype} &y, const {mytype} &x) const {mytype} operator+(const {othertype} &y, const {mytype} &x)
{{ return {mytype}(x.real+y, x.imag); }} {{ return {mytype}(get_real(x) + y, get_imag(x)); }}
""" """
operator_plus = "".join( operator_plus = "".join(
...@@ -620,10 +716,10 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -620,10 +716,10 @@ class ScalarType(CType, HasDataType, HasShape):
def operator_minus_real(mytype, othertype): def operator_minus_real(mytype, othertype):
return f""" return f"""
const {mytype} operator-(const {mytype} &x, const {othertype} &y) const {mytype} operator-(const {mytype} &x, const {othertype} &y)
{{ return {mytype}(x.real-y, x.imag); }} {{ return {mytype}(get_real(x) - y, get_imag(x)); }}
const {mytype} operator-(const {othertype} &y, const {mytype} &x) const {mytype} operator-(const {othertype} &y, const {mytype} &x)
{{ return {mytype}(y-x.real, -x.imag); }} {{ return {mytype}(y - get_real(x), -get_imag(x)); }}
""" """
operator_minus = "".join( operator_minus = "".join(
...@@ -635,10 +731,10 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -635,10 +731,10 @@ class ScalarType(CType, HasDataType, HasShape):
def operator_mul_real(mytype, othertype): def operator_mul_real(mytype, othertype):
return f""" return f"""
const {mytype} operator*(const {mytype} &x, const {othertype} &y) const {mytype} operator*(const {mytype} &x, const {othertype} &y)
{{ return {mytype}(x.real*y, x.imag*y); }} {{ return {mytype}(get_real(x) * y, get_imag(x) * y); }}
const {mytype} operator*(const {othertype} &y, const {mytype} &x) const {mytype} operator*(const {othertype} &y, const {mytype} &x)
{{ return {mytype}(x.real*y, x.imag*y); }} {{ return {mytype}(get_real(x) * y, get_imag(x) * y); }}
""" """
operator_mul = "".join( operator_mul = "".join(
...@@ -648,7 +744,8 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -648,7 +744,8 @@ class ScalarType(CType, HasDataType, HasShape):
) )
return ( return (
template % dict(nbits=64, half_nbits=32) get_set_aliases
+ template % dict(nbits=64, half_nbits=32)
+ template % dict(nbits=128, half_nbits=64) + template % dict(nbits=128, half_nbits=64)
+ operator_eq + operator_eq
+ operator_plus + operator_plus
...@@ -663,7 +760,7 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -663,7 +760,7 @@ class ScalarType(CType, HasDataType, HasShape):
return ["import_array();"] return ["import_array();"]
def c_code_cache_version(self): def c_code_cache_version(self):
return (13, np.__version__) return (14, np.__version__)
def get_shape_info(self, obj): def get_shape_info(self, obj):
return obj.itemsize return obj.itemsize
...@@ -2567,7 +2664,7 @@ class Abs(UnaryScalarOp): ...@@ -2567,7 +2664,7 @@ class Abs(UnaryScalarOp):
if type in float_types: if type in float_types:
return f"{z} = fabs({x});" return f"{z} = fabs({x});"
if type in complex_types: if type in complex_types:
return f"{z} = sqrt({x}.real*{x}.real + {x}.imag*{x}.imag);" return f"{z} = sqrt(get_real({x}) * get_real({x}) + get_imag({x}) * get_imag({x}));"
if node.outputs[0].type == bool: if node.outputs[0].type == bool:
return f"{z} = ({x}) ? 1 : 0;" return f"{z} = ({x}) ? 1 : 0;"
if type in uint_types: if type in uint_types:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论