提交 41678c83 authored 作者: James Bergstra's avatar James Bergstra

added Real and Imag Ops to tensor.basic.

上级 fe6a777b
......@@ -1214,6 +1214,83 @@ def sinh(a):
def tanh(a):
"""hyperbolic tangent of a"""
class Real(Op):
"""Extract the real elements of a complex ndarray"""
view_map = {0:[0]}
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x):
_x = as_tensor(x)
y_dtype = _x.type.dtype
if y_dtype == 'complex64':
y_dtype = 'float32'
if y_dtype == 'complex128':
y_dtype = 'float64'
_y = Tensor(y_dtype, _x.type.broadcastable)()
return Apply(self, [_x], [_y])
def perform(self, node, (x,), (y,)):
if str(x.dtype).startswith('complex'):
y[0] = x.real
else:
y[0] = x
def grad(self, inputs, (g_y,)):
#TODO: waiting on a Complex(real=, imag=) op that can merge
#things back into a complex tensor
raise NotImplementedError()
_real = Real()
@constructor
def real(x):
"""Return the real part of real or complex-valued `x`
For real-valued `x`, `x` itself is returned.
"""
_x = as_tensor_variable(x)
if _x.type.dtype.startswith('complex'):
return _real(x)
else:
return _x
class Imag(Op):
"""Extract the imaginary elements of a complex ndarray"""
view_map = {0:[0]}
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x):
_x = as_tensor_variable(x)
if not _x.type.dtype.startswith('complex'):
raise TypeError('Imag(x) requires complex x', x)
if _x.type.dtype == 'complex64': y_dtype = 'float32'
elif _x.type.dtype == 'complex128': y_dtype = 'float64'
else:
raise NotImplementedError('what is this?', y_dtype)
_y = Tensor(y_dtype, _x.type.broadcastable)()
return Apply(self, [_x], [_y])
def perform(self, node, (x,), (y,)):
if str(x.dtype).startswith('complex'):
y[0] = x.imag
else:
y[0] = x * 0
def grad(self, inputs, (g_y,)):
# TODO: waiting on a complex(real=, imag=) op that can merge
# things back into a complex tensor
raise NotImplementedError()
_imag = Imag()
@constructor
def imag(x):
"""Return the imaginary part of real or complex-valued `x`
For real-valued 'x' this returns `zeros_like(x)`.
"""
_x = as_tensor_variable(x)
if _x.type.dtype.startswith('complex'):
return _imag(x)
else:
return zeros_like(x)
##########################
# Misc
......
import unittest
import theano
from theano.tensor import *
class TestRealImag(unittest.TestCase):
def test0(self):
x= zvector()
rng = numpy.random.RandomState(23)
xval = numpy.asarray(list(numpy.complex(rng.randn(), rng.randn()) for i in xrange(10)))
assert numpy.all( xval.real == theano.function([x], real(x))(xval))
assert numpy.all( xval.imag == theano.function([x], imag(x))(xval))
def test_on_real_input(self):
x= dvector()
rng = numpy.random.RandomState(23)
xval = rng.randn(10)
assert numpy.all( 0 == theano.function([x], imag(x))(xval))
assert numpy.all( xval == theano.function([x], real(x))(xval))
def test_cast(self):
x= zvector()
self.failUnlessRaises(TypeError, cast, x, 'int32')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论