提交 dc1a1ae6 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Wrap arctanh from numpy, elemwise.

上级 41c1330d
......@@ -2199,6 +2199,25 @@ class Tanh(UnaryScalarOp):
tanh = Tanh(upgrade_to_float, name='tanh')
class ArcTanh(UnaryScalarOp):
def impl(self, x):
return numpy.arctanh(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / (numpy.cast[x.type](1) -sqr(x)),
else:
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = atanh(%(x)s);" % locals()
arctanh = ArcTanh(upgrade_to_float, name='arctanh')
class Real(UnaryScalarOp):
"""Extract the real coordinate of a complex number. """
def impl(self, x):
......
......@@ -2616,6 +2616,11 @@ def tanh(a):
"""hyperbolic tangent of a"""
@_scal_elemwise_with_nfunc('arctanh', 1, 1)
def arctanh(a):
"""hyperbolic arc tangent of a"""
@_scal_elemwise
def erf(a):
"""error function"""
......
......@@ -183,6 +183,10 @@ def arcsinh_inplace(a):
def tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)"""
@_scal_inplace
def arctanh_inplace(a):
"""hyperbolic arc tangent of `a` (inplace on `a`)"""
@_scal_inplace
def erf_inplace(a):
"""error function"""
......
......@@ -1178,6 +1178,25 @@ TanhInplaceTester = makeBroadcastTester(op = inplace.tanh_inplace,
grad = _grad_broadcast_unary_normal,
inplace = True)
_good_broadcast_unary_arctanh = dict(normal = (rand_ranged(-1, 1, (2, 3)),),
integers = (randint_ranged(-1, 1, (2, 3)),),
complex = (randc128_ranged(-1, 1, (2, 3)),),
empty = (numpy.asarray([]),),)
_grad_broadcast_unary_arctanh = dict(normal = (rand_ranged(-1, 1, (2, 3)),),)
ArctanhTester = makeBroadcastTester(op = tensor.arctanh,
expected = numpy.arctanh,
good = _good_broadcast_unary_arctanh,
grad = _grad_broadcast_unary_arctanh)
ArctanhInplaceTester = makeBroadcastTester(op = inplace.arctanh_inplace,
expected = numpy.arctanh,
good = _good_broadcast_unary_arctanh,
grad = _grad_broadcast_unary_arctanh,
inplace = True)
#inplace ops when the input is integer and the output is float*
# don't have a well defined behavior. We don't test that case.
_good_broadcast_unary_normal_no_int_no_complex = _good_broadcast_unary_normal_no_complex.copy()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论