提交 00587f67 authored 作者: James Bergstra's avatar James Bergstra

merge

...@@ -328,7 +328,7 @@ uint_types = uint8, uint16, uint32, uint64 ...@@ -328,7 +328,7 @@ uint_types = uint8, uint16, uint32, uint64
float_types = float32, float64 float_types = float32, float64
complex_types = complex64, complex128 complex_types = complex64, complex128
grad_types = float_types + complex_types # these are the types for which gradients can be defined. continuous_types = float_types + complex_types
class _scalar_py_operators: class _scalar_py_operators:
...@@ -698,22 +698,18 @@ class Switch(ScalarOp): ...@@ -698,22 +698,18 @@ class Switch(ScalarOp):
def c_code(self, node, name, (cond, ift, iff), (z, ), sub): def c_code(self, node, name, (cond, ift, iff), (z, ), sub):
return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals() return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals()
def grad(self, (cond, ift, iff), (gz, )): def grad(self, (cond, ift, iff), (gz, )):
if ift.type in grad_types: if ift.type in continuous_types:
first_part = switch(cond, gz, 0) first_part = switch(cond, gz, 0)
else: else:
first_part = None first_part = None
if iff.type in grad_types: if iff.type in continuous_types:
second_part = switch(cond, 0, gz) second_part = switch(cond, 0, gz)
else: else:
second_part = None second_part = None
return (None, first_part, second_part) return (None, first_part, second_part)
#return (None,
# switch(cond, gz, 0) if ift.type in grad_types else None,
# switch(cond, 0, gz) if iff.type in grad_types else None)
def output_types(self, (cond_t, ift_t, iff_t)): def output_types(self, (cond_t, ift_t, iff_t)):
return upcast_out(ift_t, iff_t) return upcast_out(ift_t, iff_t)
switch = Switch() switch = Switch()
...@@ -786,10 +782,12 @@ class Maximum(BinaryScalarOp): ...@@ -786,10 +782,12 @@ class Maximum(BinaryScalarOp):
return "%(z)s = ((%(y)s)>(%(x)s)? (%(y)s):(%(x)s));" %locals() return "%(z)s = ((%(y)s)>(%(x)s)? (%(y)s):(%(x)s));" %locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types
# max is not defined for complex_types
gx, gy = None, None gx, gy = None, None
if x.type in grad_types: if x.type in float_types:
gx = eq(maximum(x,y), x)*gz gx = eq(maximum(x,y), x)*gz
if y.type in grad_types: if y.type in float_types:
gy = eq(maximum(x,y), y)*gz gy = eq(maximum(x,y), y)*gz
return (gx,gy) return (gx,gy)
...@@ -804,10 +802,12 @@ class Minimum(BinaryScalarOp): ...@@ -804,10 +802,12 @@ class Minimum(BinaryScalarOp):
return "%(z)s = ((%(y)s)<(%(x)s)? (%(y)s):(%(x)s));" %locals() return "%(z)s = ((%(y)s)<(%(x)s)? (%(y)s):(%(x)s));" %locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types
# max is not defined for complex_types
gx, gy = None, None gx, gy = None, None
if x.type in grad_types: if x.type in float_types:
gx = eq(minimum(x,y), x)*gz gx = eq(minimum(x,y), x)*gz
if y.type in grad_types: if y.type in float_types:
gy = eq(minimum(x,y), y)*gz gy = eq(minimum(x,y), y)*gz
return (gx,gy) return (gx,gy)
...@@ -825,13 +825,24 @@ class Add(ScalarOp): ...@@ -825,13 +825,24 @@ class Add(ScalarOp):
else: else:
return z + " = " + " + ".join(inputs) + ";" return z + " = " + " + ".join(inputs) + ";"
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
retval = [] retval = []
for i in inputs: if gz.type in complex_types:
if i.type in grad_types: for i in inputs:
retval += [cast(gz, i.type.dtype)] if i.type in complex_types:
retval += [cast(gz, i.type.dtype)]
elif i.type in float_types:
retval += [cast(real(gz), i.type.dtype)]
else:
retval += [None]
elif gz.type in float_types:
for i in inputs:
if i.type in float_types:
retval += [cast(gz, i.type.dtype)]
else:
retval += [None]
else: else:
retval += [None] retval += [None] * len(inputs)
return retval return retval
add = Add(upcast_out, name = 'add') add = Add(upcast_out, name = 'add')
class Mul(ScalarOp): class Mul(ScalarOp):
...@@ -848,23 +859,22 @@ class Mul(ScalarOp): ...@@ -848,23 +859,22 @@ class Mul(ScalarOp):
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
retval = [] retval = []
for input in inputs: for input in inputs:
if input.type in grad_types: if input.type in continuous_types:
if input.type in complex_types: if gz.type in complex_types:
# does casting from real to complex work? # zr+zi = (xr + xi)(yr + yi)
dz_dinput = cast(mul(*(utils.difference(inputs, [input]))), input.type.dtype) # zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr )
x = real(dz_dinput) otherprod = mul(*(utils.difference(inputs, [input])))
y = imag(dz_dinput) yr = real(otherprod)
retval += [complex(x*real(gz)+y*imag(gz), x*imag(gz)-y*real(gz))] yi = imag(otherprod)
if input.type in complex_types:
retval += [complex(yr*real(gz)+yi*imag(gz), yr*imag(gz)-yi*real(gz))]
else:
retval += [cast(yr*real(gz)+yi*imag(gz), input.type.dtype)]
else: else:
retval += [cast(mul(*([gz] + utils.difference(inputs, [input]))), input.type.dtype)] retval += [cast(mul(*([gz] + utils.difference(inputs, [input]))), input.type.dtype)]
else: else:
retval += [None] retval += [None]
return retval return retval
#return [(mul(*([gz] + utils.difference(inputs, [input])))
# if input.type in grad_types else None)
# for input in inputs]
mul = Mul(upcast_out, name = 'mul') mul = Mul(upcast_out, name = 'mul')
class Sub(BinaryScalarOp): class Sub(BinaryScalarOp):
...@@ -873,12 +883,15 @@ class Sub(BinaryScalarOp): ...@@ -873,12 +883,15 @@ class Sub(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" % locals() return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
if x.type in grad_types: if gz.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
first_part = cast(gz, x.type.dtype) first_part = cast(gz, x.type.dtype)
else: else:
first_part = None first_part = None
if y.type in grad_types: if y.type in float_types:
second_part = cast(-gz, y.type.dtype) second_part = cast(-gz, y.type.dtype)
else: else:
second_part = None second_part = None
...@@ -911,12 +924,14 @@ class TrueDiv(BinaryScalarOp): ...@@ -911,12 +924,14 @@ class TrueDiv(BinaryScalarOp):
return "%(z)s = ((double)%(x)s) / %(y)s;" % locals() return "%(z)s = ((double)%(x)s) / %(y)s;" % locals()
return "%(z)s = %(x)s / %(y)s;" % locals() return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
if x.type in grad_types: if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
first_part = cast(gz / y, x.type.dtype) first_part = cast(gz / y, x.type.dtype)
else: else:
first_part = None first_part = None
if y.type in grad_types: if y.type in float_types:
second_part = cast(-(gz * x) / (y * y), y.type.dtype) second_part = cast(-(gz * x) / (y * y), y.type.dtype)
else: else:
second_part = None second_part = None
...@@ -982,12 +997,14 @@ class Pow(BinaryScalarOp): ...@@ -982,12 +997,14 @@ class Pow(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = pow(%(x)s, %(y)s);" % locals() return "%(z)s = pow(%(x)s, %(y)s);" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
if x.type in grad_types: if gz.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
first_part = gz * y * x**(y - 1) first_part = gz * y * x**(y - 1)
else: else:
first_part = None first_part = None
if y.type in grad_types: if y.type in float_types:
second_part = gz * log(x) * x**y second_part = gz * log(x) * x**y
else: else:
second_part = None second_part = None
...@@ -1008,8 +1025,9 @@ class Clip(ScalarOp): ...@@ -1008,8 +1025,9 @@ class Clip(ScalarOp):
def c_code(self, node, name, (x, min, max), (z, ), sub): def c_code(self, node, name, (x, min, max), (z, ), sub):
return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals() return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals()
def grad(self, (x, min, max), (gz, )): def grad(self, (x, min, max), (gz, )):
assert gz.type not in complex_types
gx = ((x > min) & (x < max)) * gz gx = ((x > min) & (x < max)) * gz
if x.type in grad_types: if x.type in float_types:
return gx, None, None return gx, None, None
else: else:
return None, None, None return None, None, None
...@@ -1021,12 +1039,10 @@ class First(BinaryScalarOp): ...@@ -1021,12 +1039,10 @@ class First(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
if x.type in grad_types: if x.type in continuous_types:
return gz, None return gz, None
else: else:
return None,None return None,None
#backport
#return gz if x.type in grad_types else None, None
first = First(transfer_type(0), name = 'first') first = First(transfer_type(0), name = 'first')
class Second(BinaryScalarOp): class Second(BinaryScalarOp):
...@@ -1035,13 +1051,11 @@ class Second(BinaryScalarOp): ...@@ -1035,13 +1051,11 @@ class Second(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(y)s;" % locals() return "%(z)s = %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
if y.type in grad_types: if y.type in continuous_types:
return None, gz return None, gz
else: else:
return None return None
#backport
#return None, gz if y.type in grad_types else None
second = Second(transfer_type(1), name = 'second') second = Second(transfer_type(1), name = 'second')
...@@ -1052,7 +1066,7 @@ class Identity(UnaryScalarOp): ...@@ -1052,7 +1066,7 @@ class Identity(UnaryScalarOp):
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in continuous_types:
return gz, return gz,
else: else:
return None, return None,
...@@ -1073,7 +1087,7 @@ class Cast(UnaryScalarOp): ...@@ -1073,7 +1087,7 @@ class Cast(UnaryScalarOp):
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x) return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in continuous_types:
return [cast(gz, x.type.dtype)] return [cast(gz, x.type.dtype)]
else: else:
return None, return None,
...@@ -1134,7 +1148,7 @@ class Abs(UnaryScalarOp): ...@@ -1134,7 +1148,7 @@ class Abs(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.abs(x) return numpy.abs(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in float_types + complex_types:
return gz * x / abs(x), # formula works for complex and real return gz * x / abs(x), # formula works for complex and real
else: else:
return None, return None,
...@@ -1279,7 +1293,7 @@ class Neg(UnaryScalarOp): ...@@ -1279,7 +1293,7 @@ class Neg(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return -x return -x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in continuous_types:
return -gz, return -gz,
else: else:
return None, return None,
...@@ -1291,13 +1305,12 @@ class Inv(UnaryScalarOp): ...@@ -1291,13 +1305,12 @@ class Inv(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return 1.0 / x return 1.0 / x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in complex_types:
return -gz / (x * x), raise NotImplementedError()
else: if x.type in float_types:
return None, return -gz / (x * x),
else:
#backport return None,
#return -gz / (x * x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = 1.0 / %(x)s;" % locals() return "%(z)s = 1.0 / %(x)s;" % locals()
inv = Inv(upgrade_to_float, name = 'inv') inv = Inv(upgrade_to_float, name = 'inv')
...@@ -1307,12 +1320,12 @@ class Log(UnaryScalarOp): ...@@ -1307,12 +1320,12 @@ class Log(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.log(x) return numpy.log(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in complex_types:
return gz / x, raise NotImplementedError()
else: if x.type in float_types:
return None, return gz / x,
#backport else:
#return gz / x if x.type in grad_types else None, return None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
#todo: the version using log2 seems to be very slightly faster #todo: the version using log2 seems to be very slightly faster
# on some machines for some reason, check if it's worth switching # on some machines for some reason, check if it's worth switching
...@@ -1327,13 +1340,13 @@ class Log2(UnaryScalarOp): ...@@ -1327,13 +1340,13 @@ class Log2(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.log2(x) return numpy.log2(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / (x * math.log(2.0)), return gz / (x * math.log(2.0)),
else: else:
return None, return None,
#backport
#return gz / (x * math.log(2.0)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1345,13 +1358,13 @@ class Log10(UnaryScalarOp): ...@@ -1345,13 +1358,13 @@ class Log10(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.log10(x) return numpy.log10(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / (x * numpy.log(10.0)), return gz / (x * numpy.log(10.0)),
else: else:
return None return None
#backport
#return gz / (x * numpy.log(10.0)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1363,7 +1376,11 @@ class Log1p(UnaryScalarOp): ...@@ -1363,7 +1376,11 @@ class Log1p(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.log1p(x) return numpy.log1p(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
return [gz / (1+x)] if gz.type in complex_types:
raise NotImplementedError()
if gz.type in float_types:
return [gz / (1+x)]
return [None]
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1374,12 +1391,12 @@ class Exp(UnaryScalarOp): ...@@ -1374,12 +1391,12 @@ class Exp(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.exp(x) return numpy.exp(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in complex_types:
return gz * exp(x), raise NotImplementedError()
else: elif x.type in float_types:
return None, return gz * exp(x),
#backport else:
#return gz * exp(x) if x.type in grad_types else None, return None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1390,13 +1407,13 @@ class Sqr(UnaryScalarOp): ...@@ -1390,13 +1407,13 @@ class Sqr(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return x*x return x*x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if gz.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz * x * 2, return gz * x * 2,
else: else:
return None, return None,
#backport
# return gz * x * 2 if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s * %(x)s;" % locals() return "%(z)s = %(x)s * %(x)s;" % locals()
sqr = Sqr(same_out, name = 'sqr') sqr = Sqr(same_out, name = 'sqr')
...@@ -1405,12 +1422,12 @@ class Sqrt(UnaryScalarOp): ...@@ -1405,12 +1422,12 @@ class Sqrt(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.sqrt(x) return numpy.sqrt(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if gz.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return (gz * 0.5) / sqrt(x), return (gz * 0.5) / sqrt(x),
else: else:
return None, return None,
#backport
#return (gz * 0.5) / sqrt(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1421,12 +1438,12 @@ class Cos(UnaryScalarOp): ...@@ -1421,12 +1438,12 @@ class Cos(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.cos(x) return numpy.cos(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if gz.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return -gz * sin(x), return -gz * sin(x),
else: else:
return None, return None,
#backport
# return -gz * sin(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1437,12 +1454,12 @@ class Sin(UnaryScalarOp): ...@@ -1437,12 +1454,12 @@ class Sin(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.sin(x) return numpy.sin(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz * cos(x), return gz * cos(x),
else: else:
return None, return None,
#backport
# return gz * cos(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1453,12 +1470,12 @@ class Tan(UnaryScalarOp): ...@@ -1453,12 +1470,12 @@ class Tan(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.tan(x) return numpy.tan(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz / sqr(cos(x)), return gz / sqr(cos(x)),
else: else:
return None, return None,
#backport
#return gz / sqr(cos(x)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1472,12 +1489,12 @@ class Cosh(UnaryScalarOp): ...@@ -1472,12 +1489,12 @@ class Cosh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.cosh(x) return numpy.cosh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz * sinh(x), return gz * sinh(x),
else: else:
return None, return None,
#backport
#return gz * sinh(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1491,12 +1508,12 @@ class Sinh(UnaryScalarOp): ...@@ -1491,12 +1508,12 @@ class Sinh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.sinh(x) return numpy.sinh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz * cosh(x), return gz * cosh(x),
else: else:
return None, return None,
#backport
#return gz * cosh(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1511,12 +1528,12 @@ class Tanh(UnaryScalarOp): ...@@ -1511,12 +1528,12 @@ class Tanh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.tanh(x) return numpy.tanh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in complex_types:
raise NotImplementedError()
if x.type in float_types:
return gz * (1 - sqr(tanh(x))), return gz * (1 - sqr(tanh(x))),
else: else:
return None, return None,
#backport
#return gz * (1 - sqr(tanh(x))) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
...@@ -1536,7 +1553,12 @@ class Imag(UnaryScalarOp): ...@@ -1536,7 +1553,12 @@ class Imag(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.imag(x) return numpy.imag(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return [complex(0, gz)] if x.type in complex_types:
return [complex(0, gz)]
elif x.type in float_types:
return [second(x,0)]
else:
return [None]
imag = Imag(real_out, name='imag') imag = Imag(real_out, name='imag')
class Angle(UnaryScalarOp): class Angle(UnaryScalarOp):
...@@ -1558,10 +1580,15 @@ class Angle(UnaryScalarOp): ...@@ -1558,10 +1580,15 @@ class Angle(UnaryScalarOp):
r = abs(c) r = abs(c)
gr = -gtheta * y / (r**2 * sqrt(1 - (y/r)**2)) gr = -gtheta * y / (r**2 * sqrt(1 - (y/r)**2))
gx = gr * x/r gx = gr * x/r
gy = gr * y/r gy = gr * y/r
return [complex(gx, gy)] if c in complex_types:
return [cast(complex(gx, gy), x.type.dtype)]
elif c in float_types:
return [cast(second(x,0), x.type.dtype)]
else:
return [None]
angle = Angle(specific_out(float64), name='angle') angle = Angle(specific_out(float64), name='angle')
class Complex(BinaryScalarOp): class Complex(BinaryScalarOp):
...@@ -1579,8 +1606,9 @@ class Complex(BinaryScalarOp): ...@@ -1579,8 +1606,9 @@ class Complex(BinaryScalarOp):
return [complex64] return [complex64]
def impl(self, x, y): def impl(self, x, y):
return numpy.complex(x, y) return numpy.complex(x, y)
def grad(self, (x,y), (z,)): def grad(self, (x,y), (gz,)):
return [real(z), imag(z)] return [cast(real(gz), x.type.dtype),
cast(imag(gz), y.type.dtype)]
complex = Complex(name='complex') complex = Complex(name='complex')
class Conj(UnaryScalarOp): class Conj(UnaryScalarOp):
...@@ -1606,7 +1634,8 @@ class ComplexFromPolar(BinaryScalarOp): ...@@ -1606,7 +1634,8 @@ class ComplexFromPolar(BinaryScalarOp):
def grad(self, (r,theta), (gz,)): def grad(self, (r,theta), (gz,)):
gr = cos(theta) * real(gz) + sin(theta) * imag(gz) gr = cos(theta) * real(gz) + sin(theta) * imag(gz)
gtheta = -real(gz) * r * sin(theta) + imag(gz) * r * cos(theta) gtheta = -real(gz) * r * sin(theta) + imag(gz) * r * cos(theta)
return [gr, gtheta] return [cast(gr, r.type.dtype),
cast(gtheta, theta.type.dtype)]
complex_from_polar = ComplexFromPolar(name='complex_from_polar') complex_from_polar = ComplexFromPolar(name='complex_from_polar')
......
...@@ -3,6 +3,8 @@ import theano ...@@ -3,6 +3,8 @@ import theano
from theano.tensor import * from theano.tensor import *
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from numpy.testing import dec
class TestRealImag(unittest.TestCase): class TestRealImag(unittest.TestCase):
def test0(self): def test0(self):
...@@ -53,6 +55,54 @@ class TestRealImag(unittest.TestCase): ...@@ -53,6 +55,54 @@ class TestRealImag(unittest.TestCase):
mval = numpy.asarray(rng.randn(2,5)) mval = numpy.asarray(rng.randn(2,5))
utt.verify_grad(f, [mval]) utt.verify_grad(f, [mval])
@dec.knownfailureif(True,"Complex grads not enabled")
def test_mul_mixed0(self):
def f(a):
ac = complex(a[0], a[1])
return abs((ac)**2).sum()
rng = numpy.random.RandomState(9333)
aval = numpy.asarray(rng.randn(2,5))
try:
utt.verify_grad(f, [aval])
except utt.verify_grad.E_grad, e:
print e.num_grad.gf
print e.analytic_grad
raise
@dec.knownfailureif(True,"Complex grads not enabled")
def test_mul_mixed1(self):
def f(a):
ac = complex(a[0], a[1])
return abs(ac).sum()
rng = numpy.random.RandomState(9333)
aval = numpy.asarray(rng.randn(2,5))
try:
utt.verify_grad(f, [aval])
except utt.verify_grad.E_grad, e:
print e.num_grad.gf
print e.analytic_grad
raise
@dec.knownfailureif(True,"Complex grads not enabled")
def test_mul_mixed(self):
def f(a,b):
ac = complex(a[0], a[1])
return abs((ac*b)**2).sum()
rng = numpy.random.RandomState(9333)
aval = numpy.asarray(rng.randn(2,5))
bval = rng.randn(5)
try:
utt.verify_grad(f, [aval, bval])
except utt.verify_grad.E_grad, e:
print e.num_grad.gf
print e.analytic_grad
raise
def test_polar_grads(self): def test_polar_grads(self):
def f(m): def f(m):
c = complex_from_polar(abs(m[0]), m[1]) c = complex_from_polar(abs(m[0]), m[1])
...@@ -62,7 +112,6 @@ class TestRealImag(unittest.TestCase): ...@@ -62,7 +112,6 @@ class TestRealImag(unittest.TestCase):
mval = numpy.asarray(rng.randn(2,5)) mval = numpy.asarray(rng.randn(2,5))
utt.verify_grad(f, [mval]) utt.verify_grad(f, [mval])
def test_abs_grad(self): def test_abs_grad(self):
def f(m): def f(m):
c = complex(m[0], m[1]) c = complex(m[0], m[1])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论