提交 d66fda21 authored 作者: nouiz's avatar nouiz

Merge pull request #826 from bouchnic/complex

Complex
...@@ -2425,9 +2425,6 @@ complex = Complex(name='complex') ...@@ -2425,9 +2425,6 @@ complex = Complex(name='complex')
class Conj(UnaryScalarOp): class Conj(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.conj(x) return numpy.conj(x)
def grad(self, (x, ), (gz, )):
return [conj(gz)]
conj = Conj(same_out, name='conj') conj = Conj(same_out, name='conj')
...@@ -2447,10 +2444,9 @@ class ComplexFromPolar(BinaryScalarOp): ...@@ -2447,10 +2444,9 @@ class ComplexFromPolar(BinaryScalarOp):
return numpy.complex128(numpy.complex(x, y)) return numpy.complex128(numpy.complex(x, y))
def grad(self, (r, theta), (gz,)): def grad(self, (r, theta), (gz,)):
gr = cos(theta) * real(gz) + sin(theta) * imag(gz) gr = gz * complex_from_polar(1, theta)
gtheta = -real(gz) * r * sin(theta) + imag(gz) * r * cos(theta) gtheta = gz * complex_from_polar(r, -theta)
return [cast(gr, r.type.dtype), return [gr, gtheta]
cast(gtheta, theta.type.dtype)]
complex_from_polar = ComplexFromPolar(name='complex_from_polar') complex_from_polar = ComplexFromPolar(name='complex_from_polar')
......
...@@ -75,6 +75,30 @@ class Erfc(UnaryScalarOp): ...@@ -75,6 +75,30 @@ class Erfc(UnaryScalarOp):
erfc = Erfc(upgrade_to_float_no_complex, name='erfc') erfc = Erfc(upgrade_to_float_no_complex, name='erfc')
class Gamma(UnaryScalarOp):
@staticmethod
def st_impl(x):
return scipy.special.gamma(x)
def impl(self, x):
return Gamma.st_impl(x)
def grad(self, (x, ), (gz, )):
return gz * gamma(x) * psi(x),
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in float_types:
return """%(z)s = tgamma(%(x)s);""" % locals()
raise NotImplementedError('only floating point is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
gamma = Gamma(upgrade_to_float, name='gamma')
class GammaLn(UnaryScalarOp): class GammaLn(UnaryScalarOp):
""" """
Log gamma function. Log gamma function.
......
...@@ -2768,6 +2768,11 @@ def erfc(a): ...@@ -2768,6 +2768,11 @@ def erfc(a):
"""complementary error function""" """complementary error function"""
@_scal_elemwise
def gamma(a):
"""gamma function"""
@_scal_elemwise @_scal_elemwise
def gammaln(a): def gammaln(a):
"""log gamma function""" """log gamma function"""
...@@ -2798,10 +2803,14 @@ def complex(real, imag): ...@@ -2798,10 +2803,14 @@ def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components""" """Return complex-valued tensor with `real` and `imag` components"""
@_scal_elemwise_with_nfunc('conj', 1, -1)
def conj(z):
"""Return the complex conjugate of `z`."""
@_scal_elemwise @_scal_elemwise
def complex_from_polar(abs, angle): def complex_from_polar(abs, angle):
"""Return complex-valued tensor from polar coordinate specification""" """Return complex-valued tensor from polar coordinate specification."""
########################## ##########################
# Misc # Misc
......
...@@ -219,6 +219,10 @@ def erf_inplace(a): ...@@ -219,6 +219,10 @@ def erf_inplace(a):
def erfc_inplace(a): def erfc_inplace(a):
"""complementary error function""" """complementary error function"""
@_scal_inplace
def gamma_inplace(a):
"""gamma function"""
@_scal_inplace @_scal_inplace
def gammaln_inplace(a): def gammaln_inplace(a):
"""log gamma function""" """log gamma function"""
...@@ -271,6 +275,10 @@ def mod_inplace(a, b): ...@@ -271,6 +275,10 @@ def mod_inplace(a, b):
def pow_inplace(a, b): def pow_inplace(a, b):
"""elementwise power (inplace on `a`)""" """elementwise power (inplace on `a`)"""
@_scal_inplace
def conj_inplace(a):
"""elementwise conjugate (inplace on `a`)"""
pprint.assign(add_inplace, printing.OperatorPrinter('+=', -2, 'either')) pprint.assign(add_inplace, printing.OperatorPrinter('+=', -2, 'either'))
pprint.assign(mul_inplace, printing.OperatorPrinter('*=', -1, 'either')) pprint.assign(mul_inplace, printing.OperatorPrinter('*=', -1, 'either'))
pprint.assign(sub_inplace, printing.OperatorPrinter('-=', -2, 'left')) pprint.assign(sub_inplace, printing.OperatorPrinter('-=', -2, 'left'))
......
...@@ -1350,6 +1350,29 @@ _good_broadcast_unary_gammaln = dict( ...@@ -1350,6 +1350,29 @@ _good_broadcast_unary_gammaln = dict(
_grad_broadcast_unary_gammaln = dict( _grad_broadcast_unary_gammaln = dict(
normal=(rand_ranged(1e-8, 10, (2, 3)),),) normal=(rand_ranged(1e-8, 10, (2, 3)),),)
if theano.config.floatX == 'float32':
gamma_eps = 2e-4
else:
gamma_eps = 2e-10
GammaTester = makeBroadcastTester(
op=tensor.gamma,
expected=scipy.special.gamma,
good=_good_broadcast_unary_gammaln,
grad=_grad_broadcast_unary_gammaln,
mode=mode_no_scipy,
eps=gamma_eps,
skip=skip_scipy)
GammaInplaceTester = makeBroadcastTester(
op=inplace.gamma_inplace,
expected=scipy.special.gamma,
good=_good_broadcast_unary_gammaln,
grad=_grad_broadcast_unary_gammaln,
mode=mode_no_scipy,
eps=gamma_eps,
inplace=True,
skip=skip_scipy)
GammaLnTester = makeBroadcastTester( GammaLnTester = makeBroadcastTester(
op=tensor.gammaln, op=tensor.gammaln,
expected=expected_gammaln, expected=expected_gammaln,
...@@ -1402,6 +1425,37 @@ OnesLikeTester = makeBroadcastTester( ...@@ -1402,6 +1425,37 @@ OnesLikeTester = makeBroadcastTester(
grad=_grad_broadcast_unary_normal, grad=_grad_broadcast_unary_normal,
name='OnesLike') name='OnesLike')
# Complex operations
_good_complex_from_polar = dict(
same_shapes=(abs(rand(2, 3)), rand(2, 3)),
not_same_dimensions=(abs(rand(2, 2)), rand(2)),
scalar=(abs(rand(2, 3)), rand(1, 1)),
row=(abs(rand(2, 3)), rand(1, 3)),
column=(abs(rand(2, 3)), rand(2, 1)),
integers=(abs(randint(2, 3)), randint(2, 3)),
empty=(numpy.asarray([]), numpy.asarray([1])),)
_grad_complex_from_polar = dict(
same_shapes=(abs(rand(2, 3)), rand(2, 3)),
scalar=(abs(rand(2, 3)), rand(1, 1)),
row=(abs(rand(2, 3)), rand(1, 3)),
column=(abs(rand(2, 3)), rand(2, 1)))
ComplexFromPolarTester = makeBroadcastTester(
op=tensor.complex_from_polar,
expected=lambda r, theta: r * numpy.cos(theta) + 1j * r * numpy.sin(theta),
good=_good_complex_from_polar)
ConjTester = makeBroadcastTester(
op=tensor.conj,
expected=numpy.conj,
good=_good_broadcast_unary_normal)
ConjInplaceTester = makeBroadcastTester(
op=tensor.conj,
expected=numpy.conj,
good=_good_broadcast_unary_normal,
inplace=True)
DotTester = makeTester(name = 'DotTester', DotTester = makeTester(name = 'DotTester',
op = dot, op = dot,
expected = lambda x, y: numpy.dot(x, y), expected = lambda x, y: numpy.dot(x, y),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论