提交 8e116b45 authored 作者: James Bergstra's avatar James Bergstra

Corrected grad of abs() in complex case.

上级 d3145f80
...@@ -1114,11 +1114,9 @@ class Abs(UnaryScalarOp): ...@@ -1114,11 +1114,9 @@ class Abs(UnaryScalarOp):
return numpy.abs(x) return numpy.abs(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return gz * sgn(x), return gz * x / abs(x), # formula works for complex and real
else: else:
return None, return None,
#backport
#return gz * sgn(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
type = node.inputs[0].type type = node.inputs[0].type
if type in int_types: if type in int_types:
......
...@@ -64,3 +64,12 @@ class TestRealImag(unittest.TestCase): ...@@ -64,3 +64,12 @@ class TestRealImag(unittest.TestCase):
mval = numpy.asarray(rng.randn(2,5)) mval = numpy.asarray(rng.randn(2,5))
utt.verify_grad(f, [mval]) utt.verify_grad(f, [mval])
def test_abs_grad(self):
def f(m):
c = complex(m[0], m[1])
return .5 * abs(c)
rng = numpy.random.RandomState(9333)
mval = numpy.asarray(rng.randn(2,5))
utt.verify_grad(f, [mval])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论