提交 4d1a5027 authored 作者: Jey Kottalam's avatar Jey Kottalam

add initial implementation of erfinv() and erfcinv()

上级 295ecea7
......@@ -75,6 +75,67 @@ class Erfc(UnaryScalarOp):
erfc = Erfc(upgrade_to_float_no_complex, name='erfc')
class Erfinv(UnaryScalarOp):
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfinv(x)
else:
super(Erfinv, self).impl(x)
def grad(self, inp, grads):
x, = inp
gz, = grads
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype))
return gz * cst * exp(erfinv(x) ** 2),
else:
return None,
# TODO
#def c_code(self, node, name, inp, out, sub):
# x, = inp
# z, = out
# if node.inputs[0].type in complex_types:
# raise NotImplementedError('type not supported', type)
# return "%(z)s = erfinv(%(x)s);" % locals()
erfinv = Erfinv(upgrade_to_float_no_complex, name='erfinv')
class Erfcinv(UnaryScalarOp):
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfcinv(x)
else:
super(Erfcinv, self).impl(x)
def grad(self, inp, grads):
x, = inp
gz, = grads
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype))
return - gz * cst * exp(erfcinv(x) ** 2),
else:
return None,
# TODO
#def c_code(self, node, name, inp, out, sub):
# x, = inp
# z, = out
# if node.inputs[0].type in complex_types:
# raise NotImplementedError('type not supported', type)
# return "%(z)s = erfcinv(%(x)s);" % locals()
erfcinv = Erfcinv(upgrade_to_float_no_complex, name='erfcinv')
class Gamma(UnaryScalarOp):
@staticmethod
def st_impl(x):
......
......@@ -2847,6 +2847,16 @@ def erfc(a):
"""complementary error function"""
@_scal_elemwise
def erfinv(a):
"""inverse error function"""
@_scal_elemwise
def erfcinv(a):
"""inverse complementary error function"""
@_scal_elemwise
def gamma(a):
"""gamma function"""
......
......@@ -915,6 +915,14 @@ _grad_broadcast_unary_normal = dict(
#empty = [numpy.asarray([])] # XXX: should this be included?
)
_grad_broadcast_unary_abs1_no_complex = dict(
normal=[numpy.asarray(rand_ranged(-1, 1, (2, 3)), dtype=floatX)],
)
_grad_broadcast_unary_0_2_no_complex = dict(
normal=[numpy.asarray(rand_ranged(0, 2, (2, 3)), dtype=floatX)],
)
AbsTester = makeBroadcastTester(op=tensor.abs_,
expected=lambda x: abs(x),
......@@ -1382,6 +1390,8 @@ del _good_broadcast_unary_normal_no_int['integers']
if imported_scipy_special:
expected_erf = scipy.special.erf
expected_erfc = scipy.special.erfc
expected_erfinv = scipy.special.erfinv
expected_erfcinv = scipy.special.erfcinv
expected_gamma = scipy.special.gamma
expected_gammaln = scipy.special.gammaln
expected_psi = scipy.special.psi
......@@ -1430,6 +1440,24 @@ ErfcInplaceTester = makeBroadcastTester(
inplace=True,
skip=skip_scipy)
ErfinvTester = makeBroadcastTester(
op=tensor.erfinv,
expected=expected_erfinv,
good=_good_broadcast_unary_normal_no_int_no_complex,
grad=_grad_broadcast_unary_abs1_no_complex,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
ErfcinvTester = makeBroadcastTester(
op=tensor.erfcinv,
expected=expected_erfcinv,
good=_good_broadcast_unary_normal_no_int_no_complex,
grad=_grad_broadcast_unary_0_2_no_complex,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
_good_broadcast_unary_gammaln = dict(
normal=(rand_ranged(-1 + 1e-2, 10, (2, 3)),),
empty=(numpy.asarray([]),),)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论