提交 fd69ca93 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1694 from nouiz/prod_no_zeros_in_inputs

Make the no_zeros_in_inputs parameter of Prod() available to the user.
......@@ -816,7 +816,7 @@ Reductions
* an *int* - computed along this axis
* a *list of ints* - computed along these axes
.. function:: prod(x, axis=None, dtype=None, keepdims=False, acc_dtype=None)
.. function:: prod(x, axis=None, dtype=None, keepdims=False, acc_dtype=None, no_zeros_in_input=False)
:Parameter: *x* - symbolic Tensor (or compatible)
:Parameter: *axis* - axis or axes along which to compute the product
......@@ -844,6 +844,21 @@ Reductions
- for float dtypes, we use at least float64;
- for complex dtypes, we use at least complex128.
:Parameter: *no_zeros_in_input* - The grad of prod is complicated
as we need to handle 3 different cases: without zeros in the
input reduced group, with 1 zeros or with more zeros.
This could slow you down, but more importantly, we currently
don't support the second derivative of the 3 cases. So you
can not take the second derivative of the default prod().
To remove the handling of the special cases of 0 and so get
some small speed up and allow second derivative set
``no_zeros_in_inputs`` to ``True``. It default to ``False``.
**It is the user responsability to make sure there is no zeros
in the inputs. If there is, the grad will be wrong.**
:Returns: product of every term in *x* along *axis*
axis can be:
......
......@@ -2748,7 +2748,8 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor
def prod(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
def prod(input, axis=None, dtype=None, keepdims=False, acc_dtype=None,
no_zeros_in_input=False):
"""
Computes the product along the given axis(es) of a tensor `input`
......@@ -2762,7 +2763,8 @@ def prod(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
For full documentation see ``tensor.elemwise.Prod``.
"""
out = elemwise.Prod(axis, dtype=dtype, acc_dtype=acc_dtype)(input)
out = elemwise.Prod(axis, dtype=dtype, acc_dtype=acc_dtype,
no_zeros_in_input=no_zeros_in_input)(input)
if keepdims:
out = makeKeepDims(input, out, axis)
......
......@@ -581,6 +581,38 @@ class test_Prod(unittest.TestCase):
#unittest_tools.verify_grad(fn5, [x_val])
def test_prod_no_zeros_in_input(self):
x = theano.tensor.dmatrix()
x_val = numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype='float32')
pwz = Prod(axis=1, no_zeros_in_input=True)(x)
fn = theano.function([x], pwz, mode=self.mode)
assert numpy.allclose(fn(x_val), [6, 120, 504])
pwz = Prod(no_zeros_in_input=True)(x)
g = theano.grad(pwz, x)
gg = theano.grad(g.sum(), x)
fn = theano.function([x], g, mode=self.mode)
assert numpy.allclose(fn(x_val),
[[362880., 181440., 120960.],
[90720., 72576., 60480.],
[51840., 45360., 40320.]])
fn = theano.function([x], gg, mode=self.mode)
assert numpy.allclose(fn(x_val),
[[663696., 422568., 301872.],
[233964., 190800., 161016.],
[139248., 122652., 109584.]])
unittest_tools.verify_grad(Prod(axis=1, no_zeros_in_input=True),
[x_val],
mode=self.mode)
unittest_tools.verify_grad(Prod(no_zeros_in_input=True), [x_val],
mode=self.mode)
def second_deriv(x):
return theano.grad(Prod(no_zeros_in_input=True)(x), x)
unittest_tools.verify_grad(second_deriv, [x_val],
mode=self.mode)
def test_prod_without_zeros(self):
x = theano.tensor.dmatrix()
x_val = numpy.array([[1, 2, 3], [0, 5, 6], [0, 0, 9]], dtype='float32')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论