提交 673a30b3 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Move logic of y_from_indexing to _check_rows_is_arange_len_labels

上级 2668282c
......@@ -895,32 +895,30 @@ def local_argmax_pushdown(node):
return tensor._max_and_argmax(pre_x+tensor.DimShuffle(pre_bias.broadcastable,
('x',0))(pre_bias), axis)
# Utility function used by the two next optimizations
def _check_rows_is_arange_len_labels(rows, labels):
'''Check that 'rows' is the same node as T.arange(labels.shape[0])'''
if rows.owner and isinstance(rows.owner.op, tensor.ARange):
start, stop, step = rows.owner.inputs
#print "SSS", start, stop, step
if getattr(start, 'data', None) != 0: #constants will have data
return False
if getattr(step, 'data', None) != 1: # constant step will have data
return False
if stop.owner and isinstance(stop.owner.op, tensor.Subtensor):
#print "GOT SUBTENSOR"
shape_subtensor = stop.owner
if shape_subtensor.op.idx_list == [0]:
shape_var, = shape_subtensor.inputs
#print "GOT SHAPE VAR", shape_var
if shape_var.owner and shape_var.owner.op == tensor._shape:
return shape_var.owner.inputs[0] is labels
@opt.register_specialize
@gof.local_optimizer([])
def local_advanced_indexing_crossentropy_onehot(node):
def y_from_indexing(rows):
# we're looking for something of the form
# T.arange(y.shape[0])
# in that case return 'y' else None
if rows.owner and isinstance(rows.owner.op, tensor.ARange):
start, stop, step = rows.owner.inputs
#print "SSS", start, stop, step
if getattr(start, 'data', None) != 0: #constants will have data
raise ValueError()
if getattr(step, 'data', None) != 1: # constant step will have data
raise ValueError()
if stop.owner and isinstance(stop.owner.op, tensor.Subtensor):
#print "GOT SUBTENSOR"
shape_subtensor = stop.owner
if shape_subtensor.op.idx_list == [0]:
shape_var, = shape_subtensor.inputs
#print "GOT SHAPE VAR", shape_var
if shape_var.owner and shape_var.owner.op == tensor._shape:
return shape_var.owner.inputs[0]
log = None
sm = None
# First case: log(softmax(x))[rows, labels]
......@@ -930,7 +928,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
except:
pass
if log and log.owner and log.owner.op == tensor.log:
sm = log.owner.inputs[0]
sm = log.owner.inputs[0]
# Second case: log(softmax(x)[rows, labels])
if node.op == tensor.log:
......@@ -949,14 +947,12 @@ def local_advanced_indexing_crossentropy_onehot(node):
else:
x_var = sm.owner.inputs[0]
b_var = tensor.zeros_like(x_var[0])
# Check that rows == arange(labels.shape[0])
try:
y_var = y_from_indexing(rows)
except ValueError:
return
if y_var is labels and labels.ndim == 1 and x_var.ndim == 2:
return [-crossentropy_softmax_argmax_1hot_with_bias(x_var, b_var, y_var)[0]]
# Check that rows == arange(labels.shape[0])
if _check_rows_is_arange_len_labels(rows, labels):
if labels.ndim == 1 and x_var.ndim == 2:
return [-crossentropy_softmax_argmax_1hot_with_bias(x_var, b_var, labels)[0]]
def binary_crossentropy(output, target):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论