提交 fd8df6c1 authored 作者: James Bergstra's avatar James Bergstra

added better complex support in scalar/basic

上级 8cf42d56
......@@ -381,13 +381,27 @@ def float_out(*types):
return float64,
def upgrade_to_float(*types):
"""
This upgrade the types to float32 or float64 to don't loose any precision.
Upgrade any int types to float32 or float64 to avoid losing any precision.
"""
conv = {int8: float32,
int16: float32,
int32: float64,
int64: float64}
return Scalar(Scalar.upcast(*[conv.get(type, type) for type in types])),
def same_out_nocomplex(type):
if type in complex_types:
raise TypeError('complex argument not supported')
return type,
def int_out_nocomplex(*types):
for type in types:
if type in complex_types:
raise TypeError('complex argument not supported')
return int64,
def float_out_nocomplex(*types):
for type in types:
if type in complex_types:
raise TypeError('complex argument not supported')
return float64,
class ScalarOp(Op):
......@@ -997,7 +1011,6 @@ class Abs(UnaryScalarOp):
return "%(z)s = fabs(%(x)s);" % locals()
if type in complex_types:
return "%(z)s = sqrt(%(x)s.real*%(x)s.real + %(x)s.imag*%(x)s.imag);" % locals()
#complex, other?
raise NotImplementedError('type not supported', type)
abs_ = Abs(same_out)
......@@ -1010,8 +1023,19 @@ class Sgn(UnaryScalarOp):
def c_code(self, node, name, (x, ), (z, ), sub):
#casting is done by compiler
#TODO: use copysign
type = node.inputs[0].type
if type in float_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0.0 : 1.0 : -1.0;" % locals()
sgn = Sgn(same_out, name = 'sgn')
if type in int_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0 : 1 : -1;" % locals()
raise TypeError() #complex has no sgn
def c_code_cache_version(self):
s = super(Sgn, self).c_code_cache_version()
if s:
return (3,) + s
else: #if parent is unversioned, we are too
return s
sgn = Sgn(same_out_nocomplex, name = 'sgn')
class Ceil(UnaryScalarOp):
def impl(self, x):
......@@ -1020,7 +1044,7 @@ class Ceil(UnaryScalarOp):
return None,
def c_code(self, node, name, (x,), (z,), sub):
return "%(z)s = ceil(%(x)s);" % locals()
ceil = Ceil(same_out, name = 'ceil')
ceil = Ceil(same_out_nocomplex, name = 'ceil')
class Floor(UnaryScalarOp):
def impl(self, x):
......@@ -1029,14 +1053,14 @@ class Floor(UnaryScalarOp):
return None,
def c_code(self, node, name, (x,), (z,), sub):
return "%(z)s = floor(%(x)s);" % locals()
floor = Floor(same_out, name = 'ceil')
floor = Floor(same_out_nocomplex, name = 'ceil')
class IRound(UnaryScalarOp):
def impl(self, x):
return numpy.asarray(numpy.round(x), dtype = 'int64')
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = round(%(x)s);" % locals()
iround = IRound(int_out)
iround = IRound(int_out_nocomplex)
class Neg(UnaryScalarOp):
def impl(self, x):
......@@ -1080,6 +1104,8 @@ class Log(UnaryScalarOp):
#todo: the version using log2 seems to be very slightly faster
# on some machines for some reason, check if it's worth switching
#return "%(z)s = log2(%(x)s) * 0.69314718055994529;" % locals()
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = log(%(x)s);" % locals()
log = Log(upgrade_to_float, name = 'log')
......@@ -1096,6 +1122,8 @@ class Log2(UnaryScalarOp):
#backport
#return gz / (x * math.log(2.0)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = log2(%(x)s);" % locals()
log2 = Log2(upgrade_to_float, name = 'log2')
......@@ -1105,28 +1133,31 @@ class Log10(UnaryScalarOp):
return numpy.log10(x)
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return gz / (x * math.log(10.0)),
return gz / (x * numpy.log(10.0)),
else:
return None
#backport
#return gz / (x * math.log(10.0)) if x.type in grad_types else None,
#return gz / (x * numpy.log(10.0)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = log10(%(x)s);" % locals()
log10 = Log10(upgrade_to_float, name = 'log10')
class Exp(UnaryScalarOp):
def impl(self, x):
return math.exp(x)
return numpy.exp(x)
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return gz * exp(x),
else:
return None,
#backport
#return gz * exp(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = exp(%(x)s);" % locals()
exp = Exp(upgrade_to_float, name = 'exp')
......@@ -1147,7 +1178,7 @@ sqr = Sqr(same_out, name = 'sqr')
class Sqrt(UnaryScalarOp):
def impl(self, x):
return math.sqrt(x)
return numpy.sqrt(x)
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return (gz * 0.5) / sqrt(x),
......@@ -1156,12 +1187,14 @@ class Sqrt(UnaryScalarOp):
#backport
#return (gz * 0.5) / sqrt(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = sqrt(%(x)s);" % locals()
sqrt = Sqrt(upgrade_to_float, name = 'sqrt')
class Cos(UnaryScalarOp):
def impl(self, x):
return math.cos(x)
return numpy.cos(x)
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return -gz * sin(x),
......@@ -1170,12 +1203,14 @@ class Cos(UnaryScalarOp):
#backport
# return -gz * sin(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = cos(%(x)s);" % locals()
cos = Cos(upgrade_to_float, name = 'cos')
class Sin(UnaryScalarOp):
def impl(self, x):
return math.sin(x)
return numpy.sin(x)
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return gz * cos(x),
......@@ -1184,12 +1219,14 @@ class Sin(UnaryScalarOp):
#backport
# return gz * cos(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = sin(%(x)s);" % locals()
sin = Sin(upgrade_to_float, name = 'sin')
class Tan(UnaryScalarOp):
def impl(self, x):
return math.tan(x)
return numpy.tan(x)
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return gz / sqr(cos(x)),
......@@ -1198,6 +1235,8 @@ class Tan(UnaryScalarOp):
#backport
#return gz / sqr(cos(x)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = tan(%(x)s);" % locals()
tan = Tan(upgrade_to_float, name = 'tan')
......@@ -1206,7 +1245,7 @@ class Cosh(UnaryScalarOp):
cosh(x) = (exp(x) + exp(-x)) / 2
"""
def impl(self, x):
return math.cosh(x)
return numpy.cosh(x)
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return gz * sinh(x),
......@@ -1215,6 +1254,8 @@ class Cosh(UnaryScalarOp):
#backport
#return gz * sinh(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = cosh(%(x)s);" % locals()
cosh = Cosh(upgrade_to_float, name = 'cosh')
......@@ -1223,7 +1264,7 @@ class Sinh(UnaryScalarOp):
sinh(x) = (exp(x) - exp(-x)) / 2
"""
def impl(self, x):
return math.sinh(x)
return numpy.sinh(x)
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return gz * cosh(x),
......@@ -1232,6 +1273,8 @@ class Sinh(UnaryScalarOp):
#backport
#return gz * cosh(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = sinh(%(x)s);" % locals()
sinh = Sinh(upgrade_to_float, name = 'sinh')
......@@ -1241,7 +1284,7 @@ class Tanh(UnaryScalarOp):
= (exp(2*x) - 1) / (exp(2*x) + 1)
"""
def impl(self, x):
return math.tanh(x)
return numpy.tanh(x)
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return gz * (1 - sqr(tanh(x))),
......@@ -1250,6 +1293,8 @@ class Tanh(UnaryScalarOp):
#backport
#return gz * (1 - sqr(tanh(x))) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = tanh(%(x)s);" % locals()
tanh = Tanh(upgrade_to_float, name = 'tanh')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论