提交 8e116b45 authored 作者: James Bergstra's avatar James Bergstra

Corrected grad of abs() in complex case.

上级 d3145f80
......@@ -1114,11 +1114,9 @@ class Abs(UnaryScalarOp):
return numpy.abs(x)
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return gz * sgn(x),
return gz * x / abs(x), # formula works for complex and real
else:
return None,
#backport
#return gz * sgn(x) if x.type in grad_types else None,
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
type = node.inputs[0].type
if type in int_types:
......
......@@ -64,3 +64,12 @@ class TestRealImag(unittest.TestCase):
mval = numpy.asarray(rng.randn(2,5))
utt.verify_grad(f, [mval])
def test_abs_grad(self):
def f(m):
c = complex(m[0], m[1])
return .5 * abs(c)
rng = numpy.random.RandomState(9333)
mval = numpy.asarray(rng.randn(2,5))
utt.verify_grad(f, [mval])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论