提交 e631efb7 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add (zero) gradient for any() and all().

上级 27ba127c
...@@ -1653,6 +1653,10 @@ class All(CAReduce): ...@@ -1653,6 +1653,10 @@ class All(CAReduce):
ret = super(All, self).make_node(input) ret = super(All, self).make_node(input)
return ret return ret
def grad(self, inp, grads):
x, = inp
return [x.zeros_like(theano.config.floatX)]
class Any(CAReduce): class Any(CAReduce):
""" Applies `bitwise or` to all the values of a tensor along the """ Applies `bitwise or` to all the values of a tensor along the
...@@ -1679,6 +1683,10 @@ class Any(CAReduce): ...@@ -1679,6 +1683,10 @@ class Any(CAReduce):
ret = super(Any, self).make_node(input) ret = super(Any, self).make_node(input)
return ret return ret
def grad(self, inp, grads):
x, = inp
return [x.zeros_like(theano.config.floatX)]
class CAReduceDtype(CAReduce): class CAReduceDtype(CAReduce):
""" """
......
...@@ -1093,6 +1093,37 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -1093,6 +1093,37 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
idx += 1 idx += 1
class TestBitOpReduceGrad(unittest.TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(unittest_tools.fetch_seed())
def test_all_grad(self):
x = tensor.bmatrix('x')
x_all = x.all()
gx = theano.grad(x_all, x)
f = theano.function([x], gx)
x_random = self.rng.binomial(n=1, p=0.5, size=(5, 7)).astype('int8')
for x_val in (x_random,
numpy.zeros_like(x_random),
numpy.ones_like(x_random)):
gx_val = f(x_val)
assert gx_val.shape == x_val.shape
assert numpy.all(gx_val == 0)
def test_any_grad(self):
x = tensor.bmatrix('x')
x_all = x.any()
gx = theano.grad(x_all, x)
f = theano.function([x], gx)
x_random = self.rng.binomial(n=1, p=0.5, size=(5, 7)).astype('int8')
for x_val in (x_random,
numpy.zeros_like(x_random),
numpy.ones_like(x_random)):
gx_val = f(x_val)
assert gx_val.shape == x_val.shape
assert numpy.all(gx_val == 0)
class TestElemwise(unittest_tools.InferShapeTester): class TestElemwise(unittest_tools.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论