提交 002872ad authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Remove None from return values of grad.

Also change checks to verify the output dtype of the Op itself, not of the inputs or gradient, because it can depend on different things. The idea is that if the Op's output is continuous, then the gradient should be propagated to the inputs, regardless of whether they are continuous or discrete. However, if the output is discrete, then the gradient wrt the inputs will be a continuous zero.
上级 37ce26a3
差异被折叠。
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
#as scipy is not always available, we treat them separatly #as scipy is not always available, we treat them separatly
import numpy import numpy
import theano
from theano.scalar.basic import (UnaryScalarOp, BinaryScalarOp, from theano.scalar.basic import (UnaryScalarOp, BinaryScalarOp,
exp, upgrade_to_float, exp, upgrade_to_float,
float_types) float_types)
from theano.scalar.basic import (upgrade_to_float_no_complex, from theano.scalar.basic import (upgrade_to_float_no_complex,
complex_types, complex_types, discrete_types,
upcast) upcast)
imported_scipy_special = False imported_scipy_special = False
...@@ -32,12 +33,15 @@ class Erf(UnaryScalarOp): ...@@ -32,12 +33,15 @@ class Erf(UnaryScalarOp):
gz, = grads gz, = grads
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
elif x.type in float_types: if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
cst = numpy.asarray(2. / numpy.sqrt(numpy.pi), cst = numpy.asarray(2. / numpy.sqrt(numpy.pi),
dtype=upcast(x.type.dtype, gz.type.dtype)) dtype=upcast(x.type.dtype, gz.type.dtype))
return gz * cst * exp(-x * x), return gz * cst * exp(-x * x),
else:
return None,
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, = inp x, = inp
...@@ -60,12 +64,15 @@ class Erfc(UnaryScalarOp): ...@@ -60,12 +64,15 @@ class Erfc(UnaryScalarOp):
gz, = grads gz, = grads
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
elif x.type in float_types: if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
cst = numpy.asarray(2. / numpy.sqrt(numpy.pi), cst = numpy.asarray(2. / numpy.sqrt(numpy.pi),
dtype=upcast(x.type.dtype, gz.type.dtype)) dtype=upcast(x.type.dtype, gz.type.dtype))
return - gz * cst * exp(-x * x), return - gz * cst * exp(-x * x),
else:
return None,
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, = inp x, = inp
...@@ -99,12 +106,15 @@ class Erfinv(UnaryScalarOp): ...@@ -99,12 +106,15 @@ class Erfinv(UnaryScalarOp):
gz, = grads gz, = grads
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
elif x.type in float_types: if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2., cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype)) dtype=upcast(x.type.dtype, gz.type.dtype))
return gz * cst * exp(erfinv(x) ** 2), return gz * cst * exp(erfinv(x) ** 2),
else:
return None,
# TODO: erfinv() is not provided by the C standard library # TODO: erfinv() is not provided by the C standard library
#def c_code(self, node, name, inp, out, sub): #def c_code(self, node, name, inp, out, sub):
...@@ -129,12 +139,15 @@ class Erfcinv(UnaryScalarOp): ...@@ -129,12 +139,15 @@ class Erfcinv(UnaryScalarOp):
gz, = grads gz, = grads
if x.type in complex_types: if x.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
elif x.type in float_types: if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2., cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype)) dtype=upcast(x.type.dtype, gz.type.dtype))
return - gz * cst * exp(erfcinv(x) ** 2), return - gz * cst * exp(erfcinv(x) ** 2),
else:
return None,
# TODO: erfcinv() is not provided by the C standard library # TODO: erfcinv() is not provided by the C standard library
#def c_code(self, node, name, inp, out, sub): #def c_code(self, node, name, inp, out, sub):
...@@ -159,6 +172,14 @@ class Gamma(UnaryScalarOp): ...@@ -159,6 +172,14 @@ class Gamma(UnaryScalarOp):
super(Gamma, self).impl(x) super(Gamma, self).impl(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
return gz * gamma(x) * psi(x), return gz * gamma(x) * psi(x),
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
...@@ -190,6 +211,14 @@ class GammaLn(UnaryScalarOp): ...@@ -190,6 +211,14 @@ class GammaLn(UnaryScalarOp):
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
gz, = grads gz, = grads
if x.type in complex_types:
raise NotImplementedError()
if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
return [gz * psi(x)] return [gz * psi(x)]
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
...@@ -224,7 +253,6 @@ class Psi(UnaryScalarOp): ...@@ -224,7 +253,6 @@ class Psi(UnaryScalarOp):
def grad(self, inputs, outputs_gradients): def grad(self, inputs, outputs_gradients):
raise NotImplementedError() raise NotImplementedError()
return [None]
def c_support_code(self): def c_support_code(self):
return ( return (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论