提交 d73c19ef authored 作者: Frederic Bastien's avatar Frederic Bastien

added infer_shape fct()

上级 aca6b1da
......@@ -73,6 +73,9 @@ class SoftmaxWithBias(gof.Op):
db = tensor.sum(dx, axis = 0)
return dx, db
def infer_shape(self, node, shape):
return [shape[0]]
def c_headers(self):
return ['<iostream>','<cmath>']
......@@ -231,6 +234,9 @@ class SoftmaxGrad(gof.Op):
def grad(self, *args):
raise NotImplementedError()
def infer_shape(self, node, shape):
return [shape[1]]
def c_code_cache_version(self):
return (3,)
def c_code(self, node, name, (dy, sm), (dx,), sub):
......@@ -330,6 +336,10 @@ class Softmax(gof.Op):
def grad(self, (x,), (g_sm,)):
sm = softmax(x)
return [softmax_grad(g_sm, sm)]
def infer_shape(self, node, shape):
return shape
softmax = Softmax()
@opt.register_specialize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论