提交 002872ad authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Remove None from return values of grad.

Also change checks to verify the output dtype of the Op itself, not of the inputs or gradient, because it can depend on different things. The idea is that if the Op's output is continuous, then the gradient should be propagated to the inputs, regardless of whether they are continuous or discrete. However, if the output is discrete, then the gradient wrt the inputs will be a continuous zero.
上级 37ce26a3
差异被折叠。
......@@ -2,11 +2,12 @@
#as scipy is not always available, we treat them separatly
import numpy
import theano
from theano.scalar.basic import (UnaryScalarOp, BinaryScalarOp,
exp, upgrade_to_float,
float_types)
from theano.scalar.basic import (upgrade_to_float_no_complex,
complex_types,
complex_types, discrete_types,
upcast)
imported_scipy_special = False
......@@ -32,12 +33,15 @@ class Erf(UnaryScalarOp):
gz, = grads
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
cst = numpy.asarray(2. / numpy.sqrt(numpy.pi),
dtype=upcast(x.type.dtype, gz.type.dtype))
return gz * cst * exp(-x * x),
else:
return None,
if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
cst = numpy.asarray(2. / numpy.sqrt(numpy.pi),
dtype=upcast(x.type.dtype, gz.type.dtype))
return gz * cst * exp(-x * x),
def c_code(self, node, name, inp, out, sub):
x, = inp
......@@ -60,12 +64,15 @@ class Erfc(UnaryScalarOp):
gz, = grads
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
cst = numpy.asarray(2. / numpy.sqrt(numpy.pi),
dtype=upcast(x.type.dtype, gz.type.dtype))
return - gz * cst * exp(-x * x),
else:
return None,
if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
cst = numpy.asarray(2. / numpy.sqrt(numpy.pi),
dtype=upcast(x.type.dtype, gz.type.dtype))
return - gz * cst * exp(-x * x),
def c_code(self, node, name, inp, out, sub):
x, = inp
......@@ -99,12 +106,15 @@ class Erfinv(UnaryScalarOp):
gz, = grads
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype))
return gz * cst * exp(erfinv(x) ** 2),
else:
return None,
if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype))
return gz * cst * exp(erfinv(x) ** 2),
# TODO: erfinv() is not provided by the C standard library
#def c_code(self, node, name, inp, out, sub):
......@@ -129,12 +139,15 @@ class Erfcinv(UnaryScalarOp):
gz, = grads
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype))
return - gz * cst * exp(erfcinv(x) ** 2),
else:
return None,
if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
cst = numpy.asarray(numpy.sqrt(numpy.pi) / 2.,
dtype=upcast(x.type.dtype, gz.type.dtype))
return - gz * cst * exp(erfcinv(x) ** 2),
# TODO: erfcinv() is not provided by the C standard library
#def c_code(self, node, name, inp, out, sub):
......@@ -159,6 +172,14 @@ class Gamma(UnaryScalarOp):
super(Gamma, self).impl(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
return gz * gamma(x) * psi(x),
def c_code(self, node, name, (x, ), (z, ), sub):
......@@ -190,6 +211,14 @@ class GammaLn(UnaryScalarOp):
def grad(self, inp, grads):
x, = inp
gz, = grads
if x.type in complex_types:
raise NotImplementedError()
if self(x).type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
return [gz * psi(x)]
def c_code(self, node, name, inp, out, sub):
......@@ -224,7 +253,6 @@ class Psi(UnaryScalarOp):
def grad(self, inputs, outputs_gradients):
raise NotImplementedError()
return [None]
def c_support_code(self):
return (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论