提交 e526d59c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3005 from jsalvatier/proderror

improve error message for grad of ProdWithoutZeros
......@@ -2139,3 +2139,10 @@ class ProdWithoutZeros(CAReduceDtype):
def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__(self, mul_without_zeros, axis=axis,
dtype=dtype, acc_dtype=acc_dtype)
def grad(self, inp, grads):
a, = inp
a_grad = theano.gradient.grad_not_implemented(self, 0, a,
"2nd derivatives of `product(a)` is not currently supported."
"If `a` is guarenteed to contains no zeros, use `product(a, no_zeros_in_input=True)`."
)
return [a_grad]
......@@ -6,6 +6,7 @@ import unittest
import numpy
from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr
from nose.tools import raises
import theano
from theano import gof, scalar, config
......@@ -671,6 +672,13 @@ class test_Prod(unittest.TestCase):
fn_a0 = theano.function([x], pwz_a0, mode=self.mode)
assert numpy.allclose(fn_a0(x_val), [1, 10, 162])
@raises(theano.gradient.NullTypeGradError)
def test_prod_without_zeros_grad(self):
x = theano.tensor.dmatrix()
pwz_a1 = ProdWithoutZeros(axis=0)(x)
pwz_grad = theano.grad(theano.tensor.sum(pwz_a1), x)
fn_a1 = theano.function([x], pwz_grad, mode=self.mode)
@attr('slow')
def test_other_grad_tests(self):
x = theano.tensor.dmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论