提交 fd8df6c1 authored 作者: James Bergstra's avatar James Bergstra

added better complex support in scalar/basic

上级 8cf42d56
...@@ -381,13 +381,27 @@ def float_out(*types): ...@@ -381,13 +381,27 @@ def float_out(*types):
return float64, return float64,
def upgrade_to_float(*types): def upgrade_to_float(*types):
""" """
This upgrade the types to float32 or float64 to don't loose any precision. Upgrade any int types to float32 or float64 to avoid losing any precision.
""" """
conv = {int8: float32, conv = {int8: float32,
int16: float32, int16: float32,
int32: float64, int32: float64,
int64: float64} int64: float64}
return Scalar(Scalar.upcast(*[conv.get(type, type) for type in types])), return Scalar(Scalar.upcast(*[conv.get(type, type) for type in types])),
def same_out_nocomplex(type):
if type in complex_types:
raise TypeError('complex argument not supported')
return type,
def int_out_nocomplex(*types):
for type in types:
if type in complex_types:
raise TypeError('complex argument not supported')
return int64,
def float_out_nocomplex(*types):
for type in types:
if type in complex_types:
raise TypeError('complex argument not supported')
return float64,
class ScalarOp(Op): class ScalarOp(Op):
...@@ -997,7 +1011,6 @@ class Abs(UnaryScalarOp): ...@@ -997,7 +1011,6 @@ class Abs(UnaryScalarOp):
return "%(z)s = fabs(%(x)s);" % locals() return "%(z)s = fabs(%(x)s);" % locals()
if type in complex_types: if type in complex_types:
return "%(z)s = sqrt(%(x)s.real*%(x)s.real + %(x)s.imag*%(x)s.imag);" % locals() return "%(z)s = sqrt(%(x)s.real*%(x)s.real + %(x)s.imag*%(x)s.imag);" % locals()
#complex, other?
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
abs_ = Abs(same_out) abs_ = Abs(same_out)
...@@ -1010,8 +1023,19 @@ class Sgn(UnaryScalarOp): ...@@ -1010,8 +1023,19 @@ class Sgn(UnaryScalarOp):
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
#casting is done by compiler #casting is done by compiler
#TODO: use copysign #TODO: use copysign
type = node.inputs[0].type
if type in float_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0.0 : 1.0 : -1.0;" % locals() return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0.0 : 1.0 : -1.0;" % locals()
sgn = Sgn(same_out, name = 'sgn') if type in int_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0 : 1 : -1;" % locals()
raise TypeError() #complex has no sgn
def c_code_cache_version(self):
s = super(Sgn, self).c_code_cache_version()
if s:
return (3,) + s
else: #if parent is unversioned, we are too
return s
sgn = Sgn(same_out_nocomplex, name = 'sgn')
class Ceil(UnaryScalarOp): class Ceil(UnaryScalarOp):
def impl(self, x): def impl(self, x):
...@@ -1020,7 +1044,7 @@ class Ceil(UnaryScalarOp): ...@@ -1020,7 +1044,7 @@ class Ceil(UnaryScalarOp):
return None, return None,
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
return "%(z)s = ceil(%(x)s);" % locals() return "%(z)s = ceil(%(x)s);" % locals()
ceil = Ceil(same_out, name = 'ceil') ceil = Ceil(same_out_nocomplex, name = 'ceil')
class Floor(UnaryScalarOp): class Floor(UnaryScalarOp):
def impl(self, x): def impl(self, x):
...@@ -1029,14 +1053,14 @@ class Floor(UnaryScalarOp): ...@@ -1029,14 +1053,14 @@ class Floor(UnaryScalarOp):
return None, return None,
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
return "%(z)s = floor(%(x)s);" % locals() return "%(z)s = floor(%(x)s);" % locals()
floor = Floor(same_out, name = 'ceil') floor = Floor(same_out_nocomplex, name = 'ceil')
class IRound(UnaryScalarOp): class IRound(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.asarray(numpy.round(x), dtype = 'int64') return numpy.asarray(numpy.round(x), dtype = 'int64')
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = round(%(x)s);" % locals() return "%(z)s = round(%(x)s);" % locals()
iround = IRound(int_out) iround = IRound(int_out_nocomplex)
class Neg(UnaryScalarOp): class Neg(UnaryScalarOp):
def impl(self, x): def impl(self, x):
...@@ -1080,6 +1104,8 @@ class Log(UnaryScalarOp): ...@@ -1080,6 +1104,8 @@ class Log(UnaryScalarOp):
#todo: the version using log2 seems to be very slightly faster #todo: the version using log2 seems to be very slightly faster
# on some machines for some reason, check if it's worth switching # on some machines for some reason, check if it's worth switching
#return "%(z)s = log2(%(x)s) * 0.69314718055994529;" % locals() #return "%(z)s = log2(%(x)s) * 0.69314718055994529;" % locals()
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = log(%(x)s);" % locals() return "%(z)s = log(%(x)s);" % locals()
log = Log(upgrade_to_float, name = 'log') log = Log(upgrade_to_float, name = 'log')
...@@ -1096,6 +1122,8 @@ class Log2(UnaryScalarOp): ...@@ -1096,6 +1122,8 @@ class Log2(UnaryScalarOp):
#backport #backport
#return gz / (x * math.log(2.0)) if x.type in grad_types else None, #return gz / (x * math.log(2.0)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = log2(%(x)s);" % locals() return "%(z)s = log2(%(x)s);" % locals()
log2 = Log2(upgrade_to_float, name = 'log2') log2 = Log2(upgrade_to_float, name = 'log2')
...@@ -1105,28 +1133,31 @@ class Log10(UnaryScalarOp): ...@@ -1105,28 +1133,31 @@ class Log10(UnaryScalarOp):
return numpy.log10(x) return numpy.log10(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return gz / (x * math.log(10.0)), return gz / (x * numpy.log(10.0)),
else: else:
return None return None
#backport #backport
#return gz / (x * math.log(10.0)) if x.type in grad_types else None, #return gz / (x * numpy.log(10.0)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = log10(%(x)s);" % locals() return "%(z)s = log10(%(x)s);" % locals()
log10 = Log10(upgrade_to_float, name = 'log10') log10 = Log10(upgrade_to_float, name = 'log10')
class Exp(UnaryScalarOp): class Exp(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.exp(x) return numpy.exp(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return gz * exp(x), return gz * exp(x),
else: else:
return None, return None,
#backport #backport
#return gz * exp(x) if x.type in grad_types else None, #return gz * exp(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = exp(%(x)s);" % locals() return "%(z)s = exp(%(x)s);" % locals()
exp = Exp(upgrade_to_float, name = 'exp') exp = Exp(upgrade_to_float, name = 'exp')
...@@ -1147,7 +1178,7 @@ sqr = Sqr(same_out, name = 'sqr') ...@@ -1147,7 +1178,7 @@ sqr = Sqr(same_out, name = 'sqr')
class Sqrt(UnaryScalarOp): class Sqrt(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sqrt(x) return numpy.sqrt(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return (gz * 0.5) / sqrt(x), return (gz * 0.5) / sqrt(x),
...@@ -1156,12 +1187,14 @@ class Sqrt(UnaryScalarOp): ...@@ -1156,12 +1187,14 @@ class Sqrt(UnaryScalarOp):
#backport #backport
#return (gz * 0.5) / sqrt(x) if x.type in grad_types else None, #return (gz * 0.5) / sqrt(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = sqrt(%(x)s);" % locals() return "%(z)s = sqrt(%(x)s);" % locals()
sqrt = Sqrt(upgrade_to_float, name = 'sqrt') sqrt = Sqrt(upgrade_to_float, name = 'sqrt')
class Cos(UnaryScalarOp): class Cos(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.cos(x) return numpy.cos(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return -gz * sin(x), return -gz * sin(x),
...@@ -1170,12 +1203,14 @@ class Cos(UnaryScalarOp): ...@@ -1170,12 +1203,14 @@ class Cos(UnaryScalarOp):
#backport #backport
# return -gz * sin(x) if x.type in grad_types else None, # return -gz * sin(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = cos(%(x)s);" % locals() return "%(z)s = cos(%(x)s);" % locals()
cos = Cos(upgrade_to_float, name = 'cos') cos = Cos(upgrade_to_float, name = 'cos')
class Sin(UnaryScalarOp): class Sin(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sin(x) return numpy.sin(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return gz * cos(x), return gz * cos(x),
...@@ -1184,12 +1219,14 @@ class Sin(UnaryScalarOp): ...@@ -1184,12 +1219,14 @@ class Sin(UnaryScalarOp):
#backport #backport
# return gz * cos(x) if x.type in grad_types else None, # return gz * cos(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = sin(%(x)s);" % locals() return "%(z)s = sin(%(x)s);" % locals()
sin = Sin(upgrade_to_float, name = 'sin') sin = Sin(upgrade_to_float, name = 'sin')
class Tan(UnaryScalarOp): class Tan(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.tan(x) return numpy.tan(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return gz / sqr(cos(x)), return gz / sqr(cos(x)),
...@@ -1198,6 +1235,8 @@ class Tan(UnaryScalarOp): ...@@ -1198,6 +1235,8 @@ class Tan(UnaryScalarOp):
#backport #backport
#return gz / sqr(cos(x)) if x.type in grad_types else None, #return gz / sqr(cos(x)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = tan(%(x)s);" % locals() return "%(z)s = tan(%(x)s);" % locals()
tan = Tan(upgrade_to_float, name = 'tan') tan = Tan(upgrade_to_float, name = 'tan')
...@@ -1206,7 +1245,7 @@ class Cosh(UnaryScalarOp): ...@@ -1206,7 +1245,7 @@ class Cosh(UnaryScalarOp):
cosh(x) = (exp(x) + exp(-x)) / 2 cosh(x) = (exp(x) + exp(-x)) / 2
""" """
def impl(self, x): def impl(self, x):
return math.cosh(x) return numpy.cosh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return gz * sinh(x), return gz * sinh(x),
...@@ -1215,6 +1254,8 @@ class Cosh(UnaryScalarOp): ...@@ -1215,6 +1254,8 @@ class Cosh(UnaryScalarOp):
#backport #backport
#return gz * sinh(x) if x.type in grad_types else None, #return gz * sinh(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = cosh(%(x)s);" % locals() return "%(z)s = cosh(%(x)s);" % locals()
cosh = Cosh(upgrade_to_float, name = 'cosh') cosh = Cosh(upgrade_to_float, name = 'cosh')
...@@ -1223,7 +1264,7 @@ class Sinh(UnaryScalarOp): ...@@ -1223,7 +1264,7 @@ class Sinh(UnaryScalarOp):
sinh(x) = (exp(x) - exp(-x)) / 2 sinh(x) = (exp(x) - exp(-x)) / 2
""" """
def impl(self, x): def impl(self, x):
return math.sinh(x) return numpy.sinh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return gz * cosh(x), return gz * cosh(x),
...@@ -1232,6 +1273,8 @@ class Sinh(UnaryScalarOp): ...@@ -1232,6 +1273,8 @@ class Sinh(UnaryScalarOp):
#backport #backport
#return gz * cosh(x) if x.type in grad_types else None, #return gz * cosh(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = sinh(%(x)s);" % locals() return "%(z)s = sinh(%(x)s);" % locals()
sinh = Sinh(upgrade_to_float, name = 'sinh') sinh = Sinh(upgrade_to_float, name = 'sinh')
...@@ -1241,7 +1284,7 @@ class Tanh(UnaryScalarOp): ...@@ -1241,7 +1284,7 @@ class Tanh(UnaryScalarOp):
= (exp(2*x) - 1) / (exp(2*x) + 1) = (exp(2*x) - 1) / (exp(2*x) + 1)
""" """
def impl(self, x): def impl(self, x):
return math.tanh(x) return numpy.tanh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return gz * (1 - sqr(tanh(x))), return gz * (1 - sqr(tanh(x))),
...@@ -1250,6 +1293,8 @@ class Tanh(UnaryScalarOp): ...@@ -1250,6 +1293,8 @@ class Tanh(UnaryScalarOp):
#backport #backport
#return gz * (1 - sqr(tanh(x))) if x.type in grad_types else None, #return gz * (1 - sqr(tanh(x))) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = tanh(%(x)s);" % locals() return "%(z)s = tanh(%(x)s);" % locals()
tanh = Tanh(upgrade_to_float, name = 'tanh') tanh = Tanh(upgrade_to_float, name = 'tanh')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论