提交 7ef44dfd authored 作者: Frederic Bastien's avatar Frederic Bastien

Move nfunc_spec to Scalar op, This allow to always have it even when we use Elemwise(a_scalar_op).

This happend frequently in the grad computation of elemwise. This could speed up DebugMode. This fix tests crashX
上级 ed337d4e
...@@ -1031,6 +1031,7 @@ class LT(LogicalComparison): ...@@ -1031,6 +1031,7 @@ class LT(LogicalComparison):
identity = False identity = False
commutative = False commutative = False
associative = False associative = False
nfunc_spec = ('less', 2, 1)
def impl(self, x, y): def impl(self, x, y):
# built-in < don't support complex # built-in < don't support complex
...@@ -1049,6 +1050,7 @@ class GT(LogicalComparison): ...@@ -1049,6 +1050,7 @@ class GT(LogicalComparison):
identity = False identity = False
commutative = False commutative = False
associative = False associative = False
nfunc_spec = ('greater', 2, 1)
def impl(self, x, y): def impl(self, x, y):
# built-in > don't support complex # built-in > don't support complex
...@@ -1067,6 +1069,7 @@ class LE(LogicalComparison): ...@@ -1067,6 +1069,7 @@ class LE(LogicalComparison):
identity = False identity = False
commutative = False commutative = False
associative = False associative = False
nfunc_spec = ('less_equal', 2, 1)
def impl(self, x, y): def impl(self, x, y):
# built-in <= don't support complex # built-in <= don't support complex
...@@ -1085,6 +1088,7 @@ class GE(LogicalComparison): ...@@ -1085,6 +1088,7 @@ class GE(LogicalComparison):
identity = False identity = False
commutative = False commutative = False
associative = False associative = False
nfunc_spec = ('greater_equal', 2, 1)
def impl(self, x, y): def impl(self, x, y):
# built-in >= don't support complex # built-in >= don't support complex
...@@ -1103,6 +1107,7 @@ class EQ(LogicalComparison): ...@@ -1103,6 +1107,7 @@ class EQ(LogicalComparison):
identity = False identity = False
commutative = True commutative = True
associative = False associative = False
nfunc_spec = ('equal', 2, 1)
def impl(self, x, y): def impl(self, x, y):
return x == y return x == y
...@@ -1118,6 +1123,7 @@ class NEQ(LogicalComparison): ...@@ -1118,6 +1123,7 @@ class NEQ(LogicalComparison):
identity = False identity = False
commutative = True commutative = True
associative = False associative = False
nfunc_spec = ('not_equal', 2, 1)
def impl(self, x, y): def impl(self, x, y):
return x != y return x != y
...@@ -1132,6 +1138,8 @@ neq = NEQ() ...@@ -1132,6 +1138,8 @@ neq = NEQ()
class IsNan(FixedLogicalComparison): class IsNan(FixedLogicalComparison):
nfunc_spec = ('isnan', 1, 1)
def impl(self, x): def impl(self, x):
return numpy.isnan(x) return numpy.isnan(x)
...@@ -1145,6 +1153,8 @@ isnan = IsNan() ...@@ -1145,6 +1153,8 @@ isnan = IsNan()
class IsInf(FixedLogicalComparison): class IsInf(FixedLogicalComparison):
nfunc_spec = ('isinf', 1, 1)
def impl(self, x): def impl(self, x):
return numpy.isinf(x) return numpy.isinf(x)
...@@ -1223,6 +1233,7 @@ inclosedrange = InRange(False, False) ...@@ -1223,6 +1233,7 @@ inclosedrange = InRange(False, False)
class Switch(ScalarOp): class Switch(ScalarOp):
nin = 3 nin = 3
nfunc_spec = ('where', 3, 1)
def impl(self, cond, ift, iff): def impl(self, cond, ift, iff):
if cond: if cond:
...@@ -1296,6 +1307,7 @@ class OR(BinaryBitOp): ...@@ -1296,6 +1307,7 @@ class OR(BinaryBitOp):
identity = 0 identity = 0
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ('bitwise_or', 2, 1)
def impl(self, x, y): def impl(self, x, y):
return x | y return x | y
...@@ -1311,6 +1323,7 @@ class XOR(BinaryBitOp): ...@@ -1311,6 +1323,7 @@ class XOR(BinaryBitOp):
identity = 0 identity = 0
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ('bitwise_xor', 2, 1)
def impl(self, x, y): def impl(self, x, y):
return x ^ y return x ^ y
...@@ -1326,6 +1339,7 @@ class AND(BinaryBitOp): ...@@ -1326,6 +1339,7 @@ class AND(BinaryBitOp):
identity = 1 identity = 1
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ('bitwise_and', 2, 1)
def impl(self, x, y): def impl(self, x, y):
return x & y return x & y
...@@ -1338,6 +1352,8 @@ and_ = AND() ...@@ -1338,6 +1352,8 @@ and_ = AND()
class Invert(UnaryBitOp): class Invert(UnaryBitOp):
nfunc_spec = ('invert', 1, 1)
def impl(self, x): def impl(self, x):
return ~x return ~x
...@@ -1354,6 +1370,7 @@ invert = Invert() ...@@ -1354,6 +1370,7 @@ invert = Invert()
class Maximum(BinaryScalarOp): class Maximum(BinaryScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ('maximum', 2, 1)
def impl(self, *inputs): def impl(self, *inputs):
# The built-in max function don't support complex type # The built-in max function don't support complex type
...@@ -1392,6 +1409,7 @@ maximum = Maximum(upcast_out, name='maximum') ...@@ -1392,6 +1409,7 @@ maximum = Maximum(upcast_out, name='maximum')
class Minimum(BinaryScalarOp): class Minimum(BinaryScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ('minimum', 2, 1)
def impl(self, *inputs): def impl(self, *inputs):
# The built-in min function don't support complex type # The built-in min function don't support complex type
...@@ -1427,6 +1445,7 @@ class Add(ScalarOp): ...@@ -1427,6 +1445,7 @@ class Add(ScalarOp):
identity = 0 identity = 0
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ('add', 2, 1)
def impl(self, *inputs): def impl(self, *inputs):
return sum(inputs) return sum(inputs)
...@@ -1465,6 +1484,7 @@ class Mul(ScalarOp): ...@@ -1465,6 +1484,7 @@ class Mul(ScalarOp):
identity = 1 identity = 1
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ('multiply', 2, 1)
def impl(self, *inputs): def impl(self, *inputs):
return numpy.product(inputs) return numpy.product(inputs)
...@@ -1516,6 +1536,8 @@ mul = Mul(upcast_out, name='mul') ...@@ -1516,6 +1536,8 @@ mul = Mul(upcast_out, name='mul')
class Sub(BinaryScalarOp): class Sub(BinaryScalarOp):
nfunc_spec = ('subtract', 2, 1)
def impl(self, x, y): def impl(self, x, y):
return x - y return x - y
...@@ -1604,6 +1626,8 @@ def div_proxy(x, y): ...@@ -1604,6 +1626,8 @@ def div_proxy(x, y):
class TrueDiv(BinaryScalarOp): class TrueDiv(BinaryScalarOp):
nfunc_spec = ('true_divide', 2, 1)
def output_types(self, types): def output_types(self, types):
if all(t in discrete_types for t in types): if all(t in discrete_types for t in types):
return [get_scalar_type(config.floatX)] return [get_scalar_type(config.floatX)]
...@@ -1659,6 +1683,7 @@ true_div = TrueDiv(upcast_out, name='true_div') ...@@ -1659,6 +1683,7 @@ true_div = TrueDiv(upcast_out, name='true_div')
class IntDiv(BinaryScalarOp): class IntDiv(BinaryScalarOp):
nfunc_spec = ('floor_divide', 2, 1)
complex_error = ComplexError( complex_error = ComplexError(
"Theano does not support integer division (//) on " "Theano does not support integer division (//) on "
"complex numbers, since numpy deprecated it.") "complex numbers, since numpy deprecated it.")
...@@ -1744,6 +1769,7 @@ def mod_check(x, y): ...@@ -1744,6 +1769,7 @@ def mod_check(x, y):
class Mod(BinaryScalarOp): class Mod(BinaryScalarOp):
nfunc_spec = ('mod', 2, 1)
complex_error = ComplexError( complex_error = ComplexError(
"Theano does not support the mod operator (%) on " "Theano does not support the mod operator (%) on "
"complex numbers, since numpy deprecated it.") "complex numbers, since numpy deprecated it.")
...@@ -1828,6 +1854,8 @@ mod = Mod(upcast_out, name='mod') ...@@ -1828,6 +1854,8 @@ mod = Mod(upcast_out, name='mod')
class Pow(BinaryScalarOp): class Pow(BinaryScalarOp):
nfunc_spec = ('power', 2, 1)
def impl(self, x, y): def impl(self, x, y):
return x ** y return x ** y
...@@ -1903,6 +1931,8 @@ pow = Pow(upcast_out, name='pow') ...@@ -1903,6 +1931,8 @@ pow = Pow(upcast_out, name='pow')
class Clip(ScalarOp): class Clip(ScalarOp):
nin = 3 nin = 3
# The numpy.clip don't work correctly when the min is bigger then the max,
# So we do not use nfunc_spec = ('clip', 3, 1)
def impl(self, x, min, max): def impl(self, x, min, max):
if x < min: if x < min:
...@@ -2086,6 +2116,8 @@ def cast(x, dtype): ...@@ -2086,6 +2116,8 @@ def cast(x, dtype):
class Abs(UnaryScalarOp): class Abs(UnaryScalarOp):
nfunc_spec = ('abs', 1, 1)
def make_node(self, x): def make_node(self, x):
inputs = [as_scalar(input) for input in [x]] inputs = [as_scalar(input) for input in [x]]
if inputs[0].type == complex64: if inputs[0].type == complex64:
...@@ -2126,6 +2158,8 @@ abs_ = Abs(same_out) ...@@ -2126,6 +2158,8 @@ abs_ = Abs(same_out)
class Sgn(UnaryScalarOp): class Sgn(UnaryScalarOp):
nfunc_spec = ('sign', 1, 1)
def impl(self, x): def impl(self, x):
# casting to output type is handled by filter # casting to output type is handled by filter
return numpy.sign(x) return numpy.sign(x)
...@@ -2162,6 +2196,8 @@ sgn = Sgn(same_out_nocomplex, name='sgn') ...@@ -2162,6 +2196,8 @@ sgn = Sgn(same_out_nocomplex, name='sgn')
class Ceil(UnaryScalarOp): class Ceil(UnaryScalarOp):
nfunc_spec = ('ceil', 1, 1)
def impl(self, x): def impl(self, x):
return numpy.ceil(x) return numpy.ceil(x)
...@@ -2183,6 +2219,8 @@ ceil = Ceil(same_out_nocomplex, name='ceil') ...@@ -2183,6 +2219,8 @@ ceil = Ceil(same_out_nocomplex, name='ceil')
class Floor(UnaryScalarOp): class Floor(UnaryScalarOp):
nfunc_spec = ('floor', 1, 1)
def impl(self, x): def impl(self, x):
return numpy.floor(x) return numpy.floor(x)
...@@ -2204,6 +2242,8 @@ floor = Floor(same_out_nocomplex, name='floor') ...@@ -2204,6 +2242,8 @@ floor = Floor(same_out_nocomplex, name='floor')
class Trunc(UnaryScalarOp): class Trunc(UnaryScalarOp):
nfunc_spec = ('trunc', 1, 1)
def impl(self, x): def impl(self, x):
return numpy.trunc(x) return numpy.trunc(x)
...@@ -2227,6 +2267,8 @@ class RoundHalfToEven(UnaryScalarOp): ...@@ -2227,6 +2267,8 @@ class RoundHalfToEven(UnaryScalarOp):
See http://en.wikipedia.org/wiki/Rounding for more details. See http://en.wikipedia.org/wiki/Rounding for more details.
""" """
nfunc_spec = ('around', 1, 1)
def impl(self, x): def impl(self, x):
return numpy.round(x) return numpy.round(x)
...@@ -2348,6 +2390,8 @@ round_half_away_from_zero = RoundHalfAwayFromZero(same_out_float_only) ...@@ -2348,6 +2390,8 @@ round_half_away_from_zero = RoundHalfAwayFromZero(same_out_float_only)
class Neg(UnaryScalarOp): class Neg(UnaryScalarOp):
nfunc_spec = ('negative', 1, 1)
def impl(self, x): def impl(self, x):
return -x return -x
...@@ -2413,6 +2457,7 @@ class Log(UnaryScalarOp): ...@@ -2413,6 +2457,7 @@ class Log(UnaryScalarOp):
log base e. log base e.
""" """
nfunc_spec = ('log', 1, 1)
amd_float32 = "amd_vrsa_logf" amd_float32 = "amd_vrsa_logf"
amd_float64 = "amd_vrda_log" amd_float64 = "amd_vrda_log"
...@@ -2454,6 +2499,7 @@ class Log2(UnaryScalarOp): ...@@ -2454,6 +2499,7 @@ class Log2(UnaryScalarOp):
log base 2. log base 2.
""" """
nfunc_spec = ('log2', 1, 1)
amd_float32 = "amd_vrsa_log2f" amd_float32 = "amd_vrsa_log2f"
amd_float64 = "amd_vrda_log2" amd_float64 = "amd_vrda_log2"
...@@ -2492,6 +2538,7 @@ class Log10(UnaryScalarOp): ...@@ -2492,6 +2538,7 @@ class Log10(UnaryScalarOp):
log base 10. log base 10.
""" """
nfunc_spec = ('log10', 1, 1)
amd_float32 = "amd_vrsa_log10f" amd_float32 = "amd_vrsa_log10f"
amd_float64 = "amd_vrda_log10" amd_float64 = "amd_vrda_log10"
...@@ -2530,6 +2577,8 @@ class Log1p(UnaryScalarOp): ...@@ -2530,6 +2577,8 @@ class Log1p(UnaryScalarOp):
log(1+x). log(1+x).
""" """
nfunc_spec = ('log1p', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.log1p will compute the result in # If x is an int8 or uint8, numpy.log1p will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -2561,6 +2610,7 @@ log1p = Log1p(upgrade_to_float, name='log1p') ...@@ -2561,6 +2610,7 @@ log1p = Log1p(upgrade_to_float, name='log1p')
class Exp(UnaryScalarOp): class Exp(UnaryScalarOp):
nfunc_spec = ('exp', 1, 1)
amd_float32 = "amd_vrsa_expf" amd_float32 = "amd_vrsa_expf"
amd_float64 = "amd_vrda_exp" amd_float64 = "amd_vrda_exp"
...@@ -2595,6 +2645,8 @@ exp = Exp(upgrade_to_float, name='exp') ...@@ -2595,6 +2645,8 @@ exp = Exp(upgrade_to_float, name='exp')
class Exp2(UnaryScalarOp): class Exp2(UnaryScalarOp):
nfunc_spec = ('exp2', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.exp2 will compute the result in # If x is an int8 or uint8, numpy.exp2 will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -2626,6 +2678,8 @@ exp2 = Exp2(upgrade_to_float, name='exp2') ...@@ -2626,6 +2678,8 @@ exp2 = Exp2(upgrade_to_float, name='exp2')
class Expm1(UnaryScalarOp): class Expm1(UnaryScalarOp):
nfunc_spec = ('expm1', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.expm1 will compute the result in # If x is an int8 or uint8, numpy.expm1 will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -2660,6 +2714,8 @@ expm1 = Expm1(upgrade_to_float, name='expm1') ...@@ -2660,6 +2714,8 @@ expm1 = Expm1(upgrade_to_float, name='expm1')
class Sqr(UnaryScalarOp): class Sqr(UnaryScalarOp):
nfunc_spec = ('square', 1, 1)
def impl(self, x): def impl(self, x):
return x * x return x * x
...@@ -2684,6 +2740,8 @@ sqr = Sqr(same_out, name='sqr') ...@@ -2684,6 +2740,8 @@ sqr = Sqr(same_out, name='sqr')
class Sqrt(UnaryScalarOp): class Sqrt(UnaryScalarOp):
nfunc_spec = ('sqrt', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.sqrt will compute the result in # If x is an int8 or uint8, numpy.sqrt will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -2715,6 +2773,8 @@ sqrt = Sqrt(upgrade_to_float, name='sqrt') ...@@ -2715,6 +2773,8 @@ sqrt = Sqrt(upgrade_to_float, name='sqrt')
class Deg2Rad(UnaryScalarOp): class Deg2Rad(UnaryScalarOp):
nfunc_spec = ('deg2rad', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.deg2rad will compute the result in # If x is an int8 or uint8, numpy.deg2rad will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -2746,6 +2806,8 @@ deg2rad = Deg2Rad(upgrade_to_float, name='deg2rad') ...@@ -2746,6 +2806,8 @@ deg2rad = Deg2Rad(upgrade_to_float, name='deg2rad')
class Rad2Deg(UnaryScalarOp): class Rad2Deg(UnaryScalarOp):
nfunc_spec = ('rad2deg', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.rad2deg will compute the result in # If x is an int8 or uint8, numpy.rad2deg will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -2777,6 +2839,7 @@ rad2deg = Rad2Deg(upgrade_to_float, name='rad2deg') ...@@ -2777,6 +2839,7 @@ rad2deg = Rad2Deg(upgrade_to_float, name='rad2deg')
class Cos(UnaryScalarOp): class Cos(UnaryScalarOp):
nfunc_spec = ('cos', 1, 1)
amd_float32 = "amd_vrsa_cosf" amd_float32 = "amd_vrsa_cosf"
amd_float64 = "amd_vrda_cos" amd_float64 = "amd_vrda_cos"
...@@ -2811,6 +2874,8 @@ cos = Cos(upgrade_to_float, name='cos') ...@@ -2811,6 +2874,8 @@ cos = Cos(upgrade_to_float, name='cos')
class ArcCos(UnaryScalarOp): class ArcCos(UnaryScalarOp):
nfunc_spec = ('arccos', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.arccos will compute the result in # If x is an int8 or uint8, numpy.arccos will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -2842,6 +2907,7 @@ arccos = ArcCos(upgrade_to_float, name='arccos') ...@@ -2842,6 +2907,7 @@ arccos = ArcCos(upgrade_to_float, name='arccos')
class Sin(UnaryScalarOp): class Sin(UnaryScalarOp):
nfunc_spec = ('sin', 1, 1)
amd_float32 = "amd_vrsa_sinf" amd_float32 = "amd_vrsa_sinf"
amd_float64 = "amd_vrda_sin" amd_float64 = "amd_vrda_sin"
...@@ -2876,6 +2942,8 @@ sin = Sin(upgrade_to_float, name='sin') ...@@ -2876,6 +2942,8 @@ sin = Sin(upgrade_to_float, name='sin')
class ArcSin(UnaryScalarOp): class ArcSin(UnaryScalarOp):
nfunc_spec = ('arcsin', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.arcsin will compute the result in # If x is an int8 or uint8, numpy.arcsin will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -2907,6 +2975,8 @@ arcsin = ArcSin(upgrade_to_float, name='arcsin') ...@@ -2907,6 +2975,8 @@ arcsin = ArcSin(upgrade_to_float, name='arcsin')
class Tan(UnaryScalarOp): class Tan(UnaryScalarOp):
nfunc_spec = ('tan', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.tan will compute the result in # If x is an int8 or uint8, numpy.tan will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -2938,6 +3008,8 @@ tan = Tan(upgrade_to_float, name='tan') ...@@ -2938,6 +3008,8 @@ tan = Tan(upgrade_to_float, name='tan')
class ArcTan(UnaryScalarOp): class ArcTan(UnaryScalarOp):
nfunc_spec = ('arctan', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.arctan will compute the result in # If x is an int8 or uint8, numpy.arctan will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -2969,6 +3041,8 @@ arctan = ArcTan(upgrade_to_float, name='arctan') ...@@ -2969,6 +3041,8 @@ arctan = ArcTan(upgrade_to_float, name='arctan')
class ArcTan2(BinaryScalarOp): class ArcTan2(BinaryScalarOp):
nfunc_spec = ('arctan2', 1, 1)
def impl(self, y, x): def impl(self, y, x):
# If x and y are int8 or uint8, numpy.arctan2 will compute the result # If x and y are int8 or uint8, numpy.arctan2 will compute the result
# in half-precision (float16), where we want float32. # in half-precision (float16), where we want float32.
...@@ -3016,6 +3090,8 @@ class Cosh(UnaryScalarOp): ...@@ -3016,6 +3090,8 @@ class Cosh(UnaryScalarOp):
cosh(x) = (exp(x) + exp(-x)) / 2. cosh(x) = (exp(x) + exp(-x)) / 2.
""" """
nfunc_spec = ('cosh', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.cosh will compute the result in # If x is an int8 or uint8, numpy.cosh will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -3047,6 +3123,8 @@ cosh = Cosh(upgrade_to_float, name='cosh') ...@@ -3047,6 +3123,8 @@ cosh = Cosh(upgrade_to_float, name='cosh')
class ArcCosh(UnaryScalarOp): class ArcCosh(UnaryScalarOp):
nfunc_spec = ('arccosh', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.arccosh will compute the result in # If x is an int8 or uint8, numpy.arccosh will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -3082,6 +3160,8 @@ class Sinh(UnaryScalarOp): ...@@ -3082,6 +3160,8 @@ class Sinh(UnaryScalarOp):
sinh(x) = (exp(x) - exp(-x)) / 2. sinh(x) = (exp(x) - exp(-x)) / 2.
""" """
nfunc_spec = ('sinh', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.sinh will compute the result in # If x is an int8 or uint8, numpy.sinh will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -3113,6 +3193,8 @@ sinh = Sinh(upgrade_to_float, name='sinh') ...@@ -3113,6 +3193,8 @@ sinh = Sinh(upgrade_to_float, name='sinh')
class ArcSinh(UnaryScalarOp): class ArcSinh(UnaryScalarOp):
nfunc_spec = ('arcsinh', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.arcsinh will compute the result in # If x is an int8 or uint8, numpy.arcsinh will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -3149,6 +3231,8 @@ class Tanh(UnaryScalarOp): ...@@ -3149,6 +3231,8 @@ class Tanh(UnaryScalarOp):
= (exp(2*x) - 1) / (exp(2*x) + 1). = (exp(2*x) - 1) / (exp(2*x) + 1).
""" """
nfunc_spec = ('tanh', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.tanh will compute the result in # If x is an int8 or uint8, numpy.tanh will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -3180,6 +3264,8 @@ tanh = Tanh(upgrade_to_float, name='tanh') ...@@ -3180,6 +3264,8 @@ tanh = Tanh(upgrade_to_float, name='tanh')
class ArcTanh(UnaryScalarOp): class ArcTanh(UnaryScalarOp):
nfunc_spec = ('arctanh', 1, 1)
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.arctanh will compute the result in # If x is an int8 or uint8, numpy.arctanh will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -3215,6 +3301,9 @@ class Real(UnaryScalarOp): ...@@ -3215,6 +3301,9 @@ class Real(UnaryScalarOp):
Extract the real coordinate of a complex number. Extract the real coordinate of a complex number.
""" """
# numpy.real(float32) return a view on the inputs.
# nfunc_spec = ('real', 1, 1)
def impl(self, x): def impl(self, x):
return numpy.real(x) return numpy.real(x)
...@@ -3227,6 +3316,8 @@ real = Real(real_out, name='real') ...@@ -3227,6 +3316,8 @@ real = Real(real_out, name='real')
class Imag(UnaryScalarOp): class Imag(UnaryScalarOp):
nfunc_spec = ('imag', 1, 1)
def impl(self, x): def impl(self, x):
return numpy.imag(x) return numpy.imag(x)
...@@ -3244,6 +3335,8 @@ imag = Imag(real_out, name='imag') ...@@ -3244,6 +3335,8 @@ imag = Imag(real_out, name='imag')
class Angle(UnaryScalarOp): class Angle(UnaryScalarOp):
nfunc_spec = ('angle', 1, 1)
def impl(self, x): def impl(self, x):
return numpy.angle(x) return numpy.angle(x)
...@@ -3303,6 +3396,8 @@ complex = Complex(name='complex') ...@@ -3303,6 +3396,8 @@ complex = Complex(name='complex')
class Conj(UnaryScalarOp): class Conj(UnaryScalarOp):
nfunc_spec = ('conj', 1, 1)
def impl(self, x): def impl(self, x):
return numpy.conj(x) return numpy.conj(x)
conj = Conj(same_out, name='conj') conj = Conj(same_out, name='conj')
......
...@@ -1738,42 +1738,42 @@ def largest(*args): ...@@ -1738,42 +1738,42 @@ def largest(*args):
# Comparison # Comparison
########################## ##########################
@_scal_elemwise_with_nfunc('less', 2, 1) @_scal_elemwise
def lt(a, b): def lt(a, b):
"""a < b""" """a < b"""
@_scal_elemwise_with_nfunc('greater', 2, 1) @_scal_elemwise
def gt(a, b): def gt(a, b):
"""a > b""" """a > b"""
@_scal_elemwise_with_nfunc('less_equal', 2, 1) @_scal_elemwise
def le(a, b): def le(a, b):
"""a <= b""" """a <= b"""
@_scal_elemwise_with_nfunc('greater_equal', 2, 1) @_scal_elemwise
def ge(a, b): def ge(a, b):
"""a >= b""" """a >= b"""
@_scal_elemwise_with_nfunc('equal', 2, 1) @_scal_elemwise
def eq(a, b): def eq(a, b):
"""a == b""" """a == b"""
@_scal_elemwise_with_nfunc('not_equal', 2, 1) @_scal_elemwise
def neq(a, b): def neq(a, b):
"""a != b""" """a != b"""
@_scal_elemwise_with_nfunc('isnan', 1, 1) @_scal_elemwise
def isnan(a): def isnan(a):
"""isnan(a)""" """isnan(a)"""
@_scal_elemwise_with_nfunc('isinf', 1, 1) @_scal_elemwise
def isinf(a): def isinf(a):
"""isinf(a)""" """isinf(a)"""
...@@ -1922,7 +1922,7 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): ...@@ -1922,7 +1922,7 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
# Condition # Condition
########################## ##########################
@_scal_elemwise_with_nfunc('where', 3, 1) @_scal_elemwise
def switch(cond, ift, iff): def switch(cond, ift, iff):
"""if cond then ift else iff""" """if cond then ift else iff"""
...@@ -1932,25 +1932,25 @@ where = switch ...@@ -1932,25 +1932,25 @@ where = switch
########################## ##########################
@_scal_elemwise_with_nfunc('bitwise_and', 2, 1) @_scal_elemwise
def and_(a, b): def and_(a, b):
"""bitwise a & b""" """bitwise a & b"""
bitwise_and = and_ # numpy name for it bitwise_and = and_ # numpy name for it
@_scal_elemwise_with_nfunc('bitwise_or', 2, 1) @_scal_elemwise
def or_(a, b): def or_(a, b):
"""bitwise a | b""" """bitwise a | b"""
bitwise_or = or_ # numpy name for it bitwise_or = or_ # numpy name for it
@_scal_elemwise_with_nfunc('bitwise_xor', 2, 1) @_scal_elemwise
def xor(a, b): def xor(a, b):
"""bitwise a ^ b""" """bitwise a ^ b"""
bitwise_xor = xor # numpy name for it bitwise_xor = xor # numpy name for it
@_scal_elemwise_with_nfunc('invert', 1, 1) @_scal_elemwise
def invert(a): def invert(a):
"""bitwise ~a""" """bitwise ~a"""
bitwise_not = invert # numpy alias for it bitwise_not = invert # numpy alias for it
...@@ -1960,7 +1960,7 @@ bitwise_not = invert # numpy alias for it ...@@ -1960,7 +1960,7 @@ bitwise_not = invert # numpy alias for it
# Math # Math
########################## ##########################
@_scal_elemwise_with_nfunc('abs', 1, 1) @_scal_elemwise
def abs_(a): def abs_(a):
"""|`a`| """|`a`|
...@@ -1972,22 +1972,22 @@ def abs_(a): ...@@ -1972,22 +1972,22 @@ def abs_(a):
pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000))) pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000)))
@_scal_elemwise_with_nfunc('exp', 1, 1) @_scal_elemwise
def exp(a): def exp(a):
"""e^`a`""" """e^`a`"""
@_scal_elemwise_with_nfunc('exp2', 1, 1) @_scal_elemwise
def exp2(a): def exp2(a):
"""2^`a`""" """2^`a`"""
@_scal_elemwise_with_nfunc('expm1', 1, 1) @_scal_elemwise
def expm1(a): def expm1(a):
"""e^`a` - 1""" """e^`a` - 1"""
@_scal_elemwise_with_nfunc('negative', 1, 1) @_scal_elemwise
def neg(a): def neg(a):
"""-a""" """-a"""
...@@ -1999,42 +1999,42 @@ def inv(a): ...@@ -1999,42 +1999,42 @@ def inv(a):
"""1.0/a""" """1.0/a"""
@_scal_elemwise_with_nfunc('log', 1, 1) @_scal_elemwise
def log(a): def log(a):
"""base e logarithm of a""" """base e logarithm of a"""
@_scal_elemwise_with_nfunc('log2', 1, 1) @_scal_elemwise
def log2(a): def log2(a):
"""base 2 logarithm of a""" """base 2 logarithm of a"""
@_scal_elemwise_with_nfunc('log10', 1, 1) @_scal_elemwise
def log10(a): def log10(a):
"""base 10 logarithm of a""" """base 10 logarithm of a"""
@_scal_elemwise_with_nfunc('log1p', 1, 1) @_scal_elemwise
def log1p(a): def log1p(a):
"""log(1+a)""" """log(1+a)"""
@_scal_elemwise_with_nfunc('sign', 1, 1) @_scal_elemwise
def sgn(a): def sgn(a):
"""sign of a""" """sign of a"""
@_scal_elemwise_with_nfunc('ceil', 1, 1) @_scal_elemwise
def ceil(a): def ceil(a):
"""ceiling of a""" """ceiling of a"""
@_scal_elemwise_with_nfunc('floor', 1, 1) @_scal_elemwise
def floor(a): def floor(a):
"""floor of a""" """floor of a"""
@_scal_elemwise_with_nfunc('trunc', 1, 1) @_scal_elemwise
def trunc(a): def trunc(a):
"""trunc of a""" """trunc of a"""
...@@ -2056,7 +2056,7 @@ def round(a, mode="half_away_from_zero"): ...@@ -2056,7 +2056,7 @@ def round(a, mode="half_away_from_zero"):
raise Exception("round mode %s is not implemented." % mode) raise Exception("round mode %s is not implemented." % mode)
@_scal_elemwise_with_nfunc('around', 1, 1) @_scal_elemwise
def round_half_to_even(a): def round_half_to_even(a):
"""round_half_to_even(a)""" """round_half_to_even(a)"""
...@@ -2066,7 +2066,7 @@ def round_half_away_from_zero(a): ...@@ -2066,7 +2066,7 @@ def round_half_away_from_zero(a):
"""round_half_away_from_zero(a)""" """round_half_away_from_zero(a)"""
@_scal_elemwise_with_nfunc('square', 1, 1) @_scal_elemwise
def sqr(a): def sqr(a):
"""square of a""" """square of a"""
...@@ -2075,82 +2075,82 @@ def sqr(a): ...@@ -2075,82 +2075,82 @@ def sqr(a):
square = sqr square = sqr
@_scal_elemwise_with_nfunc('sqrt', 1, 1) @_scal_elemwise
def sqrt(a): def sqrt(a):
"""square root of a""" """square root of a"""
@_scal_elemwise_with_nfunc('deg2rad', 1, 1) @_scal_elemwise
def deg2rad(a): def deg2rad(a):
"""convert degree a to radian""" """convert degree a to radian"""
@_scal_elemwise_with_nfunc('rad2deg', 1, 1) @_scal_elemwise
def rad2deg(a): def rad2deg(a):
"""convert radian a to degree""" """convert radian a to degree"""
@_scal_elemwise_with_nfunc('cos', 1, 1) @_scal_elemwise
def cos(a): def cos(a):
"""cosine of a""" """cosine of a"""
@_scal_elemwise_with_nfunc('arccos', 1, 1) @_scal_elemwise
def arccos(a): def arccos(a):
"""arccosine of a""" """arccosine of a"""
@_scal_elemwise_with_nfunc('sin', 1, 1) @_scal_elemwise
def sin(a): def sin(a):
"""sine of a""" """sine of a"""
@_scal_elemwise_with_nfunc('arcsin', 1, 1) @_scal_elemwise
def arcsin(a): def arcsin(a):
"""arcsine of a""" """arcsine of a"""
@_scal_elemwise_with_nfunc('tan', 1, 1) @_scal_elemwise
def tan(a): def tan(a):
"""tangent of a""" """tangent of a"""
@_scal_elemwise_with_nfunc('arctan', 1, 1) @_scal_elemwise
def arctan(a): def arctan(a):
"""arctangent of a""" """arctangent of a"""
@_scal_elemwise_with_nfunc('arctan2', 1, 1) @_scal_elemwise
def arctan2(a, b): def arctan2(a, b):
"""arctangent of a / b""" """arctangent of a / b"""
@_scal_elemwise_with_nfunc('cosh', 1, 1) @_scal_elemwise
def cosh(a): def cosh(a):
"""hyperbolic cosine of a""" """hyperbolic cosine of a"""
@_scal_elemwise_with_nfunc('arccosh', 1, 1) @_scal_elemwise
def arccosh(a): def arccosh(a):
"""hyperbolic arc cosine of a""" """hyperbolic arc cosine of a"""
@_scal_elemwise_with_nfunc('sinh', 1, 1) @_scal_elemwise
def sinh(a): def sinh(a):
"""hyperbolic sine of a""" """hyperbolic sine of a"""
@_scal_elemwise_with_nfunc('arcsinh', 1, 1) @_scal_elemwise
def arcsinh(a): def arcsinh(a):
"""hyperbolic arc sine of a""" """hyperbolic arc sine of a"""
@_scal_elemwise_with_nfunc('tanh', 1, 1) @_scal_elemwise
def tanh(a): def tanh(a):
"""hyperbolic tangent of a""" """hyperbolic tangent of a"""
@_scal_elemwise_with_nfunc('arctanh', 1, 1) @_scal_elemwise
def arctanh(a): def arctanh(a):
"""hyperbolic arc tangent of a""" """hyperbolic arc tangent of a"""
...@@ -2200,21 +2200,19 @@ def chi2sf(x, k): ...@@ -2200,21 +2200,19 @@ def chi2sf(x, k):
"""chi squared survival function""" """chi squared survival function"""
# numpy.real(float32) return a view on the inputs.
# @_scal_elemwise_with_nfunc('real', 1, 1)
@_scal_elemwise @_scal_elemwise
def real(z): def real(z):
"""Return real component of complex-valued tensor `z`""" """Return real component of complex-valued tensor `z`"""
_tensor_py_operators.real = property(real) _tensor_py_operators.real = property(real)
@_scal_elemwise_with_nfunc('imag', 1, 1) @_scal_elemwise
def imag(z): def imag(z):
"""Return imaginary component of complex-valued tensor `z`""" """Return imaginary component of complex-valued tensor `z`"""
_tensor_py_operators.imag = property(imag) _tensor_py_operators.imag = property(imag)
@_scal_elemwise_with_nfunc('angle', 1, 1) @_scal_elemwise
def angle(z): def angle(z):
"""Return polar-coordinate angle of complex-valued tensor `z`""" """Return polar-coordinate angle of complex-valued tensor `z`"""
...@@ -2224,7 +2222,7 @@ def complex(real, imag): ...@@ -2224,7 +2222,7 @@ def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components""" """Return complex-valued tensor with `real` and `imag` components"""
@_scal_elemwise_with_nfunc('conj', 1, 1) @_scal_elemwise
def conj(z): def conj(z):
"""Return the complex conjugate of `z`.""" """Return the complex conjugate of `z`."""
...@@ -3166,13 +3164,13 @@ setdefault = default # legacy ...@@ -3166,13 +3164,13 @@ setdefault = default # legacy
########################## ##########################
# Arithmetics # Arithmetics
########################## ##########################
@_scal_elemwise_with_nfunc('maximum', 2, 1) @_scal_elemwise
def maximum(x, y): def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor""" """elemwise maximum. See max for the maximum in one tensor"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('minimum', 2, 1) @_scal_elemwise
def minimum(x, y): def minimum(x, y):
"""elemwise minimum. See min for the minimum in one tensor""" """elemwise minimum. See min for the minimum in one tensor"""
# see decorator for function body # see decorator for function body
...@@ -3191,31 +3189,31 @@ def divmod(x, y): ...@@ -3191,31 +3189,31 @@ def divmod(x, y):
return floor_div(x, y), mod_check(x, y) return floor_div(x, y), mod_check(x, y)
@_scal_elemwise_with_nfunc('add', 2, 1) @_scal_elemwise
def add(a, *other_terms): def add(a, *other_terms):
"""elementwise addition""" """elementwise addition"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('subtract', 2, 1) @_scal_elemwise
def sub(a, b): def sub(a, b):
"""elementwise subtraction""" """elementwise subtraction"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('multiply', 2, 1) @_scal_elemwise
def mul(a, *other_terms): def mul(a, *other_terms):
"""elementwise multiplication""" """elementwise multiplication"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('true_divide', 2, 1) @_scal_elemwise
def true_div(a, b): def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)""" """elementwise [true] division (inverse of multiplication)"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('floor_divide', 2, 1) @_scal_elemwise
def int_div(a, b): def int_div(a, b):
"""elementwise [floor] division (inverse of multiplication)""" """elementwise [floor] division (inverse of multiplication)"""
# see decorator for function body # see decorator for function body
...@@ -3256,20 +3254,18 @@ def mod_check(x, y): ...@@ -3256,20 +3254,18 @@ def mod_check(x, y):
return mod(x, y) return mod(x, y)
@_scal_elemwise_with_nfunc('mod', 2, 1) @_scal_elemwise
def mod(a, b): def mod(a, b):
"""elementwise modulo""" """elementwise modulo"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('power', 2, 1) @_scal_elemwise
def pow(a, b): def pow(a, b):
"""elementwise power""" """elementwise power"""
# see decorator for function body # see decorator for function body
# The numpy.clip don't work correctly when the min is bigger then the max,
# So we do not use @scal_elemwise_with_nfunc('clip', 3, 1)
@_scal_elemwise @_scal_elemwise
def clip(x, min, max): def clip(x, min, max):
""" """
......
...@@ -503,6 +503,8 @@ class Elemwise(OpenMPOp): ...@@ -503,6 +503,8 @@ class Elemwise(OpenMPOp):
self.ufunc = None self.ufunc = None
self.nfunc = None self.nfunc = None
if nfunc_spec is None:
nfunc_spec = getattr(scalar_op, 'nfunc_spec', None)
self.nfunc_spec = nfunc_spec self.nfunc_spec = nfunc_spec
if nfunc_spec: if nfunc_spec:
self.nfunc = getattr(numpy, nfunc_spec[0]) self.nfunc = getattr(numpy, nfunc_spec[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论