提交 7ef44dfd authored 作者: Frederic Bastien's avatar Frederic Bastien

Move nfunc_spec to Scalar op, This allow to always have it even when we use Elemwise(a_scalar_op).

This happend frequently in the grad computation of elemwise. This could speed up DebugMode. This fix tests crashX
上级 ed337d4e
......@@ -1031,6 +1031,7 @@ class LT(LogicalComparison):
identity = False
commutative = False
associative = False
nfunc_spec = ('less', 2, 1)
def impl(self, x, y):
# built-in < don't support complex
......@@ -1049,6 +1050,7 @@ class GT(LogicalComparison):
identity = False
commutative = False
associative = False
nfunc_spec = ('greater', 2, 1)
def impl(self, x, y):
# built-in > don't support complex
......@@ -1067,6 +1069,7 @@ class LE(LogicalComparison):
identity = False
commutative = False
associative = False
nfunc_spec = ('less_equal', 2, 1)
def impl(self, x, y):
# built-in <= don't support complex
......@@ -1085,6 +1088,7 @@ class GE(LogicalComparison):
identity = False
commutative = False
associative = False
nfunc_spec = ('greater_equal', 2, 1)
def impl(self, x, y):
# built-in >= don't support complex
......@@ -1103,6 +1107,7 @@ class EQ(LogicalComparison):
identity = False
commutative = True
associative = False
nfunc_spec = ('equal', 2, 1)
def impl(self, x, y):
return x == y
......@@ -1118,6 +1123,7 @@ class NEQ(LogicalComparison):
identity = False
commutative = True
associative = False
nfunc_spec = ('not_equal', 2, 1)
def impl(self, x, y):
return x != y
......@@ -1132,6 +1138,8 @@ neq = NEQ()
class IsNan(FixedLogicalComparison):
nfunc_spec = ('isnan', 1, 1)
def impl(self, x):
return numpy.isnan(x)
......@@ -1145,6 +1153,8 @@ isnan = IsNan()
class IsInf(FixedLogicalComparison):
nfunc_spec = ('isinf', 1, 1)
def impl(self, x):
return numpy.isinf(x)
......@@ -1223,6 +1233,7 @@ inclosedrange = InRange(False, False)
class Switch(ScalarOp):
nin = 3
nfunc_spec = ('where', 3, 1)
def impl(self, cond, ift, iff):
if cond:
......@@ -1296,6 +1307,7 @@ class OR(BinaryBitOp):
identity = 0
commutative = True
associative = True
nfunc_spec = ('bitwise_or', 2, 1)
def impl(self, x, y):
return x | y
......@@ -1311,6 +1323,7 @@ class XOR(BinaryBitOp):
identity = 0
commutative = True
associative = True
nfunc_spec = ('bitwise_xor', 2, 1)
def impl(self, x, y):
return x ^ y
......@@ -1326,6 +1339,7 @@ class AND(BinaryBitOp):
identity = 1
commutative = True
associative = True
nfunc_spec = ('bitwise_and', 2, 1)
def impl(self, x, y):
return x & y
......@@ -1338,6 +1352,8 @@ and_ = AND()
class Invert(UnaryBitOp):
nfunc_spec = ('invert', 1, 1)
def impl(self, x):
return ~x
......@@ -1354,6 +1370,7 @@ invert = Invert()
class Maximum(BinaryScalarOp):
commutative = True
associative = True
nfunc_spec = ('maximum', 2, 1)
def impl(self, *inputs):
# The built-in max function don't support complex type
......@@ -1392,6 +1409,7 @@ maximum = Maximum(upcast_out, name='maximum')
class Minimum(BinaryScalarOp):
commutative = True
associative = True
nfunc_spec = ('minimum', 2, 1)
def impl(self, *inputs):
# The built-in min function don't support complex type
......@@ -1427,6 +1445,7 @@ class Add(ScalarOp):
identity = 0
commutative = True
associative = True
nfunc_spec = ('add', 2, 1)
def impl(self, *inputs):
return sum(inputs)
......@@ -1465,6 +1484,7 @@ class Mul(ScalarOp):
identity = 1
commutative = True
associative = True
nfunc_spec = ('multiply', 2, 1)
def impl(self, *inputs):
return numpy.product(inputs)
......@@ -1516,6 +1536,8 @@ mul = Mul(upcast_out, name='mul')
class Sub(BinaryScalarOp):
nfunc_spec = ('subtract', 2, 1)
def impl(self, x, y):
return x - y
......@@ -1604,6 +1626,8 @@ def div_proxy(x, y):
class TrueDiv(BinaryScalarOp):
nfunc_spec = ('true_divide', 2, 1)
def output_types(self, types):
if all(t in discrete_types for t in types):
return [get_scalar_type(config.floatX)]
......@@ -1659,6 +1683,7 @@ true_div = TrueDiv(upcast_out, name='true_div')
class IntDiv(BinaryScalarOp):
nfunc_spec = ('floor_divide', 2, 1)
complex_error = ComplexError(
"Theano does not support integer division (//) on "
"complex numbers, since numpy deprecated it.")
......@@ -1744,6 +1769,7 @@ def mod_check(x, y):
class Mod(BinaryScalarOp):
nfunc_spec = ('mod', 2, 1)
complex_error = ComplexError(
"Theano does not support the mod operator (%) on "
"complex numbers, since numpy deprecated it.")
......@@ -1828,6 +1854,8 @@ mod = Mod(upcast_out, name='mod')
class Pow(BinaryScalarOp):
nfunc_spec = ('power', 2, 1)
def impl(self, x, y):
return x ** y
......@@ -1903,6 +1931,8 @@ pow = Pow(upcast_out, name='pow')
class Clip(ScalarOp):
nin = 3
# The numpy.clip don't work correctly when the min is bigger then the max,
# So we do not use nfunc_spec = ('clip', 3, 1)
def impl(self, x, min, max):
if x < min:
......@@ -2086,6 +2116,8 @@ def cast(x, dtype):
class Abs(UnaryScalarOp):
nfunc_spec = ('abs', 1, 1)
def make_node(self, x):
inputs = [as_scalar(input) for input in [x]]
if inputs[0].type == complex64:
......@@ -2126,6 +2158,8 @@ abs_ = Abs(same_out)
class Sgn(UnaryScalarOp):
nfunc_spec = ('sign', 1, 1)
def impl(self, x):
# casting to output type is handled by filter
return numpy.sign(x)
......@@ -2162,6 +2196,8 @@ sgn = Sgn(same_out_nocomplex, name='sgn')
class Ceil(UnaryScalarOp):
nfunc_spec = ('ceil', 1, 1)
def impl(self, x):
return numpy.ceil(x)
......@@ -2183,6 +2219,8 @@ ceil = Ceil(same_out_nocomplex, name='ceil')
class Floor(UnaryScalarOp):
nfunc_spec = ('floor', 1, 1)
def impl(self, x):
return numpy.floor(x)
......@@ -2204,6 +2242,8 @@ floor = Floor(same_out_nocomplex, name='floor')
class Trunc(UnaryScalarOp):
nfunc_spec = ('trunc', 1, 1)
def impl(self, x):
return numpy.trunc(x)
......@@ -2227,6 +2267,8 @@ class RoundHalfToEven(UnaryScalarOp):
See http://en.wikipedia.org/wiki/Rounding for more details.
"""
nfunc_spec = ('around', 1, 1)
def impl(self, x):
return numpy.round(x)
......@@ -2348,6 +2390,8 @@ round_half_away_from_zero = RoundHalfAwayFromZero(same_out_float_only)
class Neg(UnaryScalarOp):
nfunc_spec = ('negative', 1, 1)
def impl(self, x):
return -x
......@@ -2413,6 +2457,7 @@ class Log(UnaryScalarOp):
log base e.
"""
nfunc_spec = ('log', 1, 1)
amd_float32 = "amd_vrsa_logf"
amd_float64 = "amd_vrda_log"
......@@ -2454,6 +2499,7 @@ class Log2(UnaryScalarOp):
log base 2.
"""
nfunc_spec = ('log2', 1, 1)
amd_float32 = "amd_vrsa_log2f"
amd_float64 = "amd_vrda_log2"
......@@ -2492,6 +2538,7 @@ class Log10(UnaryScalarOp):
log base 10.
"""
nfunc_spec = ('log10', 1, 1)
amd_float32 = "amd_vrsa_log10f"
amd_float64 = "amd_vrda_log10"
......@@ -2530,6 +2577,8 @@ class Log1p(UnaryScalarOp):
log(1+x).
"""
nfunc_spec = ('log1p', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.log1p will compute the result in
# half-precision (float16), where we want float32.
......@@ -2561,6 +2610,7 @@ log1p = Log1p(upgrade_to_float, name='log1p')
class Exp(UnaryScalarOp):
nfunc_spec = ('exp', 1, 1)
amd_float32 = "amd_vrsa_expf"
amd_float64 = "amd_vrda_exp"
......@@ -2595,6 +2645,8 @@ exp = Exp(upgrade_to_float, name='exp')
class Exp2(UnaryScalarOp):
nfunc_spec = ('exp2', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.exp2 will compute the result in
# half-precision (float16), where we want float32.
......@@ -2626,6 +2678,8 @@ exp2 = Exp2(upgrade_to_float, name='exp2')
class Expm1(UnaryScalarOp):
nfunc_spec = ('expm1', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.expm1 will compute the result in
# half-precision (float16), where we want float32.
......@@ -2660,6 +2714,8 @@ expm1 = Expm1(upgrade_to_float, name='expm1')
class Sqr(UnaryScalarOp):
nfunc_spec = ('square', 1, 1)
def impl(self, x):
return x * x
......@@ -2684,6 +2740,8 @@ sqr = Sqr(same_out, name='sqr')
class Sqrt(UnaryScalarOp):
nfunc_spec = ('sqrt', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.sqrt will compute the result in
# half-precision (float16), where we want float32.
......@@ -2715,6 +2773,8 @@ sqrt = Sqrt(upgrade_to_float, name='sqrt')
class Deg2Rad(UnaryScalarOp):
nfunc_spec = ('deg2rad', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.deg2rad will compute the result in
# half-precision (float16), where we want float32.
......@@ -2746,6 +2806,8 @@ deg2rad = Deg2Rad(upgrade_to_float, name='deg2rad')
class Rad2Deg(UnaryScalarOp):
nfunc_spec = ('rad2deg', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.rad2deg will compute the result in
# half-precision (float16), where we want float32.
......@@ -2777,6 +2839,7 @@ rad2deg = Rad2Deg(upgrade_to_float, name='rad2deg')
class Cos(UnaryScalarOp):
nfunc_spec = ('cos', 1, 1)
amd_float32 = "amd_vrsa_cosf"
amd_float64 = "amd_vrda_cos"
......@@ -2811,6 +2874,8 @@ cos = Cos(upgrade_to_float, name='cos')
class ArcCos(UnaryScalarOp):
nfunc_spec = ('arccos', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.arccos will compute the result in
# half-precision (float16), where we want float32.
......@@ -2842,6 +2907,7 @@ arccos = ArcCos(upgrade_to_float, name='arccos')
class Sin(UnaryScalarOp):
nfunc_spec = ('sin', 1, 1)
amd_float32 = "amd_vrsa_sinf"
amd_float64 = "amd_vrda_sin"
......@@ -2876,6 +2942,8 @@ sin = Sin(upgrade_to_float, name='sin')
class ArcSin(UnaryScalarOp):
nfunc_spec = ('arcsin', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.arcsin will compute the result in
# half-precision (float16), where we want float32.
......@@ -2907,6 +2975,8 @@ arcsin = ArcSin(upgrade_to_float, name='arcsin')
class Tan(UnaryScalarOp):
nfunc_spec = ('tan', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.tan will compute the result in
# half-precision (float16), where we want float32.
......@@ -2938,6 +3008,8 @@ tan = Tan(upgrade_to_float, name='tan')
class ArcTan(UnaryScalarOp):
nfunc_spec = ('arctan', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.arctan will compute the result in
# half-precision (float16), where we want float32.
......@@ -2969,6 +3041,8 @@ arctan = ArcTan(upgrade_to_float, name='arctan')
class ArcTan2(BinaryScalarOp):
nfunc_spec = ('arctan2', 1, 1)
def impl(self, y, x):
# If x and y are int8 or uint8, numpy.arctan2 will compute the result
# in half-precision (float16), where we want float32.
......@@ -3016,6 +3090,8 @@ class Cosh(UnaryScalarOp):
cosh(x) = (exp(x) + exp(-x)) / 2.
"""
nfunc_spec = ('cosh', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.cosh will compute the result in
# half-precision (float16), where we want float32.
......@@ -3047,6 +3123,8 @@ cosh = Cosh(upgrade_to_float, name='cosh')
class ArcCosh(UnaryScalarOp):
nfunc_spec = ('arccosh', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.arccosh will compute the result in
# half-precision (float16), where we want float32.
......@@ -3082,6 +3160,8 @@ class Sinh(UnaryScalarOp):
sinh(x) = (exp(x) - exp(-x)) / 2.
"""
nfunc_spec = ('sinh', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.sinh will compute the result in
# half-precision (float16), where we want float32.
......@@ -3113,6 +3193,8 @@ sinh = Sinh(upgrade_to_float, name='sinh')
class ArcSinh(UnaryScalarOp):
nfunc_spec = ('arcsinh', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.arcsinh will compute the result in
# half-precision (float16), where we want float32.
......@@ -3149,6 +3231,8 @@ class Tanh(UnaryScalarOp):
= (exp(2*x) - 1) / (exp(2*x) + 1).
"""
nfunc_spec = ('tanh', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.tanh will compute the result in
# half-precision (float16), where we want float32.
......@@ -3180,6 +3264,8 @@ tanh = Tanh(upgrade_to_float, name='tanh')
class ArcTanh(UnaryScalarOp):
nfunc_spec = ('arctanh', 1, 1)
def impl(self, x):
# If x is an int8 or uint8, numpy.arctanh will compute the result in
# half-precision (float16), where we want float32.
......@@ -3215,6 +3301,9 @@ class Real(UnaryScalarOp):
Extract the real coordinate of a complex number.
"""
# numpy.real(float32) return a view on the inputs.
# nfunc_spec = ('real', 1, 1)
def impl(self, x):
return numpy.real(x)
......@@ -3227,6 +3316,8 @@ real = Real(real_out, name='real')
class Imag(UnaryScalarOp):
nfunc_spec = ('imag', 1, 1)
def impl(self, x):
return numpy.imag(x)
......@@ -3244,6 +3335,8 @@ imag = Imag(real_out, name='imag')
class Angle(UnaryScalarOp):
nfunc_spec = ('angle', 1, 1)
def impl(self, x):
return numpy.angle(x)
......@@ -3303,6 +3396,8 @@ complex = Complex(name='complex')
class Conj(UnaryScalarOp):
nfunc_spec = ('conj', 1, 1)
def impl(self, x):
return numpy.conj(x)
conj = Conj(same_out, name='conj')
......
......@@ -1738,42 +1738,42 @@ def largest(*args):
# Comparison
##########################
@_scal_elemwise_with_nfunc('less', 2, 1)
@_scal_elemwise
def lt(a, b):
"""a < b"""
@_scal_elemwise_with_nfunc('greater', 2, 1)
@_scal_elemwise
def gt(a, b):
"""a > b"""
@_scal_elemwise_with_nfunc('less_equal', 2, 1)
@_scal_elemwise
def le(a, b):
"""a <= b"""
@_scal_elemwise_with_nfunc('greater_equal', 2, 1)
@_scal_elemwise
def ge(a, b):
"""a >= b"""
@_scal_elemwise_with_nfunc('equal', 2, 1)
@_scal_elemwise
def eq(a, b):
"""a == b"""
@_scal_elemwise_with_nfunc('not_equal', 2, 1)
@_scal_elemwise
def neq(a, b):
"""a != b"""
@_scal_elemwise_with_nfunc('isnan', 1, 1)
@_scal_elemwise
def isnan(a):
"""isnan(a)"""
@_scal_elemwise_with_nfunc('isinf', 1, 1)
@_scal_elemwise
def isinf(a):
"""isinf(a)"""
......@@ -1922,7 +1922,7 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
# Condition
##########################
@_scal_elemwise_with_nfunc('where', 3, 1)
@_scal_elemwise
def switch(cond, ift, iff):
"""if cond then ift else iff"""
......@@ -1932,25 +1932,25 @@ where = switch
##########################
@_scal_elemwise_with_nfunc('bitwise_and', 2, 1)
@_scal_elemwise
def and_(a, b):
"""bitwise a & b"""
bitwise_and = and_ # numpy name for it
@_scal_elemwise_with_nfunc('bitwise_or', 2, 1)
@_scal_elemwise
def or_(a, b):
"""bitwise a | b"""
bitwise_or = or_ # numpy name for it
@_scal_elemwise_with_nfunc('bitwise_xor', 2, 1)
@_scal_elemwise
def xor(a, b):
"""bitwise a ^ b"""
bitwise_xor = xor # numpy name for it
@_scal_elemwise_with_nfunc('invert', 1, 1)
@_scal_elemwise
def invert(a):
"""bitwise ~a"""
bitwise_not = invert # numpy alias for it
......@@ -1960,7 +1960,7 @@ bitwise_not = invert # numpy alias for it
# Math
##########################
@_scal_elemwise_with_nfunc('abs', 1, 1)
@_scal_elemwise
def abs_(a):
"""|`a`|
......@@ -1972,22 +1972,22 @@ def abs_(a):
pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000)))
@_scal_elemwise_with_nfunc('exp', 1, 1)
@_scal_elemwise
def exp(a):
"""e^`a`"""
@_scal_elemwise_with_nfunc('exp2', 1, 1)
@_scal_elemwise
def exp2(a):
"""2^`a`"""
@_scal_elemwise_with_nfunc('expm1', 1, 1)
@_scal_elemwise
def expm1(a):
"""e^`a` - 1"""
@_scal_elemwise_with_nfunc('negative', 1, 1)
@_scal_elemwise
def neg(a):
"""-a"""
......@@ -1999,42 +1999,42 @@ def inv(a):
"""1.0/a"""
@_scal_elemwise_with_nfunc('log', 1, 1)
@_scal_elemwise
def log(a):
"""base e logarithm of a"""
@_scal_elemwise_with_nfunc('log2', 1, 1)
@_scal_elemwise
def log2(a):
"""base 2 logarithm of a"""
@_scal_elemwise_with_nfunc('log10', 1, 1)
@_scal_elemwise
def log10(a):
"""base 10 logarithm of a"""
@_scal_elemwise_with_nfunc('log1p', 1, 1)
@_scal_elemwise
def log1p(a):
"""log(1+a)"""
@_scal_elemwise_with_nfunc('sign', 1, 1)
@_scal_elemwise
def sgn(a):
"""sign of a"""
@_scal_elemwise_with_nfunc('ceil', 1, 1)
@_scal_elemwise
def ceil(a):
"""ceiling of a"""
@_scal_elemwise_with_nfunc('floor', 1, 1)
@_scal_elemwise
def floor(a):
"""floor of a"""
@_scal_elemwise_with_nfunc('trunc', 1, 1)
@_scal_elemwise
def trunc(a):
"""trunc of a"""
......@@ -2056,7 +2056,7 @@ def round(a, mode="half_away_from_zero"):
raise Exception("round mode %s is not implemented." % mode)
@_scal_elemwise_with_nfunc('around', 1, 1)
@_scal_elemwise
def round_half_to_even(a):
"""round_half_to_even(a)"""
......@@ -2066,7 +2066,7 @@ def round_half_away_from_zero(a):
"""round_half_away_from_zero(a)"""
@_scal_elemwise_with_nfunc('square', 1, 1)
@_scal_elemwise
def sqr(a):
"""square of a"""
......@@ -2075,82 +2075,82 @@ def sqr(a):
square = sqr
@_scal_elemwise_with_nfunc('sqrt', 1, 1)
@_scal_elemwise
def sqrt(a):
"""square root of a"""
@_scal_elemwise_with_nfunc('deg2rad', 1, 1)
@_scal_elemwise
def deg2rad(a):
"""convert degree a to radian"""
@_scal_elemwise_with_nfunc('rad2deg', 1, 1)
@_scal_elemwise
def rad2deg(a):
"""convert radian a to degree"""
@_scal_elemwise_with_nfunc('cos', 1, 1)
@_scal_elemwise
def cos(a):
"""cosine of a"""
@_scal_elemwise_with_nfunc('arccos', 1, 1)
@_scal_elemwise
def arccos(a):
"""arccosine of a"""
@_scal_elemwise_with_nfunc('sin', 1, 1)
@_scal_elemwise
def sin(a):
"""sine of a"""
@_scal_elemwise_with_nfunc('arcsin', 1, 1)
@_scal_elemwise
def arcsin(a):
"""arcsine of a"""
@_scal_elemwise_with_nfunc('tan', 1, 1)
@_scal_elemwise
def tan(a):
"""tangent of a"""
@_scal_elemwise_with_nfunc('arctan', 1, 1)
@_scal_elemwise
def arctan(a):
"""arctangent of a"""
@_scal_elemwise_with_nfunc('arctan2', 1, 1)
@_scal_elemwise
def arctan2(a, b):
"""arctangent of a / b"""
@_scal_elemwise_with_nfunc('cosh', 1, 1)
@_scal_elemwise
def cosh(a):
"""hyperbolic cosine of a"""
@_scal_elemwise_with_nfunc('arccosh', 1, 1)
@_scal_elemwise
def arccosh(a):
"""hyperbolic arc cosine of a"""
@_scal_elemwise_with_nfunc('sinh', 1, 1)
@_scal_elemwise
def sinh(a):
"""hyperbolic sine of a"""
@_scal_elemwise_with_nfunc('arcsinh', 1, 1)
@_scal_elemwise
def arcsinh(a):
"""hyperbolic arc sine of a"""
@_scal_elemwise_with_nfunc('tanh', 1, 1)
@_scal_elemwise
def tanh(a):
"""hyperbolic tangent of a"""
@_scal_elemwise_with_nfunc('arctanh', 1, 1)
@_scal_elemwise
def arctanh(a):
"""hyperbolic arc tangent of a"""
......@@ -2200,21 +2200,19 @@ def chi2sf(x, k):
"""chi squared survival function"""
# numpy.real(float32) return a view on the inputs.
# @_scal_elemwise_with_nfunc('real', 1, 1)
@_scal_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
_tensor_py_operators.real = property(real)
@_scal_elemwise_with_nfunc('imag', 1, 1)
@_scal_elemwise
def imag(z):
"""Return imaginary component of complex-valued tensor `z`"""
_tensor_py_operators.imag = property(imag)
@_scal_elemwise_with_nfunc('angle', 1, 1)
@_scal_elemwise
def angle(z):
"""Return polar-coordinate angle of complex-valued tensor `z`"""
......@@ -2224,7 +2222,7 @@ def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components"""
@_scal_elemwise_with_nfunc('conj', 1, 1)
@_scal_elemwise
def conj(z):
"""Return the complex conjugate of `z`."""
......@@ -3166,13 +3164,13 @@ setdefault = default # legacy
##########################
# Arithmetics
##########################
@_scal_elemwise_with_nfunc('maximum', 2, 1)
@_scal_elemwise
def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor"""
# see decorator for function body
@_scal_elemwise_with_nfunc('minimum', 2, 1)
@_scal_elemwise
def minimum(x, y):
"""elemwise minimum. See min for the minimum in one tensor"""
# see decorator for function body
......@@ -3191,31 +3189,31 @@ def divmod(x, y):
return floor_div(x, y), mod_check(x, y)
@_scal_elemwise_with_nfunc('add', 2, 1)
@_scal_elemwise
def add(a, *other_terms):
"""elementwise addition"""
# see decorator for function body
@_scal_elemwise_with_nfunc('subtract', 2, 1)
@_scal_elemwise
def sub(a, b):
"""elementwise subtraction"""
# see decorator for function body
@_scal_elemwise_with_nfunc('multiply', 2, 1)
@_scal_elemwise
def mul(a, *other_terms):
"""elementwise multiplication"""
# see decorator for function body
@_scal_elemwise_with_nfunc('true_divide', 2, 1)
@_scal_elemwise
def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)"""
# see decorator for function body
@_scal_elemwise_with_nfunc('floor_divide', 2, 1)
@_scal_elemwise
def int_div(a, b):
"""elementwise [floor] division (inverse of multiplication)"""
# see decorator for function body
......@@ -3256,20 +3254,18 @@ def mod_check(x, y):
return mod(x, y)
@_scal_elemwise_with_nfunc('mod', 2, 1)
@_scal_elemwise
def mod(a, b):
"""elementwise modulo"""
# see decorator for function body
@_scal_elemwise_with_nfunc('power', 2, 1)
@_scal_elemwise
def pow(a, b):
"""elementwise power"""
# see decorator for function body
# The numpy.clip don't work correctly when the min is bigger then the max,
# So we do not use @scal_elemwise_with_nfunc('clip', 3, 1)
@_scal_elemwise
def clip(x, min, max):
"""
......
......@@ -503,6 +503,8 @@ class Elemwise(OpenMPOp):
self.ufunc = None
self.nfunc = None
if nfunc_spec is None:
nfunc_spec = getattr(scalar_op, 'nfunc_spec', None)
self.nfunc_spec = nfunc_spec
if nfunc_spec:
self.nfunc = getattr(numpy, nfunc_spec[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论