提交 002872ad authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Remove None from return values of grad.

Also change checks to verify the output dtype of the Op itself, not of the inputs or gradient, because it can depend on different things. The idea is that if the Op's output is continuous, then the gradient should be propagated to the inputs, regardless of whether they are continuous or discrete. However, if the output is discrete, then the gradient wrt the inputs will be a continuous zero.
上级 37ce26a3
...@@ -1568,7 +1568,8 @@ class IntDiv(BinaryScalarOp): ...@@ -1568,7 +1568,8 @@ class IntDiv(BinaryScalarOp):
return (2,) return (2,)
def grad(self, inputs, g_output): def grad(self, inputs, g_output):
return [None] * len(inputs) return [inp.zeros_like(dtype=theano.config.floatX)
for inp in inputs]
int_div = IntDiv(upcast_out, name='int_div') int_div = IntDiv(upcast_out, name='int_div')
...@@ -1654,7 +1655,8 @@ class Mod(BinaryScalarOp): ...@@ -1654,7 +1655,8 @@ class Mod(BinaryScalarOp):
""") % locals() """) % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return None, None return [x.zeros_like(dtype=theano.config.floatX),
y.zeros_like(dtype=theano.config.floatX)]
mod = Mod(upcast_out, name='mod') mod = Mod(upcast_out, name='mod')
...@@ -1892,10 +1894,13 @@ class Abs(UnaryScalarOp): ...@@ -1892,10 +1894,13 @@ class Abs(UnaryScalarOp):
return numpy.abs(x) return numpy.abs(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in float_types + complex_types: if self(x).type in discrete_types:
return gz * x / abs(x), # formula works for complex and real if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz * x / abs(x), # formula works for complex and real
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
type = node.inputs[0].type type = node.inputs[0].type
...@@ -2096,10 +2101,13 @@ class Neg(UnaryScalarOp): ...@@ -2096,10 +2101,13 @@ class Neg(UnaryScalarOp):
return -x return -x
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if x.type in continuous_types: if self(x).type in discrete_types:
return -gz, if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return -gz,
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
return "%(z)s = -%(x)s;" % locals() return "%(z)s = -%(x)s;" % locals()
...@@ -2114,10 +2122,13 @@ class Inv(UnaryScalarOp): ...@@ -2114,10 +2122,13 @@ class Inv(UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return -gz / (x * x), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return -gz / (x * x),
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
return "%(z)s = 1.0 / %(x)s;" % locals() return "%(z)s = 1.0 / %(x)s;" % locals()
...@@ -2135,10 +2146,13 @@ class Log(UnaryScalarOp): ...@@ -2135,10 +2146,13 @@ class Log(UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz / x, if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz / x,
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
#todo: the version using log2 seems to be very slightly faster #todo: the version using log2 seems to be very slightly faster
...@@ -2161,10 +2175,13 @@ class Log2(UnaryScalarOp): ...@@ -2161,10 +2175,13 @@ class Log2(UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz / (x * math.log(2.0)), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz / (x * math.log(2.0)),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2184,10 +2201,13 @@ class Log10(UnaryScalarOp): ...@@ -2184,10 +2201,13 @@ class Log10(UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz / (x * numpy.log(10.0)), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None return [x.zeros_like()]
return gz / (x * numpy.log(10.0)),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2204,9 +2224,13 @@ class Log1p(UnaryScalarOp): ...@@ -2204,9 +2224,13 @@ class Log1p(UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if gz.type in float_types: if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
return [gz / (1 + x)] return [gz / (1 + x)]
return [None]
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2225,10 +2249,13 @@ class Exp(UnaryScalarOp): ...@@ -2225,10 +2249,13 @@ class Exp(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
elif x.type in float_types: if self(x).type in discrete_types:
return gz * exp(x), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz * exp(x),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2244,10 +2271,13 @@ class Exp2(UnaryScalarOp): ...@@ -2244,10 +2271,13 @@ class Exp2(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
elif x.type in float_types: if self(x).type in discrete_types:
return gz * exp2(x) * log(numpy.cast[x.type](2)), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz * exp2(x) * log(numpy.cast[x.type](2)),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2263,10 +2293,13 @@ class Expm1(UnaryScalarOp): ...@@ -2263,10 +2293,13 @@ class Expm1(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
elif x.type in float_types: if self(x).type in discrete_types:
return gz * exp(x), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz * exp(x),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2282,10 +2315,13 @@ class Sqr(UnaryScalarOp): ...@@ -2282,10 +2315,13 @@ class Sqr(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz * x * 2, if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz * x * 2,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s * %(x)s;" % locals() return "%(z)s = %(x)s * %(x)s;" % locals()
...@@ -2299,10 +2335,13 @@ class Sqrt(UnaryScalarOp): ...@@ -2299,10 +2335,13 @@ class Sqrt(UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return (gz * 0.5) / sqrt(x), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return (gz * 0.5) / sqrt(x),
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2318,10 +2357,13 @@ class Deg2Rad(UnaryScalarOp): ...@@ -2318,10 +2357,13 @@ class Deg2Rad(UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz * numpy.asarray(numpy.pi / 180, gz.type), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz * numpy.asarray(numpy.pi / 180, gz.type),
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2337,10 +2379,13 @@ class Rad2Deg(UnaryScalarOp): ...@@ -2337,10 +2379,13 @@ class Rad2Deg(UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz * numpy.asarray(180. / numpy.pi, gz.type), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz * numpy.asarray(180. / numpy.pi, gz.type),
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2359,10 +2404,13 @@ class Cos(UnaryScalarOp): ...@@ -2359,10 +2404,13 @@ class Cos(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return -gz * sin(x), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return -gz * sin(x),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2378,10 +2426,13 @@ class ArcCos(UnaryScalarOp): ...@@ -2378,10 +2426,13 @@ class ArcCos(UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return - gz / sqrt(numpy.cast[x.type](1) - sqr(x)), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return - gz / sqrt(numpy.cast[x.type](1) - sqr(x)),
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2400,10 +2451,13 @@ class Sin(UnaryScalarOp): ...@@ -2400,10 +2451,13 @@ class Sin(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz * cos(x), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz * cos(x),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2419,10 +2473,13 @@ class ArcSin(UnaryScalarOp): ...@@ -2419,10 +2473,13 @@ class ArcSin(UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz / sqrt(numpy.cast[x.type](1) - sqr(x)), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz / sqrt(numpy.cast[x.type](1) - sqr(x)),
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2438,10 +2495,13 @@ class Tan(UnaryScalarOp): ...@@ -2438,10 +2495,13 @@ class Tan(UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz / sqr(cos(x)), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz / sqr(cos(x)),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2457,10 +2517,13 @@ class ArcTan(UnaryScalarOp): ...@@ -2457,10 +2517,13 @@ class ArcTan(UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz / (numpy.cast[x.type](1) + sqr(x)), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz / (numpy.cast[x.type](1) + sqr(x)),
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2476,11 +2539,22 @@ class ArcTan2(BinaryScalarOp): ...@@ -2476,11 +2539,22 @@ class ArcTan2(BinaryScalarOp):
def grad(self, (y, x), (gz,)): def grad(self, (y, x), (gz,)):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types and y.type in float_types: else:
if self(x, y).type in discrete_types:
if x.type in discrete_types:
gx = x.zeros_like(dtype=theano.config.floatX)
else:
gx = x.zeros_like()
if y.type in discrete_types:
gy = y.zeros_like(dtype=theano.config.floatX)
else:
gy = y.zeros_like()
return [gx, gy]
# If the output is float, the gradient should flow,
# even if the inputs are ints
return [gz * x / (sqr(x) + sqr(y)), return [gz * x / (sqr(x) + sqr(y)),
gz * neg(y) / (sqr(x) + sqr(y))] gz * neg(y) / (sqr(x) + sqr(y))]
else:
return None,
def c_code(self, node, name, (y, x), (z,), sub): def c_code(self, node, name, (y, x), (z,), sub):
if (node.inputs[0].type in complex_types or if (node.inputs[0].type in complex_types or
...@@ -2500,10 +2574,13 @@ class Cosh(UnaryScalarOp): ...@@ -2500,10 +2574,13 @@ class Cosh(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz * sinh(x), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz * sinh(x),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2519,10 +2596,13 @@ class ArcCosh(UnaryScalarOp): ...@@ -2519,10 +2596,13 @@ class ArcCosh(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz / sqrt(sqr(x) - numpy.cast[x.type](1)), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz / sqrt(sqr(x) - numpy.cast[x.type](1)),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2541,10 +2621,13 @@ class Sinh(UnaryScalarOp): ...@@ -2541,10 +2621,13 @@ class Sinh(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz * cosh(x), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz * cosh(x),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2560,10 +2643,13 @@ class ArcSinh(UnaryScalarOp): ...@@ -2560,10 +2643,13 @@ class ArcSinh(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz / sqrt(sqr(x) + numpy.cast[x.type](1)), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz / sqrt(sqr(x) + numpy.cast[x.type](1)),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2583,10 +2669,13 @@ class Tanh(UnaryScalarOp): ...@@ -2583,10 +2669,13 @@ class Tanh(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz * (1 - sqr(tanh(x))), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz * (1 - sqr(tanh(x))),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
...@@ -2602,10 +2691,13 @@ class ArcTanh(UnaryScalarOp): ...@@ -2602,10 +2691,13 @@ class ArcTanh(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if x.type in float_types: if self(x).type in discrete_types:
return gz / (numpy.cast[x.type](1) - sqr(x)), if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else: else:
return None, return [x.zeros_like()]
return gz / (numpy.cast[x.type](1) - sqr(x)),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
#as scipy is not always available, we treat them separatly #as scipy is not always available, we treat them separatly
import numpy import numpy
import theano
from theano.scalar.basic import (UnaryScalarOp, BinaryScalarOp, from theano.scalar.basic import (UnaryScalarOp, BinaryScalarOp,
exp, upgrade_to_float, exp, upgrade_to_float,
float_types) float_types)
from theano.scalar.basic import (upgrade_to_float_no_complex, from theano.scalar.basic import (upgrade_to_float_no_complex,
complex_types, complex_types, discrete_types,
upcast) upcast)
imported_scipy_special = False imported_scipy_special = False
...@@ -32,12 +33,15 @@ class Erf(UnaryScalarOp): ...@@ -32,12 +33,15 @@ class Erf(UnaryScalarOp):
gz, = grads gz, = grads
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
elif x.type in float_types: if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
cst = numpy.asarray(2. / numpy.sqrt(numpy.pi), cst = numpy.asarray(2. / numpy.sqrt(numpy.pi),
dtype=upcast(x.type.dtype, gz.type.dtype)) dtype=upcast(x.type.dtype, gz.type.dtype))
return gz * cst * exp(-x * x), return gz * cst * exp(-x * x),
else:
return None,
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, = inp x, = inp
...@@ -60,12 +64,15 @@ class Erfc(UnaryScalarOp): ...@@ -60,12 +64,15 @@ class Erfc(UnaryScalarOp):
gz, = grads gz, = grads
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
elif x.type in float_types: if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
cst = numpy.asarray(2. / numpy.sqrt(numpy.pi), cst = numpy.asarray(2. / numpy.sqrt(numpy.pi),
dtype=upcast(x.type.dtype, gz.type.dtype)) dtype=upcast(x.type.dtype, gz.type.dtype))
return - gz * cst * exp(-x * x), return - gz * cst * exp(-x * x),
else:
return None,
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, = inp x, = inp
...@@ -99,12 +106,15 @@ class Erfinv(UnaryScalarOp): ...@@ -99,12 +106,15 @@ class Erfinv(UnaryScalarOp):
gz, = grads gz, = grads
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
elif x.type in float_types: if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2., cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype)) dtype=upcast(x.type.dtype, gz.type.dtype))
return gz * cst * exp(erfinv(x) ** 2), return gz * cst * exp(erfinv(x) ** 2),
else:
return None,
# TODO: erfinv() is not provided by the C standard library # TODO: erfinv() is not provided by the C standard library
#def c_code(self, node, name, inp, out, sub): #def c_code(self, node, name, inp, out, sub):
...@@ -129,12 +139,15 @@ class Erfcinv(UnaryScalarOp): ...@@ -129,12 +139,15 @@ class Erfcinv(UnaryScalarOp):
gz, = grads gz, = grads
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
elif x.type in float_types: if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2., cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype)) dtype=upcast(x.type.dtype, gz.type.dtype))
return - gz * cst * exp(erfcinv(x) ** 2), return - gz * cst * exp(erfcinv(x) ** 2),
else:
return None,
# TODO: erfcinv() is not provided by the C standard library # TODO: erfcinv() is not provided by the C standard library
#def c_code(self, node, name, inp, out, sub): #def c_code(self, node, name, inp, out, sub):
...@@ -159,6 +172,14 @@ class Gamma(UnaryScalarOp): ...@@ -159,6 +172,14 @@ class Gamma(UnaryScalarOp):
super(Gamma, self).impl(x) super(Gamma, self).impl(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
return gz * gamma(x) * psi(x), return gz * gamma(x) * psi(x),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
...@@ -190,6 +211,14 @@ class GammaLn(UnaryScalarOp): ...@@ -190,6 +211,14 @@ class GammaLn(UnaryScalarOp):
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
gz, = grads gz, = grads
if x.type in complex_types:
raise NotImplementedError()
if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
return [gz * psi(x)] return [gz * psi(x)]
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
...@@ -224,7 +253,6 @@ class Psi(UnaryScalarOp): ...@@ -224,7 +253,6 @@ class Psi(UnaryScalarOp):
def grad(self, inputs, outputs_gradients): def grad(self, inputs, outputs_gradients):
raise NotImplementedError() raise NotImplementedError()
return [None]
def c_support_code(self): def c_support_code(self):
return ( return (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论