提交 dff4d08c authored 作者: James Bergstra's avatar James Bergstra

Fixed scalar abs to produce real valued output from complex input

上级 4c9ef827
...@@ -625,7 +625,15 @@ class Identity(UnaryScalarOp): ...@@ -625,7 +625,15 @@ class Identity(UnaryScalarOp):
identity = Identity(same_out, name = 'identity') identity = Identity(same_out, name = 'identity')
class Abs(UnaryScalarOp): class Abs(UnaryScalarOp):
#TODO: for complex input, output is some flavour of float def make_node(self, x):
inputs = [as_scalar(input) for input in [x]]
if inputs[0].type == complex64:
outputs = [float32()]
elif inputs[0].type == complex128:
outputs = [float64()]
else:
outputs = [t() for t in self.output_types([input.type for input in inputs])]
return Apply(self, inputs, outputs)
def impl(self, x): def impl(self, x):
return numpy.abs(x) return numpy.abs(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论