提交 cd56f8d0 authored 作者: John Salvatier's avatar John Salvatier

ProdWithoutZero: grad_not_implemented and test err

上级 599756f1
......@@ -2140,7 +2140,9 @@ class ProdWithoutZeros(CAReduceDtype):
CAReduceDtype.__init__(self, mul_without_zeros, axis=axis,
dtype=dtype, acc_dtype=acc_dtype)
def grad(self, inp, grads):
raise NotImplementedError(
'2nd derivatives of `product(a)` is not currently supported.'
'If `a` is guarenteed to contains no zeros, use `product(a, no_zeros_in_input=True)`.'
a, = inp
a_grad = theano.gradient.grad_not_implemented(self, 0, a,
"2nd derivatives of `product(a)` is not currently supported."
"If `a` is guarenteed to contains no zeros, use `product(a, no_zeros_in_input=True)`."
)
return [a_grad]
......@@ -6,6 +6,7 @@ import unittest
import numpy
from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr
from nose.tools import raises
import theano
from theano import gof, scalar, config
......@@ -671,6 +672,13 @@ class test_Prod(unittest.TestCase):
fn_a0 = theano.function([x], pwz_a0, mode=self.mode)
assert numpy.allclose(fn_a0(x_val), [1, 10, 162])
@raises(theano.gradient.NullTypeGradError)
def test_prod_without_zeros_grad(self):
x = theano.tensor.dmatrix()
pwz_a1 = ProdWithoutZeros(axis=0)(x)
pwz_grad = theano.grad(theano.tensor.sum(pwz_a1), x)
fn_a1 = theano.function([x], pwz_grad, mode=self.mode)
@attr('slow')
def test_other_grad_tests(self):
x = theano.tensor.dmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论