提交 0115b94e authored 作者: nouiz's avatar nouiz

Merge pull request #1004 from jey/erfinv

Add erfinv and erfcinv ops
...@@ -75,6 +75,66 @@ class Erfc(UnaryScalarOp): ...@@ -75,6 +75,66 @@ class Erfc(UnaryScalarOp):
erfc = Erfc(upgrade_to_float_no_complex, name='erfc') erfc = Erfc(upgrade_to_float_no_complex, name='erfc')
class Erfinv(UnaryScalarOp):
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfinv(x)
else:
super(Erfinv, self).impl(x)
def grad(self, inp, grads):
x, = inp
gz, = grads
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype))
return gz * cst * exp(erfinv(x) ** 2),
else:
return None,
# TODO: erfinv() is not provided by the C standard library
#def c_code(self, node, name, inp, out, sub):
# x, = inp
# z, = out
# if node.inputs[0].type in complex_types:
# raise NotImplementedError('type not supported', type)
# return "%(z)s = erfinv(%(x)s);" % locals()
erfinv = Erfinv(upgrade_to_float_no_complex, name='erfinv')
class Erfcinv(UnaryScalarOp):
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfcinv(x)
else:
super(Erfcinv, self).impl(x)
def grad(self, inp, grads):
x, = inp
gz, = grads
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype))
return - gz * cst * exp(erfcinv(x) ** 2),
else:
return None,
# TODO: erfcinv() is not provided by the C standard library
#def c_code(self, node, name, inp, out, sub):
# x, = inp
# z, = out
# if node.inputs[0].type in complex_types:
# raise NotImplementedError('type not supported', type)
# return "%(z)s = erfcinv(%(x)s);" % locals()
erfcinv = Erfcinv(upgrade_to_float_no_complex, name='erfcinv')
class Gamma(UnaryScalarOp): class Gamma(UnaryScalarOp):
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
......
...@@ -2900,6 +2900,16 @@ def erfc(a): ...@@ -2900,6 +2900,16 @@ def erfc(a):
"""complementary error function""" """complementary error function"""
@_scal_elemwise
def erfinv(a):
"""inverse error function"""
@_scal_elemwise
def erfcinv(a):
"""inverse complementary error function"""
@_scal_elemwise @_scal_elemwise
def gamma(a): def gamma(a):
"""gamma function""" """gamma function"""
......
...@@ -892,7 +892,7 @@ _good_broadcast_unary_normal_float_no_complex = copymod( ...@@ -892,7 +892,7 @@ _good_broadcast_unary_normal_float_no_complex = copymod(
without=['complex']) without=['complex'])
_good_broadcast_unary_normal = dict( _good_broadcast_unary_normal = dict(
normal=[numpy.asarray(rand_ranged(-5, 5, (2, 3)), dtype= config.floatX)], normal=[numpy.asarray(rand_ranged(-5, 5, (2, 3)), dtype=config.floatX)],
integers=[randint_ranged(-5, 5, (2, 3))], integers=[randint_ranged(-5, 5, (2, 3))],
corner_case=[corner_case], corner_case=[corner_case],
complex=[randcomplex(2, 3)], complex=[randcomplex(2, 3)],
...@@ -916,6 +916,14 @@ _grad_broadcast_unary_normal = dict( ...@@ -916,6 +916,14 @@ _grad_broadcast_unary_normal = dict(
#empty = [numpy.asarray([])] # XXX: should this be included? #empty = [numpy.asarray([])] # XXX: should this be included?
) )
_grad_broadcast_unary_abs1_no_complex = dict(
normal=[numpy.asarray(rand_ranged(-1, 1, (2, 3)), dtype=floatX)],
)
_grad_broadcast_unary_0_2_no_complex = dict(
normal=[numpy.asarray(rand_ranged(0, 2, (2, 3)), dtype=floatX)],
)
AbsTester = makeBroadcastTester(op=tensor.abs_, AbsTester = makeBroadcastTester(op=tensor.abs_,
expected=lambda x: abs(x), expected=lambda x: abs(x),
...@@ -1383,6 +1391,8 @@ del _good_broadcast_unary_normal_no_int['integers'] ...@@ -1383,6 +1391,8 @@ del _good_broadcast_unary_normal_no_int['integers']
if imported_scipy_special: if imported_scipy_special:
expected_erf = scipy.special.erf expected_erf = scipy.special.erf
expected_erfc = scipy.special.erfc expected_erfc = scipy.special.erfc
expected_erfinv = scipy.special.erfinv
expected_erfcinv = scipy.special.erfcinv
expected_gamma = scipy.special.gamma expected_gamma = scipy.special.gamma
expected_gammaln = scipy.special.gammaln expected_gammaln = scipy.special.gammaln
expected_psi = scipy.special.psi expected_psi = scipy.special.psi
...@@ -1390,6 +1400,8 @@ if imported_scipy_special: ...@@ -1390,6 +1400,8 @@ if imported_scipy_special:
else: else:
expected_erf = [] expected_erf = []
expected_erfc = [] expected_erfc = []
expected_erfinv = []
expected_erfcinv = []
expected_gamma = [] expected_gamma = []
expected_gammaln = [] expected_gammaln = []
expected_psi = [] expected_psi = []
...@@ -1431,6 +1443,24 @@ ErfcInplaceTester = makeBroadcastTester( ...@@ -1431,6 +1443,24 @@ ErfcInplaceTester = makeBroadcastTester(
inplace=True, inplace=True,
skip=skip_scipy) skip=skip_scipy)
ErfinvTester = makeBroadcastTester(
op=tensor.erfinv,
expected=expected_erfinv,
good=_good_broadcast_unary_normal_no_int_no_complex,
grad=_grad_broadcast_unary_abs1_no_complex,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
ErfcinvTester = makeBroadcastTester(
op=tensor.erfcinv,
expected=expected_erfcinv,
good=_good_broadcast_unary_normal_no_int_no_complex,
grad=_grad_broadcast_unary_0_2_no_complex,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
_good_broadcast_unary_gammaln = dict( _good_broadcast_unary_gammaln = dict(
normal=(rand_ranged(-1 + 1e-2, 10, (2, 3)),), normal=(rand_ranged(-1 + 1e-2, 10, (2, 3)),),
empty=(numpy.asarray([]),),) empty=(numpy.asarray([]),),)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论