提交 dc1a1ae6 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Wrap arctanh from numpy, elemwise.

上级 41c1330d
...@@ -2199,6 +2199,25 @@ class Tanh(UnaryScalarOp): ...@@ -2199,6 +2199,25 @@ class Tanh(UnaryScalarOp):
tanh = Tanh(upgrade_to_float, name='tanh') tanh = Tanh(upgrade_to_float, name='tanh')
class ArcTanh(UnaryScalarOp):
def impl(self, x):
return numpy.arctanh(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / (numpy.cast[x.type](1) -sqr(x)),
else:
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = atanh(%(x)s);" % locals()
arctanh = ArcTanh(upgrade_to_float, name='arctanh')
class Real(UnaryScalarOp): class Real(UnaryScalarOp):
"""Extract the real coordinate of a complex number. """ """Extract the real coordinate of a complex number. """
def impl(self, x): def impl(self, x):
......
...@@ -2616,6 +2616,11 @@ def tanh(a): ...@@ -2616,6 +2616,11 @@ def tanh(a):
"""hyperbolic tangent of a""" """hyperbolic tangent of a"""
@_scal_elemwise_with_nfunc('arctanh', 1, 1)
def arctanh(a):
"""hyperbolic arc tangent of a"""
@_scal_elemwise @_scal_elemwise
def erf(a): def erf(a):
"""error function""" """error function"""
......
...@@ -183,6 +183,10 @@ def arcsinh_inplace(a): ...@@ -183,6 +183,10 @@ def arcsinh_inplace(a):
def tanh_inplace(a): def tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)""" """hyperbolic tangent of `a` (inplace on `a`)"""
@_scal_inplace
def arctanh_inplace(a):
"""hyperbolic arc tangent of `a` (inplace on `a`)"""
@_scal_inplace @_scal_inplace
def erf_inplace(a): def erf_inplace(a):
"""error function""" """error function"""
......
...@@ -1178,6 +1178,25 @@ TanhInplaceTester = makeBroadcastTester(op = inplace.tanh_inplace, ...@@ -1178,6 +1178,25 @@ TanhInplaceTester = makeBroadcastTester(op = inplace.tanh_inplace,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
inplace = True) inplace = True)
_good_broadcast_unary_arctanh = dict(normal = (rand_ranged(-1, 1, (2, 3)),),
integers = (randint_ranged(-1, 1, (2, 3)),),
complex = (randc128_ranged(-1, 1, (2, 3)),),
empty = (numpy.asarray([]),),)
_grad_broadcast_unary_arctanh = dict(normal = (rand_ranged(-1, 1, (2, 3)),),)
ArctanhTester = makeBroadcastTester(op = tensor.arctanh,
expected = numpy.arctanh,
good = _good_broadcast_unary_arctanh,
grad = _grad_broadcast_unary_arctanh)
ArctanhInplaceTester = makeBroadcastTester(op = inplace.arctanh_inplace,
expected = numpy.arctanh,
good = _good_broadcast_unary_arctanh,
grad = _grad_broadcast_unary_arctanh,
inplace = True)
#inplace ops when the input is integer and the output is float* #inplace ops when the input is integer and the output is float*
# don't have a well defined behavior. We don't test that case. # don't have a well defined behavior. We don't test that case.
_good_broadcast_unary_normal_no_int_no_complex = _good_broadcast_unary_normal_no_complex.copy() _good_broadcast_unary_normal_no_int_no_complex = _good_broadcast_unary_normal_no_complex.copy()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论