提交 5d367913 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

local_advanced_indexing_crossentropy_onehot_grad now supports cases where the

output gradient is not 1.
上级 abf7bc89
...@@ -1020,13 +1020,18 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1020,13 +1020,18 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# / softmax(x) # / softmax(x)
# which arises from the gradient of log(softmax(x))[arange(y.shape[0]), y] # which arises from the gradient of log(softmax(x))[arange(y.shape[0]), y]
# #
# TODO: explain variants of case 1.
# TODO: explain other variants of case 2.
# In some cases, in case 2., insted of "-1. like (AdvancedSubtensor...)", # In some cases, in case 2., insted of "-1. like (AdvancedSubtensor...)",
# we can have "-1. like ([-1] * AdvancedSubtensor...)". This case will be # we can have "-1. like ([-1] * AdvancedSubtensor...)". This case will be
# recognized too, but other variants, even with the same shape, might not # recognized too, but other variants, even with the same shape, might not
# (yet). # (yet).
# The base cases are realized when the gradient of the
# cost wrt the output is equal to 1. When this gradient
# has another (scalar) value, it typically appears in the
# second argument of AdvancedIncSubtensor. In that case, we
# try to extract it, and feed it as the output gradient of
# crossentropy_softmax_1hot_with_bias_dx.
# #
# N.B. Regarding clients -- This substitution is important for numerical stability, so we # N.B. Regarding clients -- This substitution is important for numerical stability, so we
# perform the substitution even when intermediate values have multiple clients. # perform the substitution even when intermediate values have multiple clients.
...@@ -1052,43 +1057,60 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1052,43 +1057,60 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
else: else:
return return
# Check that incr has the form -1./sm[arange(len(y)), y] # In the base case (output gradient = 1), incr is -1./sm[arange(len(y)), y]
# Here, we are looking for the AdvancedSubtensor term (sm[arange(len(y)), y]),
# the remaining of the expression will be used to compute outgrad_factor
# outgrad_factor will be constructed in 3 steps as follow:
# outgrad_factor = +/- 1 (initial sign)
# outgrad_factor *= numerator
# outgrad_factor /= denominator
adv_subtensor = None
outgrad_factor = 1.
# If there's a 'minus' sign before the whole expression, put it in
# outgrad_factor and iterate
if incr.owner and incr.owner.op == tensor.neg:
outgrad_factor = -1.
incr = incr.owner.inputs[0]
if incr.owner and incr.owner.op == tensor.true_div: if incr.owner and incr.owner.op == tensor.true_div:
num, denom = incr.owner.inputs num, denom = incr.owner.inputs
if not (hasattr(num, 'data') and numpy.all(num.data == -1)): # set outgrad_factor according to the numerator,
# it may be divided later
if hasattr(num, 'data') and numpy.all(num.data == -1):
# Base case, num is -1
outgrad_factor *= 1.
elif numpy.all(num.broadcastable):
# Otherwise, it should be a scalar
outgrad_factor *= -num
else:
return return
#else: OK
if not denom.owner: if not denom.owner:
return return
adv_subtensor = None
if isinstance(denom.owner.op, tensor.AdvancedSubtensor): if isinstance(denom.owner.op, tensor.AdvancedSubtensor):
# Base case
adv_subtensor = denom adv_subtensor = denom
mult_factor = 1 outgrad_factor /= 1.
elif denom.owner.op == tensor.mul: elif denom.owner.op == tensor.mul:
# Try to find the AdvancedSubtensor node mentionned above # Try to find the AdvancedSubtensor node mentionned above,
# For now, we support only the case where the other inputs # and a scalar that is equal to the output gradient
# of the "mul" node are of integer type, so we are sure it
# does not affect the gradient computation.
for i, input in enumerate(denom.owner.inputs): for i, input in enumerate(denom.owner.inputs):
if input.owner and isinstance(input.owner.op, tensor.AdvancedSubtensor): if input.owner and isinstance(input.owner.op, tensor.AdvancedSubtensor):
adv_subtensor = input
other_inputs = [in_ for (j, in_) in enumerate(denom.owner.inputs) if j!=i] other_inputs = [in_ for (j, in_) in enumerate(denom.owner.inputs) if j!=i]
if len(other_inputs) == 1: if len(other_inputs) == 1:
mult_factor = other_inputs[0] rest = other_inputs[0]
else: else:
mult_factor = tensor.mul(*[other_inputs]) rest = tensor.mul(*[other_inputs])
# Check that mult_factor is of integer type # Check that rest is a scalar
if mult_factor.dtype.startswith('int')\ if numpy.all(rest.broadcastable):
or mult_factor.dtype.startswith('uint'): adv_subtensor = input
#OK outgrad_factor /= rest
break break
else:
# That subtensor was not right
adv_subtensor = None
else: else:
return return
...@@ -1101,6 +1123,8 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1101,6 +1123,8 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
if not (maybe_sm is sm and maybe_rows is rows and maybe_labels is labels): if not (maybe_sm is sm and maybe_rows is rows and maybe_labels is labels):
return return
#else: OK #else: OK
else:
return
else: else:
return return
...@@ -1147,7 +1171,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1147,7 +1171,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
if incr.owner and incr.owner.op == tensor.fill: if incr.owner and incr.owner.op == tensor.fill:
model, value = incr.owner.inputs model, value = incr.owner.inputs
adv_subtensor = None adv_subtensor = None
mult_factor = 1 outgrad_factor = None
if model.owner and isinstance(model.owner.op, tensor.AdvancedSubtensor): if model.owner and isinstance(model.owner.op, tensor.AdvancedSubtensor):
adv_subtensor = model adv_subtensor = model
else: else:
...@@ -1169,17 +1193,16 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1169,17 +1193,16 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
if not (maybe_log_sm is log_sm and maybe_rows is rows and maybe_labels is labels): if not (maybe_log_sm is log_sm and maybe_rows is rows and maybe_labels is labels):
return return
#else: OK #else: OK
else:
return
# In the base case, value is the constant '-1' # In the base case, value is the constant '-1'
if hasattr(value, 'data') and numpy.all(value.data == -1): if hasattr(value, 'data') and numpy.all(value.data == -1):
mult_factor = 1 outgrad_factor = 1.
# In the case of -1/denom, if denom is of integer type # Otherwise, it should be a scalar, and the output gradient
elif value.owner and value.owner.op == tensor.true_div: # would be -value
val_num, val_denom = value.owner.inputs elif numpy.all(value.broadcastable):
if hasattr(val_num, 'data') and numpy.all(val_num.data == -1): outgrad_factor = -value
if val_denom.dtype.startswith('int')\
or val_denom.dtype.startswith('uint'):
mult_factor = val_denom
else: else:
return return
...@@ -1204,11 +1227,10 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1204,11 +1227,10 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# Dimension check before substitution # Dimension check before substitution
if labels.ndim == 1 and x_var.ndim == 2: if labels.ndim == 1 and x_var.ndim == 2:
if mult_factor is not None: if outgrad_factor is not None:
out_grad = tensor.fill(x_var[:,0], 1./mult_factor) out_grad = tensor.fill(x_var[:,0], outgrad_factor)
return [crossentropy_softmax_1hot_with_bias_dx(out_grad, sm, labels)] return [crossentropy_softmax_1hot_with_bias_dx(out_grad, sm, labels)]
else: else:
print 'mult_factor is None?'
return return
else: else:
return return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论