提交 de4b6c60 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Allow optimization to work on xent(softmax(vector))

上级 6e61aa5e
...@@ -1002,6 +1002,21 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1002,6 +1002,21 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
else: else:
return return
# If the arg to softmax is a broadcasted vector, d_sm has the form:
# DimShuffle{x,0}(Sum{0}(...))
# we consider what's inside of the sum instead
vector_softmax = False
if d_sm.owner and isinstance(d_sm.owner.op, tensor.DimShuffle):
ds_op = d_sm.owner.op
if ds_op.input_broadcastable == (False,) and ds_op.new_order == ('x', 0):
maybe_sum = d_sm.owner.inputs[0]
if maybe_sum.owner and isinstance(maybe_sum.owner.op, tensor.Sum):
if sm.broadcastable == (True, False)\
and maybe_sum.owner.op.axis == (0,)\
and len(maybe_sum.owner.inputs) == 1:
vector_softmax = True
d_sm = maybe_sum.owner.inputs[0]
# Two cases are supported: # Two cases are supported:
# 1. AdvancedIncSubtensor( # 1. AdvancedIncSubtensor(
# zeros_like(softmax(x)), # zeros_like(softmax(x)),
...@@ -1140,9 +1155,11 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1140,9 +1155,11 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
if not _is_const(z, 0): if not _is_const(z, 0):
return return
if z.type not in (dmatrix, fmatrix): if z.type not in (dmatrix, fmatrix):
if not (vector_softmax and z.broadcastable == (True, False)):
return return
# here we know that we are incrementing a matrix of zeros # here we know that we are incrementing a matrix of zeros
# Since out_grad and sm are the inputs of softmax_grad, # (or a broadcasted vector).
# Since d_sm and sm are the inputs of softmax_grad,
# if the graph is valid, they have the same shape, so we # if the graph is valid, they have the same shape, so we
# also know that z has the right shape. # also know that z has the right shape.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论