提交 5626a476 authored 作者: nouiz's avatar nouiz

Merge pull request #674 from bouchnic/new_elemwise

New elemwise
...@@ -1944,6 +1944,25 @@ class Exp(UnaryScalarOp): ...@@ -1944,6 +1944,25 @@ class Exp(UnaryScalarOp):
exp = Exp(upgrade_to_float, name='exp') exp = Exp(upgrade_to_float, name='exp')
class Exp2(UnaryScalarOp):
def impl(self, x):
return numpy.exp2(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
return gz * exp2(x) * log(numpy.cast[x.type](2)),
else:
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = exp2(%(x)s);" % locals()
exp2 = Exp2(upgrade_to_float, name='exp2')
class Sqr(UnaryScalarOp): class Sqr(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return x * x return x * x
...@@ -1999,7 +2018,7 @@ class Cos(UnaryScalarOp): ...@@ -1999,7 +2018,7 @@ class Cos(UnaryScalarOp):
cos = Cos(upgrade_to_float, name='cos') cos = Cos(upgrade_to_float, name='cos')
class Arccos(UnaryScalarOp): class ArcCos(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.arccos(x) return numpy.arccos(x)
...@@ -2015,7 +2034,7 @@ class Arccos(UnaryScalarOp): ...@@ -2015,7 +2034,7 @@ class Arccos(UnaryScalarOp):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = acos(%(x)s);" % locals() return "%(z)s = acos(%(x)s);" % locals()
arccos = Arccos(upgrade_to_float, name='arccos') arccos = ArcCos(upgrade_to_float, name='arccos')
class Sin(UnaryScalarOp): class Sin(UnaryScalarOp):
...@@ -2037,6 +2056,25 @@ class Sin(UnaryScalarOp): ...@@ -2037,6 +2056,25 @@ class Sin(UnaryScalarOp):
sin = Sin(upgrade_to_float, name='sin') sin = Sin(upgrade_to_float, name='sin')
class ArcSin(UnaryScalarOp):
def impl(self, x):
return numpy.arcsin(x)
def grad(self, (x,), (gz,)):
if gz.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / sqrt(numpy.cast[x.type](1) - sqr(x)),
else:
return None,
def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = asin(%(x)s);" % locals()
arcsin = ArcSin(upgrade_to_float, name='arcsin')
class Tan(UnaryScalarOp): class Tan(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.tan(x) return numpy.tan(x)
...@@ -2056,6 +2094,46 @@ class Tan(UnaryScalarOp): ...@@ -2056,6 +2094,46 @@ class Tan(UnaryScalarOp):
tan = Tan(upgrade_to_float, name='tan') tan = Tan(upgrade_to_float, name='tan')
class ArcTan(UnaryScalarOp):
def impl(self, x):
return numpy.arctan(x)
def grad(self, (x,), (gz,)):
if gz.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / (numpy.cast[x.type](1) + sqr(x)),
else:
return None,
def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = atan(%(x)s);" % locals()
arctan = ArcTan(upgrade_to_float, name='arctan')
class ArcTan2(BinaryScalarOp):
def impl(self, y, x):
return numpy.arctan2(y, x)
def grad(self, (y, x), (gz,)):
if gz.type in complex_types:
raise NotImplementedError()
if x.type in float_types and y.type in float_types:
return [gz * x / (sqr(x) + sqr(y)),
gz * neg(y) / (sqr(x) + sqr(y))]
else:
return None,
def c_code(self, node, name, (y, x), (z,), sub):
if (node.inputs[0].type in complex_types or
node.inputs[1].type in complex_types):
raise NotImplementedError('type not supported', type)
return "%(z)s = atan2(%(y)s, %(x)s);" % locals()
arctan2 = ArcTan2(upgrade_to_float, name='arctan2')
class Cosh(UnaryScalarOp): class Cosh(UnaryScalarOp):
""" """
cosh(x) = (exp(x) + exp(-x)) / 2 cosh(x) = (exp(x) + exp(-x)) / 2
...@@ -2078,6 +2156,25 @@ class Cosh(UnaryScalarOp): ...@@ -2078,6 +2156,25 @@ class Cosh(UnaryScalarOp):
cosh = Cosh(upgrade_to_float, name='cosh') cosh = Cosh(upgrade_to_float, name='cosh')
class ArcCosh(UnaryScalarOp):
def impl(self, x):
return numpy.arccosh(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / sqrt(sqr(x) - numpy.cast[x.type](1)),
else:
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = acosh(%(x)s);" % locals()
arccosh = ArcCosh(upgrade_to_float, name='arccosh')
class Sinh(UnaryScalarOp): class Sinh(UnaryScalarOp):
""" """
sinh(x) = (exp(x) - exp(-x)) / 2 sinh(x) = (exp(x) - exp(-x)) / 2
...@@ -2100,6 +2197,25 @@ class Sinh(UnaryScalarOp): ...@@ -2100,6 +2197,25 @@ class Sinh(UnaryScalarOp):
sinh = Sinh(upgrade_to_float, name='sinh') sinh = Sinh(upgrade_to_float, name='sinh')
class ArcSinh(UnaryScalarOp):
def impl(self, x):
return numpy.arcsinh(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / sqrt(sqr(x) + numpy.cast[x.type](1)),
else:
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = asinh(%(x)s);" % locals()
arcsinh = ArcSinh(upgrade_to_float, name='arcsinh')
class Tanh(UnaryScalarOp): class Tanh(UnaryScalarOp):
""" """
tanh(x) = sinh(x) / cosh(x) tanh(x) = sinh(x) / cosh(x)
...@@ -2123,6 +2239,25 @@ class Tanh(UnaryScalarOp): ...@@ -2123,6 +2239,25 @@ class Tanh(UnaryScalarOp):
tanh = Tanh(upgrade_to_float, name='tanh') tanh = Tanh(upgrade_to_float, name='tanh')
class ArcTanh(UnaryScalarOp):
def impl(self, x):
return numpy.arctanh(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / (numpy.cast[x.type](1) -sqr(x)),
else:
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = atanh(%(x)s);" % locals()
arctanh = ArcTanh(upgrade_to_float, name='arctanh')
class Real(UnaryScalarOp): class Real(UnaryScalarOp):
"""Extract the real coordinate of a complex number. """ """Extract the real coordinate of a complex number. """
def impl(self, x): def impl(self, x):
......
...@@ -2476,6 +2476,11 @@ def exp(a): ...@@ -2476,6 +2476,11 @@ def exp(a):
"""e^`a`""" """e^`a`"""
@_scal_elemwise_with_nfunc('exp2', 1, 1)
def exp2(a):
"""2^`a`"""
@_scal_elemwise_with_nfunc('negative', 1, 1) @_scal_elemwise_with_nfunc('negative', 1, 1)
def neg(a): def neg(a):
"""-a""" """-a"""
...@@ -2575,26 +2580,56 @@ def sin(a): ...@@ -2575,26 +2580,56 @@ def sin(a):
"""sine of a""" """sine of a"""
@_scal_elemwise_with_nfunc('arcsin', 1, 1)
def arcsin(a):
"""arcsine of a"""
@_scal_elemwise_with_nfunc('tan', 1, 1) @_scal_elemwise_with_nfunc('tan', 1, 1)
def tan(a): def tan(a):
"""tangent of a""" """tangent of a"""
@_scal_elemwise_with_nfunc('arctan', 1, 1)
def arctan(a):
"""arctangent of a"""
@_scal_elemwise_with_nfunc('arctan2', 1, 1)
def arctan2(a, b):
"""arctangent of a / b"""
@_scal_elemwise_with_nfunc('cosh', 1, 1) @_scal_elemwise_with_nfunc('cosh', 1, 1)
def cosh(a): def cosh(a):
"""hyperbolic cosine of a""" """hyperbolic cosine of a"""
@_scal_elemwise_with_nfunc('arccosh', 1, 1)
def arccosh(a):
"""hyperbolic arc cosine of a"""
@_scal_elemwise_with_nfunc('sinh', 1, 1) @_scal_elemwise_with_nfunc('sinh', 1, 1)
def sinh(a): def sinh(a):
"""hyperbolic sine of a""" """hyperbolic sine of a"""
@_scal_elemwise_with_nfunc('arcsinh', 1, 1)
def arcsinh(a):
"""hyperbolic arc sine of a"""
@_scal_elemwise_with_nfunc('tanh', 1, 1) @_scal_elemwise_with_nfunc('tanh', 1, 1)
def tanh(a): def tanh(a):
"""hyperbolic tangent of a""" """hyperbolic tangent of a"""
@_scal_elemwise_with_nfunc('arctanh', 1, 1)
def arctanh(a):
"""hyperbolic arc tangent of a"""
@_scal_elemwise @_scal_elemwise
def erf(a): def erf(a):
"""error function""" """error function"""
......
from basic import _scal_elemwise #, _transpose_inplace from basic import _scal_elemwise #, _transpose_inplace
from theano import scalar as scal from theano import scalar as scal
import elemwise import elemwise
...@@ -88,6 +87,10 @@ def abs__inplace(a): ...@@ -88,6 +87,10 @@ def abs__inplace(a):
def exp_inplace(a): def exp_inplace(a):
"""e^`a` (inplace on `a`)""" """e^`a` (inplace on `a`)"""
@_scal_inplace
def exp2_inplace(a):
"""2^`a` (inplace on `a`)"""
@_scal_inplace @_scal_inplace
def neg_inplace(a): def neg_inplace(a):
"""-a (inplace on a)""" """-a (inplace on a)"""
...@@ -152,22 +155,46 @@ def arccos_inplace(a): ...@@ -152,22 +155,46 @@ def arccos_inplace(a):
def sin_inplace(a): def sin_inplace(a):
"""sine of `a` (inplace on `a`)""" """sine of `a` (inplace on `a`)"""
@_scal_inplace
def arcsin_inplace(a):
"""arcsine of `a` (inplace on `a`)"""
@_scal_inplace @_scal_inplace
def tan_inplace(a): def tan_inplace(a):
"""tangent of `a` (inplace on `a`)""" """tangent of `a` (inplace on `a`)"""
@_scal_inplace
def arctan_inplace(a):
"""arctangent of `a` (inplace on `a`)"""
@_scal_inplace
def arctan2_inplace(a, b):
"""arctangent of `a` / `b` (inplace on `a`)"""
@_scal_inplace @_scal_inplace
def cosh_inplace(a): def cosh_inplace(a):
"""hyperbolic cosine of `a` (inplace on `a`)""" """hyperbolic cosine of `a` (inplace on `a`)"""
@_scal_inplace
def arccosh_inplace(a):
"""hyperbolic arc cosine of `a` (inplace on `a`)"""
@_scal_inplace @_scal_inplace
def sinh_inplace(a): def sinh_inplace(a):
"""hyperbolic sine of `a` (inplace on `a`)""" """hyperbolic sine of `a` (inplace on `a`)"""
@_scal_inplace
def arcsinh_inplace(a):
"""hyperbolic arc sine of `a` (inplace on `a`)"""
@_scal_inplace @_scal_inplace
def tanh_inplace(a): def tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)""" """hyperbolic tangent of `a` (inplace on `a`)"""
@_scal_inplace
def arctanh_inplace(a):
"""hyperbolic arc tangent of `a` (inplace on `a`)"""
@_scal_inplace @_scal_inplace
def erf_inplace(a): def erf_inplace(a):
"""error function""" """error function"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论