提交 127f75e3 authored 作者: James Bergstra's avatar James Bergstra

scalar - Corrected gradient of Mul in complex case.

上级 9c293af2
...@@ -849,6 +849,13 @@ class Mul(ScalarOp): ...@@ -849,6 +849,13 @@ class Mul(ScalarOp):
retval = [] retval = []
for input in inputs: for input in inputs:
if input.type in grad_types: if input.type in grad_types:
if input.type in complex_types:
# does casting from real to complex work?
dz_dinput = cast(mul(*(utils.difference(inputs, [input]))), input.type.dtype)
x = real(dz_dinput)
y = imag(dz_dinput)
retval += [complex(x*real(gz)+y*imag(gz), x*imag(gz)-y*real(gz))]
else:
retval += [cast(mul(*([gz] + utils.difference(inputs, [input]))), input.type.dtype)] retval += [cast(mul(*([gz] + utils.difference(inputs, [input]))), input.type.dtype)]
else: else:
retval += [None] retval += [None]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论