提交 dff4d08c authored 作者: James Bergstra's avatar James Bergstra

Fixed scalar abs to produce real valued output from complex input

上级 4c9ef827
......@@ -625,7 +625,15 @@ class Identity(UnaryScalarOp):
identity = Identity(same_out, name = 'identity')
class Abs(UnaryScalarOp):
#TODO: for complex input, output is some flavour of float
def make_node(self, x):
inputs = [as_scalar(input) for input in [x]]
if inputs[0].type == complex64:
outputs = [float32()]
elif inputs[0].type == complex128:
outputs = [float64()]
else:
outputs = [t() for t in self.output_types([input.type for input in inputs])]
return Apply(self, inputs, outputs)
def impl(self, x):
return numpy.abs(x)
def grad(self, (x, ), (gz, )):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论