提交 0e72be06 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

prise en compte output_grad ou inputs de fonction 'mult' complexes

上级 fb75e585
......@@ -1170,27 +1170,21 @@ class Mul(ScalarOp):
def grad(self, inputs, (gz, )):
retval = []
input_type = theano.tensor.as_tensor_variable(inputs).type
if input_type in discrete_types:
retval = None
elif input_type in complex_types or gz.type in complex_types:
for input in inputs:
if input.type in continuous_types:
if gz.type in complex_types:
# zr+zi = (xr + xi)(yr + yi)
# zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr )
otherprod = mul(*(utils.difference(inputs, [input])))
yr = real(otherprod)
yi = imag(otherprod)
if input.type in complex_types:
retval += [complex(yr * real(gz) + yi * imag(gz),
yr * imag(gz) - yi * real(gz))]
else:
retval += [cast(yr * real(gz) + yi * imag(gz),
input.type.dtype)]
else:
retval += [cast(mul(*([gz] + utils.difference(inputs,
[input]))),
input.type.dtype)]
retval += [mul(*([gz] +
utils.difference(inputs, [input])))]
else:
retval += [None]
for input in inputs:
retval += [cast(mul(*([gz] +
utils.difference(inputs, [input]))),
input_type.dtype)]
return retval
mul = Mul(upcast_out, name='mul')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论