提交 a8e7bab7 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

This fix gh-774.

上级 555af254
......@@ -2353,9 +2353,6 @@ complex = Complex(name='complex')
class Conj(UnaryScalarOp):
def impl(self, x):
return numpy.conj(x)
def grad(self, (x, ), (gz, )):
return [conj(gz)]
conj = Conj(same_out, name='conj')
......@@ -2375,10 +2372,9 @@ class ComplexFromPolar(BinaryScalarOp):
return numpy.complex128(numpy.complex(x, y))
def grad(self, (r, theta), (gz,)):
gr = cos(theta) * real(gz) + sin(theta) * imag(gz)
gtheta = -real(gz) * r * sin(theta) + imag(gz) * r * cos(theta)
return [cast(gr, r.type.dtype),
cast(gtheta, theta.type.dtype)]
gr = gz * complex_from_polar(1, theta)
gtheta = gz * complex_from_polar(r, -theta)
return [gr, gtheta]
complex_from_polar = ComplexFromPolar(name='complex_from_polar')
......
......@@ -75,6 +75,25 @@ class Erfc(UnaryScalarOp):
erfc = Erfc(upgrade_to_float_no_complex, name='erfc')
class Gamma(UnaryScalarOp):
@staticmethod
def st_impl(x):
return scipy.special.gamma(x)
def impl(self, x):
return Gamma.st_impl(x)
def grad(self, (x, ), (gz, )):
return gz * gamma(x) * psi(x),
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
gamma = Gamma(upgrade_to_float, name='gamma')
class GammaLn(UnaryScalarOp):
"""
Log gamma function.
......
......@@ -2741,6 +2741,11 @@ def erfc(a):
"""complementary error function"""
@_scal_elemwise
def gamma(a):
"""gamma function"""
@_scal_elemwise
def gammaln(a):
"""log gamma function"""
......@@ -2771,10 +2776,14 @@ def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components"""
@_scal_elemwise_with_nfunc('conj', 1, -1)
def conj(z):
"""Return the complex conjugate of `z`."""
@_scal_elemwise
def complex_from_polar(abs, angle):
"""Return complex-valued tensor from polar coordinate specification"""
"""Return complex-valued tensor from polar coordinate specification."""
##########################
# Misc
......
......@@ -203,6 +203,10 @@ def erf_inplace(a):
def erfc_inplace(a):
"""complementary error function"""
@_scal_inplace
def gamma_inplace(a):
"""gamma function"""
@_scal_inplace
def gammaln_inplace(a):
"""log gamma function"""
......
......@@ -1291,6 +1291,22 @@ _good_broadcast_unary_gammaln = dict(
_grad_broadcast_unary_gammaln = dict(
normal=(rand_ranged(1e-8, 10, (2, 3)),),)
GammaTester = makeBroadcastTester(
op=tensor.gamma,
expected=scipy.special.gamma,
good=_good_broadcast_unary_gammaln,
grad=_grad_broadcast_unary_gammaln,
mode=mode_no_scipy,
skip=skip_scipy)
GammaInplaceTester = makeBroadcastTester(
op=inplace.gamma_inplace,
expected=scipy.special.gamma,
good=_good_broadcast_unary_gammaln,
grad=_grad_broadcast_unary_gammaln,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy)
GammaLnTester = makeBroadcastTester(
op=tensor.gammaln,
expected=expected_gammaln,
......@@ -1343,6 +1359,31 @@ OnesLikeTester = makeBroadcastTester(
grad=_grad_broadcast_unary_normal,
name='OnesLike')
# Complex operations
_good_complex_from_polar = dict(
same_shapes=(abs(rand(2, 3)), rand(2, 3)),
not_same_dimensions=(abs(rand(2, 2)), rand(2)),
scalar=(abs(rand(2, 3)), rand(1, 1)),
row=(abs(rand(2, 3)), rand(1, 3)),
column=(abs(rand(2, 3)), rand(2, 1)),
integers=(abs(randint(2, 3)), randint(2, 3)),
empty=(numpy.asarray([]), numpy.asarray([1])),)
_grad_complex_from_polar = dict(
same_shapes=(abs(rand(2, 3)), rand(2, 3)),
scalar=(abs(rand(2, 3)), rand(1, 1)),
row=(abs(rand(2, 3)), rand(1, 3)),
column=(abs(rand(2, 3)), rand(2, 1)))
ComplexFromPolarTester = makeBroadcastTester(
op=tensor.complex_from_polar,
expected=lambda r, theta: r * numpy.cos(theta) + 1j * r * numpy.sin(theta),
good=_good_complex_from_polar)
ConjTester = makeBroadcastTester(
op=tensor.conj,
expected=numpy.conj,
good=_good_broadcast_unary_normal)
DotTester = makeTester(name = 'DotTester',
op = dot,
expected = lambda x, y: numpy.dot(x, y),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论