提交 2c962433 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add new scalar and elemwise op erf and erfc.

上级 fa5b4a12
from basic import * from basic import *
from basic_scipy import *
#definition theano.scalar op that have their python implementation taked from scipy
#as scipy is not always available, we put threat them separatly
from theano.scalar.basic import UnaryScalarOp,exp,sqrt,upgrade_to_float,complex_types,float_types
#import theano.tensor.elemwise as elemwise
import numpy
imported_scipy_special = False
try:
import scipy.special
imported_scipy_special = True
except ImportError:
pass
class Erf(UnaryScalarOp):
def impl(self, x):
if imported_scipy_special:
return scipy.special.erf(x)
else:
super(Erf,self).impl(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
return gz * 2. / numpy.sqrt(numpy.pi) * exp(-x*x),
else:
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = erf(%(x)s);" % locals()
erf = Erf(upgrade_to_float, name= 'erf')
class Erfc(UnaryScalarOp):
def impl(self, x):
if imported_scipy_special:
return scipy.special.erf(x)
else:
super(Erfc,self).impl(x)
def grad(self, (x, ), (gz, )):
if x.type in complex_types:
raise NotImplementedError()
elif x.type in float_types:
return - gz * 2. / numpy.sqrt(numpy.pi) * exp(-x*x),
else:
return None,
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = erfc(%(x)s);" % locals()
erfc = Erfc(upgrade_to_float, name = 'erfc')
...@@ -1659,6 +1659,14 @@ def sinh(a): ...@@ -1659,6 +1659,14 @@ def sinh(a):
def tanh(a): def tanh(a):
"""hyperbolic tangent of a""" """hyperbolic tangent of a"""
@_scal_elemwise
def erf(a):
"""error function"""
@_scal_elemwise
def erfc(a):
"""complementary error function"""
@_scal_elemwise @_scal_elemwise
def real(z): def real(z):
"""Return real component of complex-valued tensor `z`""" """Return real component of complex-valued tensor `z`"""
......
...@@ -164,6 +164,14 @@ def sinh_inplace(a): ...@@ -164,6 +164,14 @@ def sinh_inplace(a):
def tanh_inplace(a): def tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)""" """hyperbolic tangent of `a` (inplace on `a`)"""
@_scal_inplace
def erf_inplace(a):
"""error function"""
@_scal_inplace
def erfc_inplace(a):
"""complementary error function"""
@_scal_inplace @_scal_inplace
def second_inplace(a): def second_inplace(a):
"""Fill `a` with `b`""" """Fill `a` with `b`"""
......
...@@ -21,6 +21,13 @@ from theano.tests import unittest_tools as utt ...@@ -21,6 +21,13 @@ from theano.tests import unittest_tools as utt
from numpy.testing import dec from numpy.testing import dec
from numpy.testing.noseclasses import KnownFailureTest from numpy.testing.noseclasses import KnownFailureTest
imported_scipy_special = False
try:
import scipy.special
imported_scipy_special = True
except ImportError:
pass
### seed random number generator so that unittests are deterministic ### ### seed random number generator so that unittests are deterministic ###
utt.seed_rng() utt.seed_rng()
...@@ -88,8 +95,14 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r ...@@ -88,8 +95,14 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r
% (self.op, testname) % (self.op, testname)
exc_value.args = exc_value.args + (err_msg, ) exc_value.args = exc_value.args + (err_msg, )
raise type, exc_value, traceback raise type, exc_value, traceback
if isinstance(self.expected,dict) and testname in self.expected:
expecteds = self.expected[testname]
#with numpy version, when we print a number and read it back, we don't get exactly the same result
#So we accept rounding error in that case.
eps = 5e-9
else:
expecteds = self.expected(*inputs) expecteds = self.expected(*inputs)
eps = 1e-10
try: try:
variables = f(*inputs) variables = f(*inputs)
...@@ -104,7 +117,8 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r ...@@ -104,7 +117,8 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r
expecteds = (expecteds, ) expecteds = (expecteds, )
for i, (variable, expected) in enumerate(zip(variables, expecteds)): for i, (variable, expected) in enumerate(zip(variables, expecteds)):
if variable.dtype != expected.dtype or variable.shape != expected.shape or \ if variable.dtype != expected.dtype or variable.shape != expected.shape or \
numpy.any(numpy.abs(variable - expected) > 1e-10): numpy.any(numpy.abs(variable - expected) > eps):
import pdb;pdb.set_trace()
self.fail("Test %s::%s: Output %s gave the wrong value. With inputs %s, expected %s, got %s." self.fail("Test %s::%s: Output %s gave the wrong value. With inputs %s, expected %s, got %s."
% (self.op, testname, i, inputs, expected, variable)) % (self.op, testname, i, inputs, expected, variable))
...@@ -190,6 +204,7 @@ def makeBroadcastTester(op, expected, checks = {}, **kwargs): ...@@ -190,6 +204,7 @@ def makeBroadcastTester(op, expected, checks = {}, **kwargs):
if kwargs.has_key('inplace'): if kwargs.has_key('inplace'):
if kwargs['inplace']: if kwargs['inplace']:
_expected = expected _expected = expected
if not isinstance(_expected,dict):
expected = lambda *inputs: numpy.array(_expected(*inputs), dtype = inputs[0].dtype) expected = lambda *inputs: numpy.array(_expected(*inputs), dtype = inputs[0].dtype)
def inplace_check(inputs, outputs): def inplace_check(inputs, outputs):
# this used to be inputs[0] is output[0] # this used to be inputs[0] is output[0]
...@@ -612,6 +627,54 @@ TanhInplaceTester = makeBroadcastTester(op = inplace.tanh_inplace, ...@@ -612,6 +627,54 @@ TanhInplaceTester = makeBroadcastTester(op = inplace.tanh_inplace,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
inplace = True) inplace = True)
#inplace ops when the input is integer and the output is float*
# don't have a well defined behavior. We don't test that case.
_good_broadcast_unary_normal_no_int = _good_broadcast_unary_normal.copy()
del _good_broadcast_unary_normal_no_int['integers']
if imported_scipy_special:
expected = scipy.special.erf
else:
integers = numpy.asarray([[-1., 0.99532227, -0.99532227],
[-0.99532227, 1., 0.84270079]])
corner_case = numpy.asarray([-0.99959305, -0.99532227, -0.96610515, -0.84270079, -0.52049988, -0.52924362,
-0.51166826, 0., 0.51166826, 0.52049988, 0.79690821, 0.84270079,
0.96610515, 0.99532227, 0.99959305])
normal = numpy.array([[-1. , 0.99991358, 0.70314729],
[ 0.9977147 , -0.99999884, 0.33409098]])
expected = dict(integers=integers,corner_case=corner_case,normal = normal)
ErfTester = makeBroadcastTester(op = erf,
expected = expected,
good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal)
ErfInplaceTester = makeBroadcastTester(op = inplace.erf_inplace,
expected = expected,
good = _good_broadcast_unary_normal_no_int,
grad = _grad_broadcast_unary_normal,
inplace = True)
if imported_scipy_special:
expected = scipy.special.erfc
else:
integers = numpy.array([[ 2.00000000e+00, 4.67773498e-03, 1.99532227e+00],
[ 1.99532227e+00, 1.53745979e-12, 1.57299207e-01]])
corner_case = numpy.array([ 1.99959305e+00, 1.99532227e+00, 1.96610515e+00,
1.84270079e+00, 1.52049988e+00, 1.52924362e+00,
1.51166826e+00, 1.00000000e+00, 4.88331739e-01,
4.79500122e-01, 2.03091788e-01, 1.57299207e-01,
3.38948535e-02, 4.67773498e-03, 4.06952017e-04])
normal = numpy.array([[ 2.00000000e+00, 8.64228449e-05, 2.96852710e-01],
[ 2.28530326e-03, 1.99999884e+00, 6.65909025e-01]])
expected = dict(integers=integers,corner_case=corner_case,normal = normal)
ErfcTester = makeBroadcastTester(op = erfc,
expected = expected,
good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal)
ErfcInplaceTester = makeBroadcastTester(op = inplace.erfc_inplace,
expected = expected,
good = _good_broadcast_unary_normal_no_int,
grad = _grad_broadcast_unary_normal,
inplace = True)
DotTester = makeTester(name = 'DotTester', DotTester = makeTester(name = 'DotTester',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论