提交 c60c48f6 authored 作者: Frederic's avatar Frederic

Revert the grad of Mul. The new version fixed a probably not happening case, but brokens.

I added assert to have an error for the case not implemented.
上级 36ea2465
...@@ -1170,18 +1170,33 @@ class Mul(ScalarOp): ...@@ -1170,18 +1170,33 @@ class Mul(ScalarOp):
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
retval = [] retval = []
input_type = theano.tensor.as_tensor_variable(inputs).type
if input_type in discrete_types: # The following 3 lines verify that gz is complex when the
retval = None # output is complex. The rest of this function make this supposition.
elif input_type in complex_types or gz.type in complex_types: output_type = self.output_types([i.type for i in inputs])[0]
for input in inputs: if output_type in complex_types:
retval += [mul(*([gz] + assert gz.type in complex_types
utils.difference(inputs, [input])))]
else: for input in inputs:
for input in inputs: if input.type in continuous_types:
retval += [cast(mul(*([gz] + if gz.type in complex_types:
utils.difference(inputs, [input]))), # zr+zi = (xr + xi)(yr + yi)
input_type.dtype)] # zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr )
otherprod = mul(*(utils.difference(inputs, [input])))
yr = real(otherprod)
yi = imag(otherprod)
if input.type in complex_types:
retval += [complex(yr * real(gz) + yi * imag(gz),
yr * imag(gz) - yi * real(gz))]
else:
retval += [cast(yr * real(gz) + yi * imag(gz),
input.type.dtype)]
else:
retval += [cast(mul(*([gz] + utils.difference(inputs,
[input]))),
input.type.dtype)]
else:
retval += [None]
return retval return retval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论