提交 0e72be06 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

prise en compte output_grad ou inputs de fonction 'mult' complexes

上级 fb75e585
...@@ -1170,27 +1170,21 @@ class Mul(ScalarOp): ...@@ -1170,27 +1170,21 @@ class Mul(ScalarOp):
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
retval = [] retval = []
for input in inputs: input_type = theano.tensor.as_tensor_variable(inputs).type
if input.type in continuous_types: if input_type in discrete_types:
if gz.type in complex_types: retval = None
# zr+zi = (xr + xi)(yr + yi) elif input_type in complex_types or gz.type in complex_types:
# zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr ) for input in inputs:
otherprod = mul(*(utils.difference(inputs, [input]))) retval += [mul(*([gz] +
yr = real(otherprod) utils.difference(inputs, [input])))]
yi = imag(otherprod) else:
if input.type in complex_types: for input in inputs:
retval += [complex(yr * real(gz) + yi * imag(gz), retval += [cast(mul(*([gz] +
yr * imag(gz) - yi * real(gz))] utils.difference(inputs, [input]))),
else: input_type.dtype)]
retval += [cast(yr * real(gz) + yi * imag(gz),
input.type.dtype)]
else:
retval += [cast(mul(*([gz] + utils.difference(inputs,
[input]))),
input.type.dtype)]
else:
retval += [None]
return retval return retval
mul = Mul(upcast_out, name='mul') mul = Mul(upcast_out, name='mul')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论