提交 8fbfaa0a authored 作者: Frederic Bastien's avatar Frederic Bastien

white space fix.

上级 ac4cb749
...@@ -74,7 +74,7 @@ class Scalar(Type): ...@@ -74,7 +74,7 @@ class Scalar(Type):
py_type = self.dtype_specs()[0] py_type = self.dtype_specs()[0]
if strict and not isinstance(data, py_type): if strict and not isinstance(data, py_type):
raise TypeError("%s expected a %s, got %s of type %s" % (self, py_type, data, raise TypeError("%s expected a %s, got %s of type %s" % (self, py_type, data,
type(data)), type(data)),
data) data)
try: try:
converted_data = py_type(data) converted_data = py_type(data)
...@@ -160,11 +160,11 @@ class Scalar(Type): ...@@ -160,11 +160,11 @@ class Scalar(Type):
return """ return """
%(name)s = 0; %(name)s = 0;
""" % locals() """ % locals()
def c_extract(self, name, sub): def c_extract(self, name, sub):
specs = self.dtype_specs() specs = self.dtype_specs()
#TODO: This is the wrong code, but we don't know what to change it to. #TODO: This is the wrong code, but we don't know what to change it to.
# For example, a numpy.uint8 is not a PyInt, so PyInt_Check # For example, a numpy.uint8 is not a PyInt, so PyInt_Check
# is simply the wrong function to # is simply the wrong function to
# call. # call.
# Look at PyArrayScalar api for how to cast to/from PyArrayScalar objects. # Look at PyArrayScalar api for how to cast to/from PyArrayScalar objects.
...@@ -183,7 +183,7 @@ class Scalar(Type): ...@@ -183,7 +183,7 @@ class Scalar(Type):
dtype = specs[1], dtype = specs[1],
check = specs[2], check = specs[2],
conv = specs[3]) conv = specs[3])
def c_sync(self, name, sub): def c_sync(self, name, sub):
specs = self.dtype_specs() specs = self.dtype_specs()
return """ return """
...@@ -371,7 +371,7 @@ class _scalar_py_operators: ...@@ -371,7 +371,7 @@ class _scalar_py_operators:
#def __complex__(self): return AsComplex(self).out #def __complex__(self): return AsComplex(self).out
#BITWISE #BITWISE
def __invert__(self): return invert(self) def __invert__(self): return invert(self)
def __and__(self,other): return and_(self, other) def __and__(self,other): return and_(self, other)
def __or__(self,other): return or_(self, other) def __or__(self,other): return or_(self, other)
def __xor__(self,other): return xor(self, other) def __xor__(self,other): return xor(self, other)
...@@ -449,10 +449,10 @@ class transfer_type(gof.utils.object2): ...@@ -449,10 +449,10 @@ class transfer_type(gof.utils.object2):
upcast = upcast_out(*types) upcast = upcast_out(*types)
retval = [] retval = []
for i in self.transfer: for i in self.transfer:
if i is None: if i is None:
retval += [upcast] retval += [upcast]
else: else:
retval += [types[i]] retval += [types[i]]
return retval return retval
#return [upcast if i is None else types[i] for i in self.transfer] #return [upcast if i is None else types[i] for i in self.transfer]
def __eq__(self, other): def __eq__(self, other):
...@@ -577,7 +577,7 @@ class ScalarOp(Op): ...@@ -577,7 +577,7 @@ class ScalarOp(Op):
def impl(self, *inputs): def impl(self, *inputs):
raise utils.MethodNotDefined("impl", type(self), self.__class__.__name__) raise utils.MethodNotDefined("impl", type(self), self.__class__.__name__)
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
raise utils.MethodNotDefined("grad", type(self), self.__class__.__name__) raise utils.MethodNotDefined("grad", type(self), self.__class__.__name__)
...@@ -694,17 +694,17 @@ class InRange(LogicalComparison): ...@@ -694,17 +694,17 @@ class InRange(LogicalComparison):
return True return True
def c_code(self, node, name, (x, low, hi), (z, ), sub): def c_code(self, node, name, (x, low, hi), (z, ), sub):
if self.openlow: if self.openlow:
cmp1 = '>' cmp1 = '>'
else: else:
cmp1 = '>=' cmp1 = '>='
#backport #backport
#cmp1 = '>' if self.openlow else '>=' #cmp1 = '>' if self.openlow else '>='
if self.openhi: if self.openhi:
cmp2 = '<' cmp2 = '<'
else: else:
cmp2 = '<=' cmp2 = '<='
#backport #backport
#cmp2 = '<' if self.openhi else '<=' #cmp2 = '<' if self.openhi else '<='
...@@ -717,25 +717,25 @@ inclosedrange = InRange(False, False) ...@@ -717,25 +717,25 @@ inclosedrange = InRange(False, False)
class Switch(ScalarOp): class Switch(ScalarOp):
nin = 3 nin = 3
def impl(self, cond, ift, iff): def impl(self, cond, ift, iff):
if cond: if cond:
return ift return ift
else: else:
return iff return iff
#backport #backport
#return ift if cond else iff #return ift if cond else iff
def c_code(self, node, name, (cond, ift, iff), (z, ), sub): def c_code(self, node, name, (cond, ift, iff), (z, ), sub):
return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals() return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals()
def grad(self, (cond, ift, iff), (gz, )): def grad(self, (cond, ift, iff), (gz, )):
if ift.type in continuous_types: if ift.type in continuous_types:
first_part = switch(cond, gz, 0) first_part = switch(cond, gz, 0)
else: else:
first_part = None first_part = None
if iff.type in continuous_types: if iff.type in continuous_types:
second_part = switch(cond, 0, gz) second_part = switch(cond, 0, gz)
else: else:
second_part = None second_part = None
return (None, first_part, second_part) return (None, first_part, second_part)
...@@ -809,16 +809,16 @@ class Maximum(BinaryScalarOp): ...@@ -809,16 +809,16 @@ class Maximum(BinaryScalarOp):
return max(inputs) return max(inputs)
def c_code(self, node, name, (x,y), (z, ), sub): def c_code(self, node, name, (x,y), (z, ), sub):
return "%(z)s = ((%(y)s)>(%(x)s)? (%(y)s):(%(x)s));" %locals() return "%(z)s = ((%(y)s)>(%(x)s)? (%(y)s):(%(x)s));" %locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types assert gz.type not in complex_types
# max is not defined for complex_types # max is not defined for complex_types
gx, gy = None, None gx, gy = None, None
if x.type in float_types: if x.type in float_types:
gx = eq(maximum(x,y), x)*gz gx = eq(maximum(x,y), x)*gz
if y.type in float_types: if y.type in float_types:
gy = eq(maximum(x,y), y)*gz gy = eq(maximum(x,y), y)*gz
return (gx,gy) return (gx,gy)
maximum = Maximum(upcast_out, name = 'maximum') maximum = Maximum(upcast_out, name = 'maximum')
...@@ -829,16 +829,16 @@ class Minimum(BinaryScalarOp): ...@@ -829,16 +829,16 @@ class Minimum(BinaryScalarOp):
return min(inputs) return min(inputs)
def c_code(self, node, name, (x,y), (z, ), sub): def c_code(self, node, name, (x,y), (z, ), sub):
return "%(z)s = ((%(y)s)<(%(x)s)? (%(y)s):(%(x)s));" %locals() return "%(z)s = ((%(y)s)<(%(x)s)? (%(y)s):(%(x)s));" %locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types assert gz.type not in complex_types
# max is not defined for complex_types # max is not defined for complex_types
gx, gy = None, None gx, gy = None, None
if x.type in float_types: if x.type in float_types:
gx = eq(minimum(x,y), x)*gz gx = eq(minimum(x,y), x)*gz
if y.type in float_types: if y.type in float_types:
gy = eq(minimum(x,y), y)*gz gy = eq(minimum(x,y), y)*gz
return (gx,gy) return (gx,gy)
minimum = Minimum(upcast_out, name = 'minimum') minimum = Minimum(upcast_out, name = 'minimum')
...@@ -886,24 +886,24 @@ class Mul(ScalarOp): ...@@ -886,24 +886,24 @@ class Mul(ScalarOp):
else: else:
return z + " = " + " * ".join(inputs) + ";" return z + " = " + " * ".join(inputs) + ";"
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
retval = [] retval = []
for input in inputs: for input in inputs:
if input.type in continuous_types: if input.type in continuous_types:
if gz.type in complex_types: if gz.type in complex_types:
# zr+zi = (xr + xi)(yr + yi) # zr+zi = (xr + xi)(yr + yi)
# zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr ) # zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr )
otherprod = mul(*(utils.difference(inputs, [input]))) otherprod = mul(*(utils.difference(inputs, [input])))
yr = real(otherprod) yr = real(otherprod)
yi = imag(otherprod) yi = imag(otherprod)
if input.type in complex_types: if input.type in complex_types:
retval += [complex(yr*real(gz)+yi*imag(gz), yr*imag(gz)-yi*real(gz))] retval += [complex(yr*real(gz)+yi*imag(gz), yr*imag(gz)-yi*real(gz))]
else:
retval += [cast(yr*real(gz)+yi*imag(gz), input.type.dtype)]
else:
retval += [cast(mul(*([gz] + utils.difference(inputs, [input]))), input.type.dtype)]
else: else:
retval += [cast(yr*real(gz)+yi*imag(gz), input.type.dtype)] retval += [None]
else: return retval
retval += [cast(mul(*([gz] + utils.difference(inputs, [input]))), input.type.dtype)]
else:
retval += [None]
return retval
mul = Mul(upcast_out, name = 'mul') mul = Mul(upcast_out, name = 'mul')
class Sub(BinaryScalarOp): class Sub(BinaryScalarOp):
...@@ -959,14 +959,14 @@ class TrueDiv(BinaryScalarOp): ...@@ -959,14 +959,14 @@ class TrueDiv(BinaryScalarOp):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if x.type in float_types:
first_part = cast(gz / y, x.type.dtype) first_part = cast(gz / y, x.type.dtype)
else: else:
first_part = None first_part = None
if y.type in float_types: if y.type in float_types:
second_part = cast(-(gz * x) / (y * y), y.type.dtype) second_part = cast(-(gz * x) / (y * y), y.type.dtype)
else: else:
second_part = None second_part = None
return first_part, second_part return first_part, second_part
true_div = TrueDiv(upcast_out, name = 'true_div') true_div = TrueDiv(upcast_out, name = 'true_div')
...@@ -1009,7 +1009,7 @@ class Mod(BinaryScalarOp): ...@@ -1009,7 +1009,7 @@ class Mod(BinaryScalarOp):
x_mod_ymp = "fmod(-%(x)s,%(y)s)"%locals() x_mod_ymp = "fmod(-%(x)s,%(y)s)"%locals()
else: else:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return """ return """
if (%(x)s < 0){ if (%(x)s < 0){
if (%(y)s < 0){ if (%(y)s < 0){
...@@ -1043,7 +1043,7 @@ class Pow(BinaryScalarOp): ...@@ -1043,7 +1043,7 @@ class Pow(BinaryScalarOp):
first_part = None first_part = None
if y.type in float_types: if y.type in float_types:
second_part = gz * log(x) * x**y second_part = gz * log(x) * x**y
else: else:
second_part = None second_part = None
...@@ -1078,9 +1078,9 @@ class First(BinaryScalarOp): ...@@ -1078,9 +1078,9 @@ class First(BinaryScalarOp):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
if x.type in continuous_types: if x.type in continuous_types:
return gz, None return gz, None
else: else:
return None,None return None,None
first = First(transfer_type(0), name = 'first') first = First(transfer_type(0), name = 'first')
class Second(BinaryScalarOp): class Second(BinaryScalarOp):
...@@ -1090,9 +1090,9 @@ class Second(BinaryScalarOp): ...@@ -1090,9 +1090,9 @@ class Second(BinaryScalarOp):
return "%(z)s = %(y)s;" % locals() return "%(z)s = %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
if y.type in continuous_types: if y.type in continuous_types:
return None, gz return None, gz
else: else:
return None return None
second = Second(transfer_type(1), name = 'second') second = Second(transfer_type(1), name = 'second')
...@@ -1126,9 +1126,9 @@ class Cast(UnaryScalarOp): ...@@ -1126,9 +1126,9 @@ class Cast(UnaryScalarOp):
return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x) return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in continuous_types: if x.type in continuous_types:
return [cast(gz, x.type.dtype)] return [cast(gz, x.type.dtype)]
else: else:
return None, return None,
def c_code_cache_version(self): def c_code_cache_version(self):
s = super(Cast, self).c_code_cache_version() s = super(Cast, self).c_code_cache_version()
if s: if s:
...@@ -1163,7 +1163,7 @@ _cast_mapping = { ...@@ -1163,7 +1163,7 @@ _cast_mapping = {
'complex64': convert_to_complex64, 'complex64': convert_to_complex64,
'complex128': convert_to_complex128} 'complex128': convert_to_complex128}
def cast(x, dtype): def cast(x, dtype):
"""Symbolically cast `x` to a Scalar of given `dtype`.""" """Symbolically cast `x` to a Scalar of given `dtype`."""
if dtype == 'floatX': dtype = config.floatX if dtype == 'floatX': dtype = config.floatX
_x = as_scalar(x) _x = as_scalar(x)
...@@ -1250,7 +1250,7 @@ class RoundHalfToEven(UnaryScalarOp): ...@@ -1250,7 +1250,7 @@ class RoundHalfToEven(UnaryScalarOp):
See http://en.wikipedia.org/wiki/Rounding for more detail See http://en.wikipedia.org/wiki/Rounding for more detail
""" """
def impl(self, x): def impl(self, x):
return numpy.round(x) return numpy.round(x)
def c_code___(self, node, name, (x, ), (z, ), sub): def c_code___(self, node, name, (x, ), (z, ), sub):
typ = node.outputs[0].type.dtype typ = node.outputs[0].type.dtype
if not node.outputs[0].type.dtype in ['float32', 'float64']: if not node.outputs[0].type.dtype in ['float32', 'float64']:
...@@ -1260,7 +1260,7 @@ class RoundHalfToEven(UnaryScalarOp): ...@@ -1260,7 +1260,7 @@ class RoundHalfToEven(UnaryScalarOp):
#ifndef ROUNDING_EPSILON #ifndef ROUNDING_EPSILON
#define ROUNDING_EPSILON 0.0000001 #define ROUNDING_EPSILON 0.0000001
#endif #endif
if (%(x)s < 0.0){ if (%(x)s < 0.0){
// We implement the else part like that: -else( -%(x)s); // We implement the else part like that: -else( -%(x)s);
%(typ)s i; %(typ)s i;
...@@ -1328,8 +1328,8 @@ class RoundHalfAwayFromZero(UnaryScalarOp): ...@@ -1328,8 +1328,8 @@ class RoundHalfAwayFromZero(UnaryScalarOp):
See http://en.wikipedia.org/wiki/Rounding for more detail See http://en.wikipedia.org/wiki/Rounding for more detail
""" """
def impl(self, x): def impl(self, x):
return round_half_away_from_zero_vec(x) return round_half_away_from_zero_vec(x)
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.outputs[0].type.dtype in ['float32', 'float64']: if node.outputs[0].type.dtype in ['float32', 'float64']:
return "%(z)s = round(%(x)s);" % locals() return "%(z)s = round(%(x)s);" % locals()
...@@ -1342,9 +1342,9 @@ class Neg(UnaryScalarOp): ...@@ -1342,9 +1342,9 @@ class Neg(UnaryScalarOp):
return -x return -x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in continuous_types: if x.type in continuous_types:
return -gz, return -gz,
else: else:
return None, return None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = -%(x)s;" % locals() return "%(z)s = -%(x)s;" % locals()
neg = Neg(same_out, name = 'neg') neg = Neg(same_out, name = 'neg')
...@@ -1392,9 +1392,9 @@ class Log2(UnaryScalarOp): ...@@ -1392,9 +1392,9 @@ class Log2(UnaryScalarOp):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if x.type in float_types:
return gz / (x * math.log(2.0)), return gz / (x * math.log(2.0)),
else: else:
return None, return None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -1410,9 +1410,9 @@ class Log10(UnaryScalarOp): ...@@ -1410,9 +1410,9 @@ class Log10(UnaryScalarOp):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if x.type in float_types:
return gz / (x * numpy.log(10.0)), return gz / (x * numpy.log(10.0)),
else: else:
return None return None
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -1456,12 +1456,12 @@ class Sqr(UnaryScalarOp): ...@@ -1456,12 +1456,12 @@ class Sqr(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return x*x return x*x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if x.type in float_types:
return gz * x * 2, return gz * x * 2,
else: else:
return None, return None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s * %(x)s;" % locals() return "%(z)s = %(x)s * %(x)s;" % locals()
...@@ -1471,12 +1471,12 @@ class Sqrt(UnaryScalarOp): ...@@ -1471,12 +1471,12 @@ class Sqrt(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.sqrt(x) return numpy.sqrt(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if x.type in float_types:
return (gz * 0.5) / sqrt(x), return (gz * 0.5) / sqrt(x),
else: else:
return None, return None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1487,12 +1487,12 @@ class Cos(UnaryScalarOp): ...@@ -1487,12 +1487,12 @@ class Cos(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.cos(x) return numpy.cos(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if x.type in float_types:
return -gz * sin(x), return -gz * sin(x),
else: else:
return None, return None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1503,12 +1503,12 @@ class Sin(UnaryScalarOp): ...@@ -1503,12 +1503,12 @@ class Sin(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.sin(x) return numpy.sin(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if x.type in float_types:
return gz * cos(x), return gz * cos(x),
else: else:
return None, return None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1519,12 +1519,12 @@ class Tan(UnaryScalarOp): ...@@ -1519,12 +1519,12 @@ class Tan(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.tan(x) return numpy.tan(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if x.type in float_types:
return gz / sqr(cos(x)), return gz / sqr(cos(x)),
else: else:
return None, return None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1538,12 +1538,12 @@ class Cosh(UnaryScalarOp): ...@@ -1538,12 +1538,12 @@ class Cosh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.cosh(x) return numpy.cosh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if x.type in float_types:
return gz * sinh(x), return gz * sinh(x),
else: else:
return None, return None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1557,12 +1557,12 @@ class Sinh(UnaryScalarOp): ...@@ -1557,12 +1557,12 @@ class Sinh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.sinh(x) return numpy.sinh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if x.type in float_types:
return gz * cosh(x), return gz * cosh(x),
else: else:
return None, return None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1577,12 +1577,12 @@ class Tanh(UnaryScalarOp): ...@@ -1577,12 +1577,12 @@ class Tanh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.tanh(x) return numpy.tanh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if x.type in float_types:
return gz * (1 - sqr(tanh(x))), return gz * (1 - sqr(tanh(x))),
else: else:
return None, return None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1656,7 +1656,7 @@ class Complex(BinaryScalarOp): ...@@ -1656,7 +1656,7 @@ class Complex(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return numpy.complex(x, y) return numpy.complex(x, y)
def grad(self, (x,y), (gz,)): def grad(self, (x,y), (gz,)):
return [cast(real(gz), x.type.dtype), return [cast(real(gz), x.type.dtype),
cast(imag(gz), y.type.dtype)] cast(imag(gz), y.type.dtype)]
complex = Complex(name='complex') complex = Complex(name='complex')
...@@ -1683,7 +1683,7 @@ class ComplexFromPolar(BinaryScalarOp): ...@@ -1683,7 +1683,7 @@ class ComplexFromPolar(BinaryScalarOp):
def grad(self, (r,theta), (gz,)): def grad(self, (r,theta), (gz,)):
gr = cos(theta) * real(gz) + sin(theta) * imag(gz) gr = cos(theta) * real(gz) + sin(theta) * imag(gz)
gtheta = -real(gz) * r * sin(theta) + imag(gz) * r * cos(theta) gtheta = -real(gz) * r * sin(theta) + imag(gz) * r * cos(theta)
return [cast(gr, r.type.dtype), return [cast(gr, r.type.dtype),
cast(gtheta, theta.type.dtype)] cast(gtheta, theta.type.dtype)]
complex_from_polar = ComplexFromPolar(name='complex_from_polar') complex_from_polar = ComplexFromPolar(name='complex_from_polar')
...@@ -1707,7 +1707,7 @@ class Composite(ScalarOp): ...@@ -1707,7 +1707,7 @@ class Composite(ScalarOp):
def make_new_inplace(self, output_types_preference = None, name = None): def make_new_inplace(self, output_types_preference = None, name = None):
""" """
This op.__init__ fct don't have the same parameter as other scalar op. This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization. This break the insert_inplace_optimizer optimization.
This fct allow fix patch this. This fct allow fix patch this.
""" """
out = self.__class__(self.inputs,self.outputs) out = self.__class__(self.inputs,self.outputs)
...@@ -1828,12 +1828,12 @@ class Composite(ScalarOp): ...@@ -1828,12 +1828,12 @@ class Composite(ScalarOp):
#The use of a dummy id is safe as the code is in a separate block. #The use of a dummy id is safe as the code is in a separate block.
#It won't generate conflicting variable name. #It won't generate conflicting variable name.
d['id']='_DUMMY_ID_' d['id']='_DUMMY_ID_'
return self._c_code % d return self._c_code % d
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,)+tuple([x.op.c_code_cache_version() for x in self.env.toposort()]) return (1,)+tuple([x.op.c_code_cache_version() for x in self.env.toposort()])
def c_support_code(self): def c_support_code(self):
str = "" str = ""
for node in self.env.toposort(): for node in self.env.toposort():
...@@ -1864,11 +1864,9 @@ class Composite(ScalarOp): ...@@ -1864,11 +1864,9 @@ class Composite(ScalarOp):
d.pop('env') d.pop('env')
d.pop('_impls') d.pop('_impls')
return d return d
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
#we must call init to set env and _impls again. #we must call init to set env and _impls again.
#otherwise self.perform won't work. #otherwise self.perform won't work.
self.__init__(self.inputs, self.outputs) self.__init__(self.inputs, self.outputs)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论