提交 f15fa46c authored 作者: James Bergstra's avatar James Bergstra

FIX: CAReduce casts gradient to input dtype

上级 76345743
...@@ -1594,8 +1594,9 @@ class Sum(CAReduceDtype): ...@@ -1594,8 +1594,9 @@ class Sum(CAReduceDtype):
else: else:
new_dims.append(i) new_dims.append(i)
i += 1 i += 1
return Elemwise(scalar.second)( ds_op = DimShuffle(gz.type.broadcastable, new_dims)
x, DimShuffle(gz.type.broadcastable, new_dims)(gz)), gx = Elemwise(scalar.second)(x, ds_op(gz).astype(x.dtype))
return [gx]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# There is just one element in inputs and eval_points, the axis are # There is just one element in inputs and eval_points, the axis are
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论