提交 54b4dea2 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fix to Sum.grad

上级 f6278392
...@@ -1596,6 +1596,12 @@ class Sum(CAReduceDtype): ...@@ -1596,6 +1596,12 @@ class Sum(CAReduceDtype):
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
out = self(*inp)
if out.dtype.find('int') != -1:
return [x.zeros_like().astype(theano.config.floatX)]
gz, = grads gz, = grads
gz = as_tensor_variable(gz) gz = as_tensor_variable(gz)
axis = self.axis axis = self.axis
......
...@@ -671,6 +671,7 @@ class T_mean_dtype(unittest.TestCase): ...@@ -671,6 +671,7 @@ class T_mean_dtype(unittest.TestCase):
assert x.dtype == dtype, (x, x.dtype, dtype) assert x.dtype == dtype, (x, x.dtype, dtype)
def test_mean_custom_dtype(self): def test_mean_custom_dtype(self):
""" """
Test the ability to provide your own output dtype for a mean. Test the ability to provide your own output dtype for a mean.
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论