提交 7ef44dfd authored 作者: Frederic Bastien's avatar Frederic Bastien

Move nfunc_spec to Scalar op, This allow to always have it even when we use Elemwise(a_scalar_op).

This happend frequently in the grad computation of elemwise. This could speed up DebugMode. This fix tests crashX
上级 ed337d4e
差异被折叠。
...@@ -1738,42 +1738,42 @@ def largest(*args): ...@@ -1738,42 +1738,42 @@ def largest(*args):
# Comparison # Comparison
########################## ##########################
@_scal_elemwise_with_nfunc('less', 2, 1) @_scal_elemwise
def lt(a, b): def lt(a, b):
"""a < b""" """a < b"""
@_scal_elemwise_with_nfunc('greater', 2, 1) @_scal_elemwise
def gt(a, b): def gt(a, b):
"""a > b""" """a > b"""
@_scal_elemwise_with_nfunc('less_equal', 2, 1) @_scal_elemwise
def le(a, b): def le(a, b):
"""a <= b""" """a <= b"""
@_scal_elemwise_with_nfunc('greater_equal', 2, 1) @_scal_elemwise
def ge(a, b): def ge(a, b):
"""a >= b""" """a >= b"""
@_scal_elemwise_with_nfunc('equal', 2, 1) @_scal_elemwise
def eq(a, b): def eq(a, b):
"""a == b""" """a == b"""
@_scal_elemwise_with_nfunc('not_equal', 2, 1) @_scal_elemwise
def neq(a, b): def neq(a, b):
"""a != b""" """a != b"""
@_scal_elemwise_with_nfunc('isnan', 1, 1) @_scal_elemwise
def isnan(a): def isnan(a):
"""isnan(a)""" """isnan(a)"""
@_scal_elemwise_with_nfunc('isinf', 1, 1) @_scal_elemwise
def isinf(a): def isinf(a):
"""isinf(a)""" """isinf(a)"""
...@@ -1922,7 +1922,7 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): ...@@ -1922,7 +1922,7 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
# Condition # Condition
########################## ##########################
@_scal_elemwise_with_nfunc('where', 3, 1) @_scal_elemwise
def switch(cond, ift, iff): def switch(cond, ift, iff):
"""if cond then ift else iff""" """if cond then ift else iff"""
...@@ -1932,25 +1932,25 @@ where = switch ...@@ -1932,25 +1932,25 @@ where = switch
########################## ##########################
@_scal_elemwise_with_nfunc('bitwise_and', 2, 1) @_scal_elemwise
def and_(a, b): def and_(a, b):
"""bitwise a & b""" """bitwise a & b"""
bitwise_and = and_ # numpy name for it bitwise_and = and_ # numpy name for it
@_scal_elemwise_with_nfunc('bitwise_or', 2, 1) @_scal_elemwise
def or_(a, b): def or_(a, b):
"""bitwise a | b""" """bitwise a | b"""
bitwise_or = or_ # numpy name for it bitwise_or = or_ # numpy name for it
@_scal_elemwise_with_nfunc('bitwise_xor', 2, 1) @_scal_elemwise
def xor(a, b): def xor(a, b):
"""bitwise a ^ b""" """bitwise a ^ b"""
bitwise_xor = xor # numpy name for it bitwise_xor = xor # numpy name for it
@_scal_elemwise_with_nfunc('invert', 1, 1) @_scal_elemwise
def invert(a): def invert(a):
"""bitwise ~a""" """bitwise ~a"""
bitwise_not = invert # numpy alias for it bitwise_not = invert # numpy alias for it
...@@ -1960,7 +1960,7 @@ bitwise_not = invert # numpy alias for it ...@@ -1960,7 +1960,7 @@ bitwise_not = invert # numpy alias for it
# Math # Math
########################## ##########################
@_scal_elemwise_with_nfunc('abs', 1, 1) @_scal_elemwise
def abs_(a): def abs_(a):
"""|`a`| """|`a`|
...@@ -1972,22 +1972,22 @@ def abs_(a): ...@@ -1972,22 +1972,22 @@ def abs_(a):
pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000))) pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000)))
@_scal_elemwise_with_nfunc('exp', 1, 1) @_scal_elemwise
def exp(a): def exp(a):
"""e^`a`""" """e^`a`"""
@_scal_elemwise_with_nfunc('exp2', 1, 1) @_scal_elemwise
def exp2(a): def exp2(a):
"""2^`a`""" """2^`a`"""
@_scal_elemwise_with_nfunc('expm1', 1, 1) @_scal_elemwise
def expm1(a): def expm1(a):
"""e^`a` - 1""" """e^`a` - 1"""
@_scal_elemwise_with_nfunc('negative', 1, 1) @_scal_elemwise
def neg(a): def neg(a):
"""-a""" """-a"""
...@@ -1999,42 +1999,42 @@ def inv(a): ...@@ -1999,42 +1999,42 @@ def inv(a):
"""1.0/a""" """1.0/a"""
@_scal_elemwise_with_nfunc('log', 1, 1) @_scal_elemwise
def log(a): def log(a):
"""base e logarithm of a""" """base e logarithm of a"""
@_scal_elemwise_with_nfunc('log2', 1, 1) @_scal_elemwise
def log2(a): def log2(a):
"""base 2 logarithm of a""" """base 2 logarithm of a"""
@_scal_elemwise_with_nfunc('log10', 1, 1) @_scal_elemwise
def log10(a): def log10(a):
"""base 10 logarithm of a""" """base 10 logarithm of a"""
@_scal_elemwise_with_nfunc('log1p', 1, 1) @_scal_elemwise
def log1p(a): def log1p(a):
"""log(1+a)""" """log(1+a)"""
@_scal_elemwise_with_nfunc('sign', 1, 1) @_scal_elemwise
def sgn(a): def sgn(a):
"""sign of a""" """sign of a"""
@_scal_elemwise_with_nfunc('ceil', 1, 1) @_scal_elemwise
def ceil(a): def ceil(a):
"""ceiling of a""" """ceiling of a"""
@_scal_elemwise_with_nfunc('floor', 1, 1) @_scal_elemwise
def floor(a): def floor(a):
"""floor of a""" """floor of a"""
@_scal_elemwise_with_nfunc('trunc', 1, 1) @_scal_elemwise
def trunc(a): def trunc(a):
"""trunc of a""" """trunc of a"""
...@@ -2056,7 +2056,7 @@ def round(a, mode="half_away_from_zero"): ...@@ -2056,7 +2056,7 @@ def round(a, mode="half_away_from_zero"):
raise Exception("round mode %s is not implemented." % mode) raise Exception("round mode %s is not implemented." % mode)
@_scal_elemwise_with_nfunc('around', 1, 1) @_scal_elemwise
def round_half_to_even(a): def round_half_to_even(a):
"""round_half_to_even(a)""" """round_half_to_even(a)"""
...@@ -2066,7 +2066,7 @@ def round_half_away_from_zero(a): ...@@ -2066,7 +2066,7 @@ def round_half_away_from_zero(a):
"""round_half_away_from_zero(a)""" """round_half_away_from_zero(a)"""
@_scal_elemwise_with_nfunc('square', 1, 1) @_scal_elemwise
def sqr(a): def sqr(a):
"""square of a""" """square of a"""
...@@ -2075,82 +2075,82 @@ def sqr(a): ...@@ -2075,82 +2075,82 @@ def sqr(a):
square = sqr square = sqr
@_scal_elemwise_with_nfunc('sqrt', 1, 1) @_scal_elemwise
def sqrt(a): def sqrt(a):
"""square root of a""" """square root of a"""
@_scal_elemwise_with_nfunc('deg2rad', 1, 1) @_scal_elemwise
def deg2rad(a): def deg2rad(a):
"""convert degree a to radian""" """convert degree a to radian"""
@_scal_elemwise_with_nfunc('rad2deg', 1, 1) @_scal_elemwise
def rad2deg(a): def rad2deg(a):
"""convert radian a to degree""" """convert radian a to degree"""
@_scal_elemwise_with_nfunc('cos', 1, 1) @_scal_elemwise
def cos(a): def cos(a):
"""cosine of a""" """cosine of a"""
@_scal_elemwise_with_nfunc('arccos', 1, 1) @_scal_elemwise
def arccos(a): def arccos(a):
"""arccosine of a""" """arccosine of a"""
@_scal_elemwise_with_nfunc('sin', 1, 1) @_scal_elemwise
def sin(a): def sin(a):
"""sine of a""" """sine of a"""
@_scal_elemwise_with_nfunc('arcsin', 1, 1) @_scal_elemwise
def arcsin(a): def arcsin(a):
"""arcsine of a""" """arcsine of a"""
@_scal_elemwise_with_nfunc('tan', 1, 1) @_scal_elemwise
def tan(a): def tan(a):
"""tangent of a""" """tangent of a"""
@_scal_elemwise_with_nfunc('arctan', 1, 1) @_scal_elemwise
def arctan(a): def arctan(a):
"""arctangent of a""" """arctangent of a"""
@_scal_elemwise_with_nfunc('arctan2', 1, 1) @_scal_elemwise
def arctan2(a, b): def arctan2(a, b):
"""arctangent of a / b""" """arctangent of a / b"""
@_scal_elemwise_with_nfunc('cosh', 1, 1) @_scal_elemwise
def cosh(a): def cosh(a):
"""hyperbolic cosine of a""" """hyperbolic cosine of a"""
@_scal_elemwise_with_nfunc('arccosh', 1, 1) @_scal_elemwise
def arccosh(a): def arccosh(a):
"""hyperbolic arc cosine of a""" """hyperbolic arc cosine of a"""
@_scal_elemwise_with_nfunc('sinh', 1, 1) @_scal_elemwise
def sinh(a): def sinh(a):
"""hyperbolic sine of a""" """hyperbolic sine of a"""
@_scal_elemwise_with_nfunc('arcsinh', 1, 1) @_scal_elemwise
def arcsinh(a): def arcsinh(a):
"""hyperbolic arc sine of a""" """hyperbolic arc sine of a"""
@_scal_elemwise_with_nfunc('tanh', 1, 1) @_scal_elemwise
def tanh(a): def tanh(a):
"""hyperbolic tangent of a""" """hyperbolic tangent of a"""
@_scal_elemwise_with_nfunc('arctanh', 1, 1) @_scal_elemwise
def arctanh(a): def arctanh(a):
"""hyperbolic arc tangent of a""" """hyperbolic arc tangent of a"""
...@@ -2200,21 +2200,19 @@ def chi2sf(x, k): ...@@ -2200,21 +2200,19 @@ def chi2sf(x, k):
"""chi squared survival function""" """chi squared survival function"""
# numpy.real(float32) return a view on the inputs.
# @_scal_elemwise_with_nfunc('real', 1, 1)
@_scal_elemwise @_scal_elemwise
def real(z): def real(z):
"""Return real component of complex-valued tensor `z`""" """Return real component of complex-valued tensor `z`"""
_tensor_py_operators.real = property(real) _tensor_py_operators.real = property(real)
@_scal_elemwise_with_nfunc('imag', 1, 1) @_scal_elemwise
def imag(z): def imag(z):
"""Return imaginary component of complex-valued tensor `z`""" """Return imaginary component of complex-valued tensor `z`"""
_tensor_py_operators.imag = property(imag) _tensor_py_operators.imag = property(imag)
@_scal_elemwise_with_nfunc('angle', 1, 1) @_scal_elemwise
def angle(z): def angle(z):
"""Return polar-coordinate angle of complex-valued tensor `z`""" """Return polar-coordinate angle of complex-valued tensor `z`"""
...@@ -2224,7 +2222,7 @@ def complex(real, imag): ...@@ -2224,7 +2222,7 @@ def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components""" """Return complex-valued tensor with `real` and `imag` components"""
@_scal_elemwise_with_nfunc('conj', 1, 1) @_scal_elemwise
def conj(z): def conj(z):
"""Return the complex conjugate of `z`.""" """Return the complex conjugate of `z`."""
...@@ -3166,13 +3164,13 @@ setdefault = default # legacy ...@@ -3166,13 +3164,13 @@ setdefault = default # legacy
########################## ##########################
# Arithmetics # Arithmetics
########################## ##########################
@_scal_elemwise_with_nfunc('maximum', 2, 1) @_scal_elemwise
def maximum(x, y): def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor""" """elemwise maximum. See max for the maximum in one tensor"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('minimum', 2, 1) @_scal_elemwise
def minimum(x, y): def minimum(x, y):
"""elemwise minimum. See min for the minimum in one tensor""" """elemwise minimum. See min for the minimum in one tensor"""
# see decorator for function body # see decorator for function body
...@@ -3191,31 +3189,31 @@ def divmod(x, y): ...@@ -3191,31 +3189,31 @@ def divmod(x, y):
return floor_div(x, y), mod_check(x, y) return floor_div(x, y), mod_check(x, y)
@_scal_elemwise_with_nfunc('add', 2, 1) @_scal_elemwise
def add(a, *other_terms): def add(a, *other_terms):
"""elementwise addition""" """elementwise addition"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('subtract', 2, 1) @_scal_elemwise
def sub(a, b): def sub(a, b):
"""elementwise subtraction""" """elementwise subtraction"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('multiply', 2, 1) @_scal_elemwise
def mul(a, *other_terms): def mul(a, *other_terms):
"""elementwise multiplication""" """elementwise multiplication"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('true_divide', 2, 1) @_scal_elemwise
def true_div(a, b): def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)""" """elementwise [true] division (inverse of multiplication)"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('floor_divide', 2, 1) @_scal_elemwise
def int_div(a, b): def int_div(a, b):
"""elementwise [floor] division (inverse of multiplication)""" """elementwise [floor] division (inverse of multiplication)"""
# see decorator for function body # see decorator for function body
...@@ -3256,20 +3254,18 @@ def mod_check(x, y): ...@@ -3256,20 +3254,18 @@ def mod_check(x, y):
return mod(x, y) return mod(x, y)
@_scal_elemwise_with_nfunc('mod', 2, 1) @_scal_elemwise
def mod(a, b): def mod(a, b):
"""elementwise modulo""" """elementwise modulo"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('power', 2, 1) @_scal_elemwise
def pow(a, b): def pow(a, b):
"""elementwise power""" """elementwise power"""
# see decorator for function body # see decorator for function body
# The numpy.clip don't work correctly when the min is bigger then the max,
# So we do not use @scal_elemwise_with_nfunc('clip', 3, 1)
@_scal_elemwise @_scal_elemwise
def clip(x, min, max): def clip(x, min, max):
""" """
......
...@@ -503,6 +503,8 @@ class Elemwise(OpenMPOp): ...@@ -503,6 +503,8 @@ class Elemwise(OpenMPOp):
self.ufunc = None self.ufunc = None
self.nfunc = None self.nfunc = None
if nfunc_spec is None:
nfunc_spec = getattr(scalar_op, 'nfunc_spec', None)
self.nfunc_spec = nfunc_spec self.nfunc_spec = nfunc_spec
if nfunc_spec: if nfunc_spec:
self.nfunc = getattr(numpy, nfunc_spec[0]) self.nfunc = getattr(numpy, nfunc_spec[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论