提交 5626a476 authored 作者: nouiz's avatar nouiz

Merge pull request #674 from bouchnic/new_elemwise

New elemwise
......@@ -1944,6 +1944,25 @@ class Exp(UnaryScalarOp):
exp = Exp(upgrade_to_float, name='exp')
class Exp2(UnaryScalarOp):
def impl(self, x):
return numpy.exp2(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
return gz * exp2(x) * log(numpy.cast[x.type](2)),
else:
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = exp2(%(x)s);" % locals()
exp2 = Exp2(upgrade_to_float, name='exp2')
class Sqr(UnaryScalarOp):
def impl(self, x):
return x * x
......@@ -1999,7 +2018,7 @@ class Cos(UnaryScalarOp):
cos = Cos(upgrade_to_float, name='cos')
class Arccos(UnaryScalarOp):
class ArcCos(UnaryScalarOp):
def impl(self, x):
return numpy.arccos(x)
......@@ -2015,7 +2034,7 @@ class Arccos(UnaryScalarOp):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = acos(%(x)s);" % locals()
arccos = Arccos(upgrade_to_float, name='arccos')
arccos = ArcCos(upgrade_to_float, name='arccos')
class Sin(UnaryScalarOp):
......@@ -2037,6 +2056,25 @@ class Sin(UnaryScalarOp):
sin = Sin(upgrade_to_float, name='sin')
class ArcSin(UnaryScalarOp):
def impl(self, x):
return numpy.arcsin(x)
def grad(self, (x,), (gz,)):
if gz.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / sqrt(numpy.cast[x.type](1) - sqr(x)),
else:
return None,
def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = asin(%(x)s);" % locals()
arcsin = ArcSin(upgrade_to_float, name='arcsin')
class Tan(UnaryScalarOp):
def impl(self, x):
return numpy.tan(x)
......@@ -2056,6 +2094,46 @@ class Tan(UnaryScalarOp):
tan = Tan(upgrade_to_float, name='tan')
class ArcTan(UnaryScalarOp):
def impl(self, x):
return numpy.arctan(x)
def grad(self, (x,), (gz,)):
if gz.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / (numpy.cast[x.type](1) + sqr(x)),
else:
return None,
def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = atan(%(x)s);" % locals()
arctan = ArcTan(upgrade_to_float, name='arctan')
class ArcTan2(BinaryScalarOp):
def impl(self, y, x):
return numpy.arctan2(y, x)
def grad(self, (y, x), (gz,)):
if gz.type in complex_types:
raise NotImplementedError()
if x.type in float_types and y.type in float_types:
return [gz * x / (sqr(x) + sqr(y)),
gz * neg(y) / (sqr(x) + sqr(y))]
else:
return None,
def c_code(self, node, name, (y, x), (z,), sub):
if (node.inputs[0].type in complex_types or
node.inputs[1].type in complex_types):
raise NotImplementedError('type not supported', type)
return "%(z)s = atan2(%(y)s, %(x)s);" % locals()
arctan2 = ArcTan2(upgrade_to_float, name='arctan2')
class Cosh(UnaryScalarOp):
"""
cosh(x) = (exp(x) + exp(-x)) / 2
......@@ -2078,6 +2156,25 @@ class Cosh(UnaryScalarOp):
cosh = Cosh(upgrade_to_float, name='cosh')
class ArcCosh(UnaryScalarOp):
def impl(self, x):
return numpy.arccosh(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / sqrt(sqr(x) - numpy.cast[x.type](1)),
else:
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = acosh(%(x)s);" % locals()
arccosh = ArcCosh(upgrade_to_float, name='arccosh')
class Sinh(UnaryScalarOp):
"""
sinh(x) = (exp(x) - exp(-x)) / 2
......@@ -2100,6 +2197,25 @@ class Sinh(UnaryScalarOp):
sinh = Sinh(upgrade_to_float, name='sinh')
class ArcSinh(UnaryScalarOp):
def impl(self, x):
return numpy.arcsinh(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / sqrt(sqr(x) + numpy.cast[x.type](1)),
else:
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = asinh(%(x)s);" % locals()
arcsinh = ArcSinh(upgrade_to_float, name='arcsinh')
class Tanh(UnaryScalarOp):
"""
tanh(x) = sinh(x) / cosh(x)
......@@ -2123,6 +2239,25 @@ class Tanh(UnaryScalarOp):
tanh = Tanh(upgrade_to_float, name='tanh')
class ArcTanh(UnaryScalarOp):
def impl(self, x):
return numpy.arctanh(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / (numpy.cast[x.type](1) -sqr(x)),
else:
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = atanh(%(x)s);" % locals()
arctanh = ArcTanh(upgrade_to_float, name='arctanh')
class Real(UnaryScalarOp):
"""Extract the real coordinate of a complex number. """
def impl(self, x):
......
......@@ -2476,6 +2476,11 @@ def exp(a):
"""e^`a`"""
@_scal_elemwise_with_nfunc('exp2', 1, 1)
def exp2(a):
"""2^`a`"""
@_scal_elemwise_with_nfunc('negative', 1, 1)
def neg(a):
"""-a"""
......@@ -2575,26 +2580,56 @@ def sin(a):
"""sine of a"""
@_scal_elemwise_with_nfunc('arcsin', 1, 1)
def arcsin(a):
"""arcsine of a"""
@_scal_elemwise_with_nfunc('tan', 1, 1)
def tan(a):
"""tangent of a"""
@_scal_elemwise_with_nfunc('arctan', 1, 1)
def arctan(a):
"""arctangent of a"""
@_scal_elemwise_with_nfunc('arctan2', 1, 1)
def arctan2(a, b):
"""arctangent of a / b"""
@_scal_elemwise_with_nfunc('cosh', 1, 1)
def cosh(a):
"""hyperbolic cosine of a"""
@_scal_elemwise_with_nfunc('arccosh', 1, 1)
def arccosh(a):
"""hyperbolic arc cosine of a"""
@_scal_elemwise_with_nfunc('sinh', 1, 1)
def sinh(a):
"""hyperbolic sine of a"""
@_scal_elemwise_with_nfunc('arcsinh', 1, 1)
def arcsinh(a):
"""hyperbolic arc sine of a"""
@_scal_elemwise_with_nfunc('tanh', 1, 1)
def tanh(a):
"""hyperbolic tangent of a"""
@_scal_elemwise_with_nfunc('arctanh', 1, 1)
def arctanh(a):
"""hyperbolic arc tangent of a"""
@_scal_elemwise
def erf(a):
"""error function"""
......
from basic import _scal_elemwise #, _transpose_inplace
from theano import scalar as scal
import elemwise
......@@ -88,6 +87,10 @@ def abs__inplace(a):
def exp_inplace(a):
"""e^`a` (inplace on `a`)"""
@_scal_inplace
def exp2_inplace(a):
"""2^`a` (inplace on `a`)"""
@_scal_inplace
def neg_inplace(a):
"""-a (inplace on a)"""
......@@ -152,22 +155,46 @@ def arccos_inplace(a):
def sin_inplace(a):
"""sine of `a` (inplace on `a`)"""
@_scal_inplace
def arcsin_inplace(a):
"""arcsine of `a` (inplace on `a`)"""
@_scal_inplace
def tan_inplace(a):
"""tangent of `a` (inplace on `a`)"""
@_scal_inplace
def arctan_inplace(a):
"""arctangent of `a` (inplace on `a`)"""
@_scal_inplace
def arctan2_inplace(a, b):
"""arctangent of `a` / `b` (inplace on `a`)"""
@_scal_inplace
def cosh_inplace(a):
"""hyperbolic cosine of `a` (inplace on `a`)"""
@_scal_inplace
def arccosh_inplace(a):
"""hyperbolic arc cosine of `a` (inplace on `a`)"""
@_scal_inplace
def sinh_inplace(a):
"""hyperbolic sine of `a` (inplace on `a`)"""
@_scal_inplace
def arcsinh_inplace(a):
"""hyperbolic arc sine of `a` (inplace on `a`)"""
@_scal_inplace
def tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)"""
@_scal_inplace
def arctanh_inplace(a):
"""hyperbolic arc tangent of `a` (inplace on `a`)"""
@_scal_inplace
def erf_inplace(a):
"""error function"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论