提交 62e2bf52 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: trust that gradients are floats

上级 cc8c9e4c
......@@ -3457,12 +3457,7 @@ class BatchedDot(Op):
if ygrad.broadcastable != y.broadcastable:
ygrad = patternbroadcast(ygrad, y.broadcastable)
rval = xgrad, ygrad
for elem in rval:
assert elem.dtype.find('float') != -1
return rval
return xgrad, ygrad
def R_op(self, inputs, eval_points):
# R_op for a \dot b evaluted at c for a and d for b is
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论