提交 f6911c94 authored 作者: James Bergstra's avatar James Bergstra

Added optimization to remove advanced indexing method of expressing

cross-entropy
上级 4e603bfc
......@@ -893,6 +893,68 @@ def local_argmax_pushdown(node):
return tensor._max_and_argmax(pre_x+tensor.DimShuffle(pre_bias.broadcastable,
('x',0))(pre_bias), axis)
@opt.register_specialize
@gof.local_optimizer([])
def local_advanced_indexing_crossentropy_onehot(node):
def y_from_indexing(rows):
# we're looking for something of the form
# T.arange(y.shape[0])
# in that case return 'y' else None
if rows.owner and isinstance(rows.owner.op, tensor.ARange):
start, stop, step = rows.owner.inputs
#print "SSS", start, stop, step
if getattr(start, 'data', None) != 0: #constants will have data
raise ValueError()
if getattr(step, 'data', None) != 1: # constant step will have data
raise ValueError()
if stop.owner and isinstance(stop.owner.op, tensor.Subtensor):
#print "GOT SUBTENSOR"
shape_subtensor = stop.owner
if shape_subtensor.op.idx_list == [0]:
shape_var, = shape_subtensor.inputs
#print "GOT SHAPE VAR", shape_var
if shape_var.owner and shape_var.owner.op == tensor._shape:
return shape_var.owner.inputs[0]
log = None
sm = None
# First case: log(softmax(x))[rows, labels]
if isinstance(node.op, tensor.AdvancedSubtensor):
try:
log, rows, labels = node.inputs
except:
pass
if log and log.owner and log.owner.op == tensor.log:
sm = log.owner.inputs[0]
# Second case: log(softmax(x)[rows, labels])
if node.op == tensor.log:
pre_log = node.inputs[0].owner
if pre_log and isinstance(pre_log.op, tensor.AdvancedSubtensor):
try:
sm, rows, labels = pre_log.inputs
except:
pass
if sm is not None and sm.owner and sm.owner.op in (softmax, softmax_with_bias):
sm_w_bias = local_softmax_with_bias.transform(sm)
if sm_w_bias:
assert sm_w_bias[0].owner.op == softmax_with_bias
x_var, b_var = sm_w_bias[0].owner.inputs
else:
x_var = sm.owner.inputs[0]
b_var = tensor.zeros_like(x_var[0])
# Check that rows == arange(labels.shape[0])
try:
y_var = y_from_indexing(rows)
except ValueError:
return
if y_var is labels and labels.ndim == 1 and x_var.ndim == 2:
return [-crossentropy_softmax_argmax_1hot_with_bias(x_var, b_var, y_var)[0]]
def binary_crossentropy(output, target):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论