提交 b275d065 authored 作者: fsavard's avatar fsavard

Corrected DEBUG_MODE and FAST_COMPILE tests for grad of Prod. Added comments on…

Corrected DEBUG_MODE and FAST_COMPILE tests for grad of Prod. Added comments on mechanism for this grad.
上级 42d40e33
...@@ -1169,9 +1169,57 @@ class Prod(CAReduce): ...@@ -1169,9 +1169,57 @@ class Prod(CAReduce):
).get(idtype, idtype) ).get(idtype, idtype)
def grad(self, (prod_in, ), (gz, )): def grad(self, (prod_in, ), (gz, )):
'''
The grad of this Op could be very easy, it is was not for the case
where zeros are present in a given "group" (ie. elements reduced
together to form the product).
If no zeros are found in the elements of the product, then the
partial derivative of the product relative to one of the elements
(one of the inputs) is simply the product of the other elements.
That's easy to see from the chain rule.
Now the trick (with no zeros) is to take the overall product, then
for every original element, the partial derivative is given by
this product divided by the element itself (which equals the product
of the other terms). This is easy to do by broadcasting the original
product.
(Note that we also need to broadcast-multiply by the "incoming gradient",
ie. the gradient of the cost relative to the output/product).
-----
With zeros, things get more complicated. For a given group, we have 3
cases:
* No zeros in the group. Use previous trick.
* If only one zero is present, then the gradient for that element is
non-zero, but is zero for all others.
* If more than one zero is present, then all the derivatives are zero.
For the last two cases (with 1 or more zeros), we can't use the division
trick, as this gives divisions by 0.
Implementing that case-by-case logic is not as trivial, so a bunch of
hacks are piled down here to do it. Notably, for the "only one zero"
case, there's a special Op that computes the product of the elements
in the group, minus the zero (see ProdWithoutZero). The trick is then
to use the division trick for groups with no zero, to use the
ProdWithoutZeros op where there's only one zero, and to output a
derivative of zero for any element part of a group with more than
one zero.
I do this by first counting the number of zeros in each group (see
the "T.eq()" bits), then taking this or that behavior (see T.switch)
based on the result of this count.
'''
if prod_in.dtype[0:3] in ('int','uin'): if prod_in.dtype[0:3] in ('int','uin'):
return [None] return [None]
# Prepare the broadcasting that is used everywhere to broadcast
# over the original groups (ie. broadcast over the elements of a given
# product)
gz = as_tensor_variable(gz) gz = as_tensor_variable(gz)
axis = self.axis axis = self.axis
if axis is None: if axis is None:
...@@ -1187,9 +1235,13 @@ class Prod(CAReduce): ...@@ -1187,9 +1235,13 @@ class Prod(CAReduce):
new_dims.append(i) new_dims.append(i)
i += 1 i += 1
# result of the product, broadcastable over groups
prod_out = self(prod_in).dimshuffle(new_dims) prod_out = self(prod_in).dimshuffle(new_dims)
# incoming gradient, broadcastable over groups
gz = gz.dimshuffle(new_dims) gz = gz.dimshuffle(new_dims)
# division trick if we don't have zeros. This will contain
# NaNs to be eliminated in the T.switch if we do have zeros.
grad_case_without_zeros = (gz * prod_out / prod_in) grad_case_without_zeros = (gz * prod_out / prod_in)
if self.no_zeros_in_input: if self.no_zeros_in_input:
...@@ -1200,11 +1252,20 @@ class Prod(CAReduce): ...@@ -1200,11 +1252,20 @@ class Prod(CAReduce):
where_zeros = T.eq(prod_in, 0.0) where_zeros = T.eq(prod_in, 0.0)
sum_where_zeros = T.sum(where_zeros, axis=self.axis) sum_where_zeros = T.sum(where_zeros, axis=self.axis)
groups_with_single_zero = T.eq(sum_where_zeros, 1.0).dimshuffle(new_dims) groups_with_single_zero = T.eq(sum_where_zeros, 1).dimshuffle(new_dims)
# tensor with 0 everywhere except for those places where
# a 0 part of a group with a single zero was to be found
where_single_zero = groups_with_single_zero * where_zeros where_single_zero = groups_with_single_zero * where_zeros
# further optimization to avoid computing ProdWithoutZeros
# if the incoming gradient is 0
where_gz_not_zero = T.neq(gz, 0.0) where_gz_not_zero = T.neq(gz, 0.0)
# only take ProdWithoutZeros for the groups with single zeros
# with non-null incoming gradient
where_to_take_prod_without_zeros = \ where_to_take_prod_without_zeros = \
groups_with_single_zero * where_gz_not_zero groups_with_single_zero * where_gz_not_zero
# preprocess the original input so that we set 0 everywhere
# except for groups that contain a single zero, to avoid computing
# multiplications on other groups
prod_without_zeros_in = where_to_take_prod_without_zeros * prod_in prod_without_zeros_in = where_to_take_prod_without_zeros * prod_in
# TODO: put lazy switch here, if it'd work # TODO: put lazy switch here, if it'd work
# this is pretty efficient already (no multiplication if 0), but # this is pretty efficient already (no multiplication if 0), but
...@@ -1212,7 +1273,8 @@ class Prod(CAReduce): ...@@ -1212,7 +1273,8 @@ class Prod(CAReduce):
prod_without_zeros = ProdWithoutZeros(axis=self.axis)(prod_without_zeros_in) prod_without_zeros = ProdWithoutZeros(axis=self.axis)(prod_without_zeros_in)
prod_without_zeros = prod_without_zeros.dimshuffle(new_dims) prod_without_zeros = prod_without_zeros.dimshuffle(new_dims)
groups_without_zeros = T.eq(sum_where_zeros, 0.0).dimshuffle(new_dims) groups_without_zeros = T.eq(sum_where_zeros, 0).dimshuffle(new_dims)
final_grad = T.switch(groups_without_zeros, grad_case_without_zeros, final_grad = T.switch(groups_without_zeros, grad_case_without_zeros,
T.switch(where_single_zero, prod_without_zeros, 0.0) * gz) T.switch(where_single_zero, prod_without_zeros, 0.0) * gz)
...@@ -1228,19 +1290,29 @@ class Prod(CAReduce): ...@@ -1228,19 +1290,29 @@ class Prod(CAReduce):
return () return ()
class MulWithoutZeros(scalar.BinaryScalarOp): class MulWithoutZeros(scalar.BinaryScalarOp):
identity = 1. # "identity" here is zero, as in Reduce we don't want to start
# with reducing (1, something_else): this leads to the erronous
# case where a vector of zeros is reduced by binary reductions
# of (1, 0), which always ends up as 1 (ie. the result for
# the c version, for the product of [0,0,0], is 1.0)
identity = 0.
commutative = True commutative = True
associative = True associative = True
def impl(self, *inputs): def impl(self, x, y):
if inputs[0] == 0.: print x,y
return inputs[1] if x == 0:
if inputs[1] == 0.: return y
return inputs[0] if y == 0:
return inputs[1] * inputs[2] return x
return x*y
def c_code(self, node, name, (x,y), (z, ), sub): def c_code(self, node, name, (x,y), (z, ), sub):
return ("%(z)s = ((%(x)s == 0) ? (%(y)s) : " + \ return ("%(z)s = ((%(x)s == 0) ? (%(y)s) : " + \
"((%(y)s == 0) ? (%(x)s) : ((%(y)s)*(%(x)s))) );") % locals() "((%(y)s == 0) ? (%(x)s) : ((%(y)s)*(%(x)s))) );") % locals()
def c_code_cache_version(self):
return (1,)
mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name = 'mul_without_zeros') mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name = 'mul_without_zeros')
class ProdWithoutZeros(CAReduce): class ProdWithoutZeros(CAReduce):
......
...@@ -258,23 +258,28 @@ class test_Prod(unittest.TestCase): ...@@ -258,23 +258,28 @@ class test_Prod(unittest.TestCase):
def setUp(self): def setUp(self):
unittest_tools.seed_rng() unittest_tools.seed_rng()
# we want to allow nans in the matrices, so we disable this DEBUG_MODE check
mode = theano.compile.mode.get_default_mode()
mode = copy(mode)
mode.check_isfinite = False
self.mode = mode
def test_verify_grad(self): def test_verify_grad(self):
# including zeros, as the case with zeros is important # including zeros, as the case with zeros is important
# (and special cases: 1 zero in the row, more than 1 zero in the row) # (and special cases: 1 zero in the row, more than 1 zero in the row)
x_val = numpy.asarray([[1,2,3],[4,5,6],[7,8,9]], dtype='float32') x_val = numpy.asarray([[1,2,3],[4,5,6],[7,8,9]], dtype='float32')
x = theano.tensor.dmatrix() x = theano.tensor.dmatrix()
# now with verify_grad # now with verify_grad
unittest_tools.verify_grad(Prod(axis=1), [x_val]) unittest_tools.verify_grad(Prod(axis=1), [x_val], mode=self.mode)
# second time, with some added complexity # second time, with some added complexity
# verify_grad takes the sum of the matrices anyway # verify_grad takes the sum of the matrices anyway
def fn(x2): def fn(x2):
return theano.tensor.sqr(Prod(axis=1)(x2)) return theano.tensor.sqr(Prod(axis=1)(x2))
unittest_tools.verify_grad(fn, [x_val]) unittest_tools.verify_grad(fn, [x_val], mode=self.mode)
def test_verify_grad_with_zeros(self): def test_verify_grad_with_zeros(self):
...@@ -287,18 +292,18 @@ class test_Prod(unittest.TestCase): ...@@ -287,18 +292,18 @@ class test_Prod(unittest.TestCase):
x2 = theano.tensor.dmatrix() x2 = theano.tensor.dmatrix()
p = Prod(axis=1)(x) p = Prod(axis=1)(x)
p2 = Prod(axis=1)(x2) p2 = Prod(axis=1)(x2)
fn = theano.function([x,x2],[p-p2]) fn = theano.function([x,x2],[p-p2], mode=self.mode)
#print "hand computed diff for each row" #print "hand computed diff for each row"
x2_val = numpy.asarray([[1., 2., 3.003], [0.003,5.,6], [0.,0.,9.01]]) x2_val = numpy.asarray([[1., 2., 3.003], [0.003,5.,6], [0.,0.,9.01]])
#print fn(x_val, x2_val) #print fn(x_val, x2_val)
fn2 = theano.function([x],[theano.tensor.grad(p.sum(),x)]) fn2 = theano.function([x],[theano.tensor.grad(p.sum(),x)], mode=self.mode)
#print "real grad" #print "real grad"
#print fn2(x_val) #print fn2(x_val)
fn3 = theano.function([x],[p]) fn3 = theano.function([x],[p], mode=self.mode)
assert numpy.allclose(fn3(x_val), [6.,0.,0.]) assert numpy.allclose(fn3(x_val), [6.,0.,0.])
# now with verify_grad # now with verify_grad
unittest_tools.verify_grad(Prod(axis=1), [x_val]) unittest_tools.verify_grad(Prod(axis=1), [x_val], mode=self.mode)
# second time, with some added complexity # second time, with some added complexity
# verify_grad takes the sum of the matrices anyway # verify_grad takes the sum of the matrices anyway
...@@ -318,11 +323,11 @@ class test_Prod(unittest.TestCase): ...@@ -318,11 +323,11 @@ class test_Prod(unittest.TestCase):
x = theano.tensor.dmatrix() x = theano.tensor.dmatrix()
x_val = numpy.array([[1,2,3],[0,5,6],[0,0,9]], dtype='float32') x_val = numpy.array([[1,2,3],[0,5,6],[0,0,9]], dtype='float32')
pwz = ProdWithoutZeros(axis=1)(x) pwz = ProdWithoutZeros(axis=1)(x)
fn = theano.function([x], pwz) fn = theano.function([x], pwz, mode=self.mode)
assert numpy.allclose(fn(x_val), [6,30,9]) assert numpy.allclose(fn(x_val), [6,30,9])
pwz_a0 = ProdWithoutZeros(axis=0)(x) pwz_a0 = ProdWithoutZeros(axis=0)(x)
fn_a0 = theano.function([x], pwz_a0) fn_a0 = theano.function([x], pwz_a0, mode=self.mode)
assert numpy.allclose(fn_a0(x_val), [1, 10, 162]) assert numpy.allclose(fn_a0(x_val), [1, 10, 162])
def test_other_grad_tests(self): def test_other_grad_tests(self):
...@@ -333,24 +338,33 @@ class test_Prod(unittest.TestCase): ...@@ -333,24 +338,33 @@ class test_Prod(unittest.TestCase):
p = Prod(axis=1) p = Prod(axis=1)
grad_p = theano.tensor.grad(p(x).sum(), x) grad_p = theano.tensor.grad(p(x).sum(), x)
grad_fn = theano.function([x], grad_p) grad_fn = theano.function([x], grad_p, mode=self.mode)
assert numpy.allclose(grad_fn(x_val1), [[6.,3.,2.],[30.,0.,0.],[0.,0.,0.]]) assert numpy.allclose(grad_fn(x_val1), [[6.,3.,2.],[30.,0.,0.],[0.,0.,0.]])
assert numpy.allclose(grad_fn(x_val2), [[0., 0., 2.], [30., 0., 0.], [72., 63., 56.], [0., 0., 90.]]) assert numpy.allclose(grad_fn(x_val2), [[0., 0., 2.], [30., 0., 0.], [72., 63., 56.], [0., 0., 90.]])
p_axis0 = Prod(axis=0) p_axis0 = Prod(axis=0)
grad_p_axis0 = theano.tensor.grad(p_axis0(x).sum(), x) grad_p_axis0 = theano.tensor.grad(p_axis0(x).sum(), x)
grad_fn_axis0 = theano.function([x], grad_p_axis0) grad_fn_axis0 = theano.function([x], grad_p_axis0, mode=self.mode)
assert numpy.allclose(grad_fn_axis0(x_val2), [[0., 400., 0.],[63., 160., 0.], [0., 100., 0.], [0., 80., 0.]]) assert numpy.allclose(grad_fn_axis0(x_val2), [[0., 400., 0.],[63., 160., 0.], [0., 100., 0.], [0., 80., 0.]])
tensor.verify_grad(p, [x_val1], rng=rng) tensor.verify_grad(p, [x_val1], rng=rng, mode=self.mode)
def test_mul_without_zeros_zeros(self):
a = numpy.zeros((3,3))
x = theano.tensor.dmatrix()
mul1 = ProdWithoutZeros(axis=0)(x)
fn_debug = theano.function([x], mul1, mode='DEBUG_MODE')
fn_debug(a)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() #unittest.main()
#suite = unittest.TestSuite([test_Prod('test_verify_grad')]) suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')])
#suite.addTest(test_Prod('test_verify_grad_with_zeros')) #suite.addTest(test_Prod('test_verify_grad_with_zeros'))
#suite.addTest(test_Prod('test_prod_without_zeros')) #suite.addTest(test_Prod('test_prod_without_zeros'))
#suite.addTest(test_Prod('test_other_grad_tests')) #suite.addTest(test_Prod('test_other_grad_tests'))
#unittest.TextTestRunner().run(suite) unittest.TextTestRunner().run(suite)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论