提交 e526d59c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3005 from jsalvatier/proderror

improve error message for grad of ProdWithoutZeros
...@@ -2139,3 +2139,10 @@ class ProdWithoutZeros(CAReduceDtype): ...@@ -2139,3 +2139,10 @@ class ProdWithoutZeros(CAReduceDtype):
def __init__(self, axis=None, dtype=None, acc_dtype=None): def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__(self, mul_without_zeros, axis=axis, CAReduceDtype.__init__(self, mul_without_zeros, axis=axis,
dtype=dtype, acc_dtype=acc_dtype) dtype=dtype, acc_dtype=acc_dtype)
def grad(self, inp, grads):
a, = inp
a_grad = theano.gradient.grad_not_implemented(self, 0, a,
"2nd derivatives of `product(a)` is not currently supported."
"If `a` is guarenteed to contains no zeros, use `product(a, no_zeros_in_input=True)`."
)
return [a_grad]
...@@ -6,6 +6,7 @@ import unittest ...@@ -6,6 +6,7 @@ import unittest
import numpy import numpy
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
from nose.tools import raises
import theano import theano
from theano import gof, scalar, config from theano import gof, scalar, config
...@@ -671,6 +672,13 @@ class test_Prod(unittest.TestCase): ...@@ -671,6 +672,13 @@ class test_Prod(unittest.TestCase):
fn_a0 = theano.function([x], pwz_a0, mode=self.mode) fn_a0 = theano.function([x], pwz_a0, mode=self.mode)
assert numpy.allclose(fn_a0(x_val), [1, 10, 162]) assert numpy.allclose(fn_a0(x_val), [1, 10, 162])
@raises(theano.gradient.NullTypeGradError)
def test_prod_without_zeros_grad(self):
x = theano.tensor.dmatrix()
pwz_a1 = ProdWithoutZeros(axis=0)(x)
pwz_grad = theano.grad(theano.tensor.sum(pwz_a1), x)
fn_a1 = theano.function([x], pwz_grad, mode=self.mode)
@attr('slow') @attr('slow')
def test_other_grad_tests(self): def test_other_grad_tests(self):
x = theano.tensor.dmatrix() x = theano.tensor.dmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论