提交 2730a0a7 authored 作者: fsavard's avatar fsavard

Changed gradient for tensor.Prod to account for special cases where inputs…

Changed gradient for tensor.Prod to account for special cases where inputs contain zeros. An intermediate (very slow) version is in comment, to be removed in next commit.
上级 7e675cad
...@@ -2,7 +2,7 @@ import sys ...@@ -2,7 +2,7 @@ import sys
import elemwise_cgen as cgen import elemwise_cgen as cgen
import numpy, theano import numpy, theano
from theano import gof from theano import gof, Variable
from theano.gof import Op from theano.gof import Op
from theano import scalar from theano import scalar
from theano.scalar import Scalar from theano.scalar import Scalar
...@@ -1138,9 +1138,20 @@ class Prod(CAReduce): ...@@ -1138,9 +1138,20 @@ class Prod(CAReduce):
difference that this defines the gradient of prod wrt its tensor difference that this defines the gradient of prod wrt its tensor
input. input.
""" """
def __init__(self, axis = None): def __init__(self, axis=None, no_zeros_in_input=False):
CAReduce.__init__(self, scalar.mul, axis) CAReduce.__init__(self, scalar.mul, axis)
self.no_zeros_in_input = no_zeros_in_input
def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.axis == other.axis and self.no_zeros_in_input == other.no_zeros_in_input
def __hash__(self):
if self.axis is None:
return hash(self.scalar_op) ^ hash(self.no_zeros_in_input)
else:
return hash(self.scalar_op) ^ hash(tuple(self.axis)) ^ hash(self.no_zeros_in_input)
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
# we want to protect against overflow # we want to protect against overflow
return dict( return dict(
...@@ -1152,31 +1163,56 @@ class Prod(CAReduce): ...@@ -1152,31 +1163,56 @@ class Prod(CAReduce):
uint32='uint64', uint32='uint64',
).get(idtype, idtype) ).get(idtype, idtype)
def grad(self, (x, ), (gz, )): def grad(self, (prod_in, ), (gz, )):
if x.dtype[0:3] in ('int','uin'): if prod_in.dtype[0:3] in ('int','uin'):
return [None] return [None]
prod_out = self(x)
gz = as_tensor_variable(gz) gz = as_tensor_variable(gz)
axis = self.axis axis = self.axis
if axis is None: if axis is None:
axis = range(x.type.ndim) axis = range(prod_in.type.ndim)
if axis == (): if axis == ():
return gz, return gz,
new_dims = [] new_dims = []
i = 0 i = 0
for j, _ in enumerate(x.type.broadcastable): for j, _ in enumerate(prod_in.type.broadcastable):
if j in axis: if j in axis:
new_dims.append('x') new_dims.append('x')
else: else:
new_dims.append(i) new_dims.append(i)
i += 1 i += 1
p_gz = theano.tensor.mul(prod_out, gz) prod_out = self(prod_in).dimshuffle(new_dims)
p_gz = DimShuffle(p_gz.type.broadcastable, new_dims)(p_gz) gz = gz.dimshuffle(new_dims)
return [Elemwise(scalar.true_div)(p_gz, x)] grad_case_without_zeros = (gz * prod_out / prod_in)
if self.no_zeros_in_input:
# this handles inputs with zeros, but only certain input shapes
return [grad_case_without_zeros]
else:
T = theano.tensor
where_zeros = T.eq(prod_in, 0.0)
sum_where_zeros = T.sum(where_zeros, axis=self.axis)
groups_with_single_zero = T.eq(sum_where_zeros, 1.0).dimshuffle(new_dims)
where_single_zero = groups_with_single_zero * where_zeros
where_gz_not_zero = T.neq(gz, 0.0)
where_to_take_prod_without_zeros = \
groups_with_single_zero * where_gz_not_zero
prod_without_zeros_in = where_to_take_prod_without_zeros * prod_in
# TODO: put lazy switch here, if it'd work
# this is pretty efficient already (no multiplication if 0), but
# it'd be even better if we had a lazy if per element
prod_without_zeros = ProdWithoutZeros(axis=self.axis)(prod_without_zeros_in)
prod_without_zeros = prod_without_zeros.dimshuffle(new_dims)
groups_without_zeros = T.eq(sum_where_zeros, 0.0).dimshuffle(new_dims)
final_grad = T.switch(groups_without_zeros, grad_case_without_zeros,
T.switch(where_single_zero, prod_without_zeros, 0.0) * gz)
return [final_grad]
def __str__(self): def __str__(self):
if self.axis is None: if self.axis is None:
...@@ -1184,4 +1220,117 @@ class Prod(CAReduce): ...@@ -1184,4 +1220,117 @@ class Prod(CAReduce):
else: else:
return "Prod{%s}" % ", ".join(map(str, self.axis)) return "Prod{%s}" % ", ".join(map(str, self.axis))
def c_code_cache_version(self):
return ()
class MulWithoutZeros(scalar.BinaryScalarOp):
identity = 1.
commutative = True
associative = True
def impl(self, *inputs):
if inputs[0] == 0.:
return inputs[1]
if inputs[1] == 0.:
return inputs[0]
return inputs[1] * inputs[2]
def c_code(self, node, name, (x,y), (z, ), sub):
return ("%(z)s = ((%(x)s == 0) ? (%(y)s) : " + \
"((%(y)s == 0) ? (%(x)s) : ((%(y)s)*(%(x)s))) );") % locals()
mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name = 'mul_without_zeros')
class ProdWithoutZeros(CAReduce):
def __init__(self, axis = None):
CAReduce.__init__(self, mul_without_zeros, axis)
def _output_dtype(self, idtype):
# we want to protect against overflow
return dict(
int8='int32',
int16='int32',
int32='int64',
uint8='uint32',
uint16='uint32',
uint32='uint64',
).get(idtype, idtype)
def __str__(self):
if self.axis is None:
return "ProdWithoutZeros"
else:
return "ProdWithoutZeros{%s}" % ", ".join(map(str, self.axis))
"""
class ProdGrad(Op):
'''
grad is more complex than it might seem, as the cases where:
- all inputs in a group are non-zero
- only one input in the group is zero
- more than one input the group is zero
must be handled separately
What follows is a very restricted and inefficient way to compute this
gradient (inefficient in that it's coded in Python and it doesn't share
computations which might be shared with the computation of the products
themselves; restricted in that it only handles a very simple case).
'''
def __init__(self,axis=1):
self.axis = axis
def __eq__(self, other):
return type(self) == type(other) and self.axis == other.axis
def __hash__(self):
return hash(type(self)) ^ hash(self.axis)
def __str__(self):
return '%s{%s}' % (self.__class__.__name__, self.axis)
def make_node(self, prod_in, gz):
# make_node should only be called by the grad function of DownsampleFactorMax,
# so these asserts should not fail.
assert isinstance(prod_in, Variable) and prod_in.ndim==2
return Apply(self, [prod_in, gz], [prod_in.type()])
def perform(self, node, (prod_in, gz), (gx_stg,)):
gx = numpy.zeros_like(prod_in)
# only supported case, for the moment,
# therefore is when we have a 2d tensor (matrix)
# where we take the prod over rows
if self.axis is None or self.axis != 1 and self.axis != (1,) and self.axis != [1,]:
raise NotImplementedError("tensor.Prod default gradient only handles the case where the input is a matrix and product is taken over rows. If you're sure the input does not contain any zeros, a simpler and more efficient gradient may be used; set the Prod constructor option 'no_zeros_in_input'.")
for row_idx in range(prod_in.shape[0]):
row = prod_in[row_idx,:]
zero_count = 0
first_zero_position = 0
prod = gz[row_idx]
prod_without_zeros = prod
for el_idx, el in enumerate(row):
if el == 0.:
zero_count += 1
first_zero_position = el_idx
prod = 0.
else:
prod *= el
prod_without_zeros *= el
if zero_count == 0:
gx[row_idx, :] = prod / prod_in[row_idx,:]
elif zero_count == 1:
# all elements (but the one = 0) get a gradient value of 0
gx[row_idx, first_zero_position] = prod_without_zeros
else:
# case with more than 1 zeros, where everything ends up 0.
# just leave it the way it was initialized (with zeros)
pass
gx_stg[0] = gx
"""
...@@ -258,49 +258,99 @@ class test_Prod(unittest.TestCase): ...@@ -258,49 +258,99 @@ class test_Prod(unittest.TestCase):
def setUp(self): def setUp(self):
unittest_tools.seed_rng() unittest_tools.seed_rng()
def test_prod_grad(self):
x_val = numpy.asarray([[1,2,3],[4,5,6],[7,8,9]], dtype='float32')
x = theano.tensor.dmatrix()
p = Prod(axis=0)(x)
# sanity check
fn = theano.function([x], [p])
assert numpy.allclose(fn(x_val), numpy.array([ 28., 80., 162.]))
# very basic case for the product; no broadcasting in x
g = theano.tensor.grad(p.sum(), x)
g_fn = theano.function([x], g)
assert numpy.allclose(g_fn(x_val),
numpy.asarray([[28.,40.,54.],[7.,16.,27.],[4.,10.,18.]]))
# now with some tranposition in input
x_bc = x.dimshuffle(1, 0)
p_bc = Prod(axis=0)(x_bc)
p_bc_sum = p_bc.sum()
g_bc = theano.tensor.grad(p_bc_sum, x)
g_fn_bc = theano.function([x], [p_bc,g_bc])
p_bc_ret, g_bc_ret = g_fn_bc(x_val)
assert numpy.allclose(p_bc_ret, numpy.array([ 6., 120., 504.]))
assert numpy.allclose(g_bc_ret,
numpy.asarray([[6.,3.,2.],[30.,24.,20.],[72.,63.,56.]]))
def test_verify_grad(self): def test_verify_grad(self):
# including zeros, as the case with zeros is important
# (and special cases: 1 zero in the row, more than 1 zero in the row)
x_val = numpy.asarray([[1,2,3],[4,5,6],[7,8,9]], dtype='float32') x_val = numpy.asarray([[1,2,3],[4,5,6],[7,8,9]], dtype='float32')
x = theano.tensor.dmatrix() x = theano.tensor.dmatrix()
# now with verify_grad # now with verify_grad
unittest_tools.verify_grad(Prod(axis=0), [x_val]) unittest_tools.verify_grad(Prod(axis=1), [x_val])
# second time, with some added complexity # second time, with some added complexity
# verify_grad takes the sum of the matrices anyway # verify_grad takes the sum of the matrices anyway
def fn(x2): def fn(x2):
return theano.tensor.sqr(Prod(axis=0)(x2)) return theano.tensor.sqr(Prod(axis=1)(x2))
unittest_tools.verify_grad(fn, [x_val]) unittest_tools.verify_grad(fn, [x_val])
def test_verify_grad_with_zeros(self):
# including zeros, as the case with zeros is important
# (and special cases: 1 zero in the row, more than 1 zero in the row)
x_val = numpy.asarray([[1.,2.,3.],[0.,5.,6.],[0.,0.,9.]], dtype='float32')
x = theano.tensor.dmatrix()
# sanity check
x2 = theano.tensor.dmatrix()
p = Prod(axis=1)(x)
p2 = Prod(axis=1)(x2)
fn = theano.function([x,x2],[p-p2])
#print "hand computed diff for each row"
x2_val = numpy.asarray([[1., 2., 3.003], [0.003,5.,6], [0.,0.,9.01]])
#print fn(x_val, x2_val)
fn2 = theano.function([x],[theano.tensor.grad(p.sum(),x)])
#print "real grad"
#print fn2(x_val)
fn3 = theano.function([x],[p])
assert numpy.allclose(fn3(x_val), [6.,0.,0.])
# now with verify_grad
unittest_tools.verify_grad(Prod(axis=1), [x_val])
# second time, with some added complexity
# verify_grad takes the sum of the matrices anyway
#def fn5(x5):
# return theano.tensor.sqr(Prod(axis=1)(x5))
#x4 = theano.tensor.dmatrix()
#p4 = theano.tensor.sqr(Prod(axis=1)(x4))
#fn4 = theano.function([x4], p4)
#print "with sqr"
#print fn4(x_val)
#print fn4(x2_val)
#unittest_tools.verify_grad(fn5, [x_val])
def test_prod_without_zeros(self):
x = theano.tensor.dmatrix()
x_val = numpy.array([[1,2,3],[0,5,6],[0,0,9]], dtype='float32')
pwz = ProdWithoutZeros(axis=1)(x)
fn = theano.function([x], pwz)
assert numpy.allclose(fn(x_val), [6,30,9])
pwz_a0 = ProdWithoutZeros(axis=0)(x)
fn_a0 = theano.function([x], pwz_a0)
assert numpy.allclose(fn_a0(x_val), [1, 10, 162])
def test_other_grad_tests(self):
x = theano.tensor.dmatrix()
x_val1 = numpy.array([[1,2,3],[0,5,6],[0,0,9]], dtype='float32')
x_val2 = numpy.array([[1,2,0],[0,5,6],[7,8,9],[9,10,0]], dtype='float32')
rng = rng = numpy.random.RandomState(43)
p = Prod(axis=1)
grad_p = theano.tensor.grad(p(x).sum(), x)
grad_fn = theano.function([x], grad_p)
assert numpy.allclose(grad_fn(x_val1), [[6.,3.,2.],[30.,0.,0.],[0.,0.,0.]])
assert numpy.allclose(grad_fn(x_val2), [[0., 0., 2.], [30., 0., 0.], [72., 63., 56.], [0., 0., 90.]])
p_axis0 = Prod(axis=0)
grad_p_axis0 = theano.tensor.grad(p_axis0(x).sum(), x)
grad_fn_axis0 = theano.function([x], grad_p_axis0)
assert numpy.allclose(grad_fn_axis0(x_val2), [[0., 400., 0.],[63., 160., 0.], [0., 100., 0.], [0., 80., 0.]])
tensor.verify_grad(p, [x_val1], rng=rng)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() #unittest.main()
#suite = unittest.TestSuite([test_Prod('test_prod_grad')]) suite = unittest.TestSuite([test_Prod('test_verify_grad')])
#suite.addTest(test_Prod('test_verify_grad')) suite.addTest(test_Prod('test_verify_grad_with_zeros'))
#unittest.TextTestRunner().run(suite) suite.addTest(test_Prod('test_prod_without_zeros'))
suite.addTest(test_Prod('test_other_grad_tests'))
unittest.TextTestRunner().run(suite)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论