added trig functions, untested

上级 e412894e
......@@ -388,6 +388,53 @@ class Sqrt(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sqrt(%(x)s);" % locals()
class Cos(UnaryScalarOp):
def impl(self, x):
return math.cos(x)
def grad(self, (x, ), (gz, )):
return gz * sin(x),
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = cos(%(x)s);" % locals()
class Sin(UnaryScalarOp):
def impl(self, x):
return math.sin(x)
def grad(self, (x, ), (gz, )):
return -gz * cos(x),
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sin(%(x)s);" % locals()
class Tan(UnaryScalarOp):
def impl(self, x):
return math.tan(x)
def grad(self, (x, ), (gz, )):
raise NotImplementedError('lazy')
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = tan(%(x)s);" % locals()
class Cosh(UnaryScalarOp):
def impl(self, x):
return math.cosh(x)
def grad(self, (x, ), (gz, )):
raise NotImplementedError()
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = cosh(%(x)s);" % locals()
class Sinh(UnaryScalarOp):
def impl(self, x):
return math.sin(x)
def grad(self, (x, ), (gz, )):
return -gz * cos(x),
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sin(%(x)s);" % locals()
class Tanh(UnaryScalarOp):
def impl(self, x):
return math.tanh(x)
def grad(self, (x, ), (gz, )):
return gz * (1 - tanh(x))**2
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = tanh(%(x)s);" % locals()
#NOTE WELL!!!
......
......@@ -216,6 +216,12 @@ Log2, log2, Log2Inplace, log2_inplace = broadcast(scal.Log2, 'Log2')
Sgn, sgn, SgnInplace, sgn_inplace = broadcast(scal.Sgn, 'Sgn')
Sqr, sqr, SqrInplace, sqr_inplace = broadcast(scal.Sqr, 'Sqr')
Sqrt, sqrt, SqrtInplace, sqrt_inplace = broadcast(scal.Sqrt, 'Sqrt')
Cos, cos, CosInplace, cos_inplace = broadcast(scal.Cos, 'Cos')
Sin, sin, SinInplace, sin_inplace = broadcast(scal.Sin, 'Sin')
Tan, tan, TanInplace, tan_inplace = broadcast(scal.Tan, 'Tan')
Cosh, cosh, CoshInplace, cosh_inplace = broadcast(scal.Cosh, 'Cosh')
Sinh, sinh, SinhInplace, sinh_inplace = broadcast(scal.Sinh, 'Sinh')
Tanh, tanh, TanhInplace, tanh_inplace = broadcast(scal.Tanh, 'Tanh')
Sum = s2t.Sum
sum = gof.op.constructor(Sum)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论