提交 31c590fa authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Separate softmax(x) in softmax_with_bias(x, b) in another case.

上级 e858937f
......@@ -917,18 +917,15 @@ def _check_rows_is_arange_len_labels(rows, labels):
if rows.owner and isinstance(rows.owner.op, tensor.ARange):
start, stop, step = rows.owner.inputs
#print "SSS", start, stop, step
if getattr(start, 'data', None) != 0: #constants will have data
return False
if getattr(step, 'data', None) != 1: # constant step will have data
return False
if stop.owner and isinstance(stop.owner.op, tensor.Subtensor):
#print "GOT SUBTENSOR"
shape_subtensor = stop.owner
if shape_subtensor.op.idx_list == [0]:
shape_var, = shape_subtensor.inputs
#print "GOT SHAPE VAR", shape_var
if shape_var.owner and shape_var.owner.op == tensor._shape:
return shape_var.owner.inputs[0] is labels
......@@ -997,8 +994,13 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
except:
return
if sm is not None and sm.owner and sm.owner.op == softmax:
x_var = sm.owner.inputs[0]
if sm is not None and sm.owner and sm.owner.op in (softmax, softmax_with_bias):
sm_w_bias = local_softmax_with_bias.transform(sm.owner)
if sm_w_bias:
assert sm_w_bias[0].owner.op == softmax_with_bias
x_var, b_var = sm_w_bias[0].owner.inputs
else:
x_var = sm.owner.inputs[0]
else:
return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论