提交 757d0483 authored 作者: James Bergstra's avatar James Bergstra

tensor basic - replaced complex Ops with true Elemwise ones.

上级 437d05c4
......@@ -1629,88 +1629,25 @@ def sinh(a):
def tanh(a):
"""hyperbolic tangent of a"""
class Real(Op):
"""Extract the real elements of a complex ndarray"""
view_map = {0:[0]}
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x):
_x = as_tensor(x)
y_dtype = _x.type.dtype
if y_dtype == 'complex64':
y_dtype = 'float32'
if y_dtype == 'complex128':
y_dtype = 'float64'
_y = Tensor(y_dtype, _x.type.broadcastable)()
return Apply(self, [_x], [_y])
def perform(self, node, (x,), (y,)):
if str(x.dtype).startswith('complex'):
y[0] = x.real
else:
y[0] = x
def grad(self, inputs, (g_y,)):
#TODO: waiting on a Complex(real=, imag=) op that can merge
#things back into a complex tensor
raise NotImplementedError()
_real = Real()
@constructor
def real(x):
"""Return the real part of real or complex-valued `x`
For real-valued `x`, `x` itself is returned.
"""
_x = as_tensor_variable(x)
if _x.type.dtype.startswith('complex'):
return _real(x)
else:
return _x
@_scal_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
class Imag(Op):
"""Extract the imaginary elements of a complex ndarray"""
view_map = {0:[0]}
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x):
_x = as_tensor_variable(x)
if not _x.type.dtype.startswith('complex'):
raise TypeError('Imag(x) requires complex x', x)
if _x.type.dtype == 'complex64': y_dtype = 'float32'
elif _x.type.dtype == 'complex128': y_dtype = 'float64'
else:
raise NotImplementedError('what is this?', y_dtype)
_y = Tensor(y_dtype, _x.type.broadcastable)()
return Apply(self, [_x], [_y])
def perform(self, node, (x,), (y,)):
if str(x.dtype).startswith('complex'):
y[0] = x.imag
else:
y[0] = x * 0
def grad(self, inputs, (g_y,)):
# TODO: waiting on a complex(real=, imag=) op that can merge
# things back into a complex tensor
raise NotImplementedError()
_imag = Imag()
@constructor
def imag(x):
"""Return the imaginary part of real or complex-valued `x`
@_scal_elemwise
def imag(z):
"""Return imaginary component of complex-valued tensor `z`"""
For real-valued 'x' this returns `zeros_like(x)`.
"""
_x = as_tensor_variable(x)
if _x.type.dtype.startswith('complex'):
return _imag(x)
else:
return zeros_like(x)
@_scal_elemwise
def angle(z):
"""Return polar-coordinate angle of complex-valued tensor `z`"""
@constructor
def angle(x):
"""Return the angular component of complex-valued `x`"""
raise NotImplementedError()
@_scal_elemwise
def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components"""
@_scal_elemwise
def complex_from_polar(abs, angle):
"""Return complex-valued tensor from polar coordinate specification"""
##########################
# Misc
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论