提交 0115b94e authored 作者: nouiz's avatar nouiz

Merge pull request #1004 from jey/erfinv

Add erfinv and erfcinv ops
......@@ -75,6 +75,66 @@ class Erfc(UnaryScalarOp):
erfc = Erfc(upgrade_to_float_no_complex, name='erfc')
class Erfinv(UnaryScalarOp):
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfinv(x)
else:
super(Erfinv, self).impl(x)
def grad(self, inp, grads):
x, = inp
gz, = grads
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype))
return gz * cst * exp(erfinv(x) ** 2),
else:
return None,
# TODO: erfinv() is not provided by the C standard library
#def c_code(self, node, name, inp, out, sub):
# x, = inp
# z, = out
# if node.inputs[0].type in complex_types:
# raise NotImplementedError('type not supported', type)
# return "%(z)s = erfinv(%(x)s);" % locals()
erfinv = Erfinv(upgrade_to_float_no_complex, name='erfinv')
class Erfcinv(UnaryScalarOp):
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfcinv(x)
else:
super(Erfcinv, self).impl(x)
def grad(self, inp, grads):
x, = inp
gz, = grads
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype))
return - gz * cst * exp(erfcinv(x) ** 2),
else:
return None,
# TODO: erfcinv() is not provided by the C standard library
#def c_code(self, node, name, inp, out, sub):
# x, = inp
# z, = out
# if node.inputs[0].type in complex_types:
# raise NotImplementedError('type not supported', type)
# return "%(z)s = erfcinv(%(x)s);" % locals()
erfcinv = Erfcinv(upgrade_to_float_no_complex, name='erfcinv')
class Gamma(UnaryScalarOp):
@staticmethod
def st_impl(x):
......
......@@ -2900,6 +2900,16 @@ def erfc(a):
"""complementary error function"""
@_scal_elemwise
def erfinv(a):
"""inverse error function"""
@_scal_elemwise
def erfcinv(a):
"""inverse complementary error function"""
@_scal_elemwise
def gamma(a):
"""gamma function"""
......
......@@ -892,7 +892,7 @@ _good_broadcast_unary_normal_float_no_complex = copymod(
without=['complex'])
_good_broadcast_unary_normal = dict(
normal=[numpy.asarray(rand_ranged(-5, 5, (2, 3)), dtype= config.floatX)],
normal=[numpy.asarray(rand_ranged(-5, 5, (2, 3)), dtype=config.floatX)],
integers=[randint_ranged(-5, 5, (2, 3))],
corner_case=[corner_case],
complex=[randcomplex(2, 3)],
......@@ -916,6 +916,14 @@ _grad_broadcast_unary_normal = dict(
#empty = [numpy.asarray([])] # XXX: should this be included?
)
_grad_broadcast_unary_abs1_no_complex = dict(
normal=[numpy.asarray(rand_ranged(-1, 1, (2, 3)), dtype=floatX)],
)
_grad_broadcast_unary_0_2_no_complex = dict(
normal=[numpy.asarray(rand_ranged(0, 2, (2, 3)), dtype=floatX)],
)
AbsTester = makeBroadcastTester(op=tensor.abs_,
expected=lambda x: abs(x),
......@@ -1383,6 +1391,8 @@ del _good_broadcast_unary_normal_no_int['integers']
if imported_scipy_special:
expected_erf = scipy.special.erf
expected_erfc = scipy.special.erfc
expected_erfinv = scipy.special.erfinv
expected_erfcinv = scipy.special.erfcinv
expected_gamma = scipy.special.gamma
expected_gammaln = scipy.special.gammaln
expected_psi = scipy.special.psi
......@@ -1390,6 +1400,8 @@ if imported_scipy_special:
else:
expected_erf = []
expected_erfc = []
expected_erfinv = []
expected_erfcinv = []
expected_gamma = []
expected_gammaln = []
expected_psi = []
......@@ -1431,6 +1443,24 @@ ErfcInplaceTester = makeBroadcastTester(
inplace=True,
skip=skip_scipy)
ErfinvTester = makeBroadcastTester(
op=tensor.erfinv,
expected=expected_erfinv,
good=_good_broadcast_unary_normal_no_int_no_complex,
grad=_grad_broadcast_unary_abs1_no_complex,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
ErfcinvTester = makeBroadcastTester(
op=tensor.erfcinv,
expected=expected_erfcinv,
good=_good_broadcast_unary_normal_no_int_no_complex,
grad=_grad_broadcast_unary_0_2_no_complex,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
_good_broadcast_unary_gammaln = dict(
normal=(rand_ranged(-1 + 1e-2, 10, (2, 3)),),
empty=(numpy.asarray([]),),)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论