提交 01924edd authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add infer_shape to CrossentropySoftmaxArgmax1HotWithBias

上级 3964a349
......@@ -502,10 +502,17 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
# store the nll
nll[i] = -row[y_idx[i]] + m + numpy.log(sum_j)
output_storage[0][0] = nll
output_storage[1][0] = sm
output_storage[2][0] = am
def infer_shape(self, node, (x_shp, b_shp, idx_shp)):
nll_shp = (x_shp[0],)
sm_shp = x_shp
am_shp = idx_shp
return [nll_shp, sm_shp, am_shp]
def grad(self, (x, b, y_idx), (g_nll, g_sm, g_am)):
if g_am is not None:
raise NotImplementedError()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论