提交 571d3ab3 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged

...@@ -17,7 +17,8 @@ def info(*msg): ...@@ -17,7 +17,8 @@ def info(*msg):
_logger.info('INFO theano.gradient: '+' '.join(msg)) _logger.info('INFO theano.gradient: '+' '.join(msg))
_msg_retType = 'op.grad(...) returned a non-list' _msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients' _msg_badlen = ('op.grad(...) returned wrong number of gradients (Op, '
'number of gradients, number of inputs)')
def grad_sources_inputs(sources, graph_inputs, warn_type=True): def grad_sources_inputs(sources, graph_inputs, warn_type=True):
""" """
......
...@@ -688,8 +688,10 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): ...@@ -688,8 +688,10 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
dx[i] = dy[i] * sm[i] #vector scale dx[i] = dy[i] * sm[i] #vector scale
dx[i, y_idx[i]] -= dy[i] #scalar decrement dx[i, y_idx[i]] -= dy[i] #scalar decrement
output_storage[0][0] = dx output_storage[0][0] = dx
def grad(self, *args): def grad(self, (dy, sm, y_idx), (g_dx, )):
raise NotImplementedError() # Note: currently we do not care about computing the gradient of dy,
# since we usually should not need it.
return [None, dy * g_dx, None]
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (2,)
def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub): def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论