Unverified 提交 70ed6548 authored 作者: ricardoV94's avatar ricardoV94 提交者: GitHub

Add C code for Erfcx (#317)

* Add Faddeeva functions for erfcx * Fix c_code * call c_header_dirs in Elemwise * Add Faddeeva functions for erfcx * Call CLinkerObject.c_header_dirs for scalar Op in Elemwise * Add c_header_dirs method to aesara.scalar.Composite * Clean up comments * Remove Faddeeva.h Co-authored-by: 's avatarBrandon T. Willard <brandonwillard@users.noreply.github.com>
上级 65c410b0
......@@ -4255,6 +4255,13 @@ class Composite(ScalarOp):
return ()
return tuple(rval)
def c_header_dirs(self, **kwargs):
rval = sum(
(subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()),
[],
)
return rval
def c_support_code(self, **kwargs):
# Remove duplicate code blocks by using a `set`
rval = {
......
......@@ -150,6 +150,31 @@ class Erfcx(UnaryScalarOp):
)
return (gz * (-cst + (2.0 * x) * erfcx(x)),)
def c_header_dirs(self, **kwargs):
# Using the Faddeeva.hh (c++) header for Faddeevva.cc
res = super().c_header_dirs(**kwargs) + [
os.path.join(os.path.dirname(__file__), "c_code")
]
return res
def c_support_code(self, **kwargs):
# Using Faddeeva.cc source file from: http://ab-initio.mit.edu/wiki/index.php/Faddeeva_Package
with open(
os.path.join(os.path.dirname(__file__), "c_code", "Faddeeva.cc")
) as f:
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
return f"{z} = ({dtype}) Faddeeva::erfcx({x});"
raise NotImplementedError("type not supported", type)
erfcx = Erfcx(upgrade_to_float_no_complex, name="erfcx")
......
This source diff could not be displayed because it is too large. You can view the blob instead.
/* Copyright (c) 2012 Massachusetts Institute of Technology
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
* LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
* OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
* WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
/* Available at: http://ab-initio.mit.edu/Faddeeva
Header file for Faddeeva.cc; see that file for more information. */
#ifndef FADDEEVA_HH
#define FADDEEVA_HH 1
#include <complex>
namespace Faddeeva {
// compute w(z) = exp(-z^2) erfc(-iz) [ Faddeeva / scaled complex error func ]
extern std::complex<double> w(std::complex<double> z,double relerr=0);
extern double w_im(double x); // special-case code for Im[w(x)] of real x
// Various functions that we can compute with the help of w(z)
// compute erfcx(z) = exp(z^2) erfc(z)
extern std::complex<double> erfcx(std::complex<double> z, double relerr=0);
extern double erfcx(double x); // special case for real x
// compute erf(z), the error function of complex arguments
extern std::complex<double> erf(std::complex<double> z, double relerr=0);
extern double erf(double x); // special case for real x
// compute erfi(z) = -i erf(iz), the imaginary error function
extern std::complex<double> erfi(std::complex<double> z, double relerr=0);
extern double erfi(double x); // special case for real x
// compute erfc(z) = 1 - erf(z), the complementary error function
extern std::complex<double> erfc(std::complex<double> z, double relerr=0);
extern double erfc(double x); // special case for real x
// compute Dawson(z) = sqrt(pi)/2 * exp(-z^2) * erfi(z)
extern std::complex<double> Dawson(std::complex<double> z, double relerr=0);
extern double Dawson(double x); // special case for real x
} // namespace Faddeeva
#endif // FADDEEVA_HH
......@@ -1218,6 +1218,9 @@ second dimension
def c_headers(self, **kwargs):
return ["<vector>", "<algorithm>"]
def c_header_dirs(self, **kwargs):
return self.scalar_op.c_header_dirs(**kwargs)
def c_support_code(self, **kwargs):
return self.scalar_op.c_support_code(**kwargs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论