提交 f06506db authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Broaden the search pattern of local_advanced_indexing_crossentropy_onehot_grad,…

Broaden the search pattern of local_advanced_indexing_crossentropy_onehot_grad, to be able to deal with expressions generated by grad(mean(nll))
上级 100d96d0
...@@ -1004,7 +1004,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1004,7 +1004,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
else: else:
return return
# Two cases are supported: # Two base cases are supported:
# 1. AdvancedIncSubtensor( # 1. AdvancedIncSubtensor(
# zeros_like(softmax(x)), # zeros_like(softmax(x)),
# -1. / AdvancedSubtensor(softmax(x), arange(y.shape[0]), y), # -1. / AdvancedSubtensor(softmax(x), arange(y.shape[0]), y),
...@@ -1020,6 +1020,8 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1020,6 +1020,8 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# / softmax(x) # / softmax(x)
# which arises from the gradient of log(softmax(x))[arange(y.shape[0]), y] # which arises from the gradient of log(softmax(x))[arange(y.shape[0]), y]
# #
# TODO: explain variants of case 1.
# TODO: explain other variants of case 2.
# In some cases, in case 2., insted of "-1. like (AdvancedSubtensor...)", # In some cases, in case 2., insted of "-1. like (AdvancedSubtensor...)",
# we can have "-1. like ([-1] * AdvancedSubtensor...)". This case will be # we can have "-1. like ([-1] * AdvancedSubtensor...)". This case will be
# recognized too, but other variants, even with the same shape, might not # recognized too, but other variants, even with the same shape, might not
...@@ -1058,9 +1060,41 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1058,9 +1060,41 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
return return
#else: OK #else: OK
if denom.owner and isinstance(denom.owner.op, tensor.AdvancedSubtensor): if not denom.owner:
return
adv_subtensor = None
if isinstance(denom.owner.op, tensor.AdvancedSubtensor):
adv_subtensor = denom
mult_factor = 1
elif denom.owner.op == tensor.mul:
# Try to find the AdvancedSubtensor node mentionned above
# For now, we support only the case where the other inputs
# of the "mul" node are of integer type, so we are sure it
# does not affect the gradient computation.
for i, input in enumerate(denom.owner.inputs):
if input.owner and isinstance(input.owner.op, tensor.AdvancedSubtensor):
adv_subtensor = input
other_inputs = [in_ for (j, in_) in enumerate(denom.owner.inputs) if j!=i]
if len(other_inputs) == 1:
mult_factor = other_inputs[0]
else:
mult_factor = tensor.mul(*[other_inputs])
# Check that mult_factor is of integer type
if mult_factor.dtype.startswith('int')\
or mult_factor.dtype.startswith('uint'):
#OK
break
else:
# That subtensor was not right
adv_subtensor = None
else:
return
if adv_subtensor is not None:
try: try:
maybe_sm, maybe_rows, maybe_labels = denom.owner.inputs maybe_sm, maybe_rows, maybe_labels = adv_subtensor.owner.inputs
except: except:
return return
...@@ -1069,8 +1103,6 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1069,8 +1103,6 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
#else: OK #else: OK
else: else:
return return
else:
return
# Check that rows is arange(labels.shape[0]) # Check that rows is arange(labels.shape[0])
if not _check_rows_is_arange_len_labels(rows, labels): if not _check_rows_is_arange_len_labels(rows, labels):
...@@ -1115,6 +1147,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1115,6 +1147,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
if incr.owner and incr.owner.op == tensor.fill: if incr.owner and incr.owner.op == tensor.fill:
model, value = incr.owner.inputs model, value = incr.owner.inputs
adv_subtensor = None adv_subtensor = None
mult_factor = 1
if model.owner and isinstance(model.owner.op, tensor.AdvancedSubtensor): if model.owner and isinstance(model.owner.op, tensor.AdvancedSubtensor):
adv_subtensor = model adv_subtensor = model
else: else:
...@@ -1137,7 +1170,17 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1137,7 +1170,17 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
return return
#else: OK #else: OK
if not numpy.all(value.data == -1): # In the base case, value is the constant '-1'
if hasattr(value, 'data') and numpy.all(value.data == -1):
mult_factor = 1
# In the case of -1/denom, if denom is of integer type
elif value.owner and value.owner.op == tensor.true_div:
val_num, val_denom = value.owner.inputs
if hasattr(val_num, 'data') and numpy.all(val_num.data == -1):
if val_denom.dtype.startswith('int')\
or val_denom.dtype.startswith('uint'):
mult_factor = val_denom
else:
return return
else: else:
...@@ -1161,7 +1204,12 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1161,7 +1204,12 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# Dimension check before substitution # Dimension check before substitution
if labels.ndim == 1 and x_var.ndim == 2: if labels.ndim == 1 and x_var.ndim == 2:
return [crossentropy_softmax_1hot_with_bias_dx(tensor.ones_like(sm[:,0]), sm, labels)] if mult_factor is not None:
out_grad = tensor.fill(x_var[:,0], 1./mult_factor)
return [crossentropy_softmax_1hot_with_bias_dx(out_grad, sm, labels)]
else:
print 'mult_factor is None?'
return
else: else:
return return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论