提交 7ef44dfd authored 作者: Frederic Bastien's avatar Frederic Bastien

Move nfunc_spec to Scalar op, This allow to always have it even when we use Elemwise(a_scalar_op).

This happend frequently in the grad computation of elemwise. This could speed up DebugMode. This fix tests crashX
上级 ed337d4e
差异被折叠。
......@@ -1738,42 +1738,42 @@ def largest(*args):
# Comparison
##########################
@_scal_elemwise_with_nfunc('less', 2, 1)
@_scal_elemwise
def lt(a, b):
"""a < b"""
@_scal_elemwise_with_nfunc('greater', 2, 1)
@_scal_elemwise
def gt(a, b):
"""a > b"""
@_scal_elemwise_with_nfunc('less_equal', 2, 1)
@_scal_elemwise
def le(a, b):
"""a <= b"""
@_scal_elemwise_with_nfunc('greater_equal', 2, 1)
@_scal_elemwise
def ge(a, b):
"""a >= b"""
@_scal_elemwise_with_nfunc('equal', 2, 1)
@_scal_elemwise
def eq(a, b):
"""a == b"""
@_scal_elemwise_with_nfunc('not_equal', 2, 1)
@_scal_elemwise
def neq(a, b):
"""a != b"""
@_scal_elemwise_with_nfunc('isnan', 1, 1)
@_scal_elemwise
def isnan(a):
"""isnan(a)"""
@_scal_elemwise_with_nfunc('isinf', 1, 1)
@_scal_elemwise
def isinf(a):
"""isinf(a)"""
......@@ -1922,7 +1922,7 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
# Condition
##########################
@_scal_elemwise_with_nfunc('where', 3, 1)
@_scal_elemwise
def switch(cond, ift, iff):
"""if cond then ift else iff"""
......@@ -1932,25 +1932,25 @@ where = switch
##########################
@_scal_elemwise_with_nfunc('bitwise_and', 2, 1)
@_scal_elemwise
def and_(a, b):
"""bitwise a & b"""
bitwise_and = and_ # numpy name for it
@_scal_elemwise_with_nfunc('bitwise_or', 2, 1)
@_scal_elemwise
def or_(a, b):
"""bitwise a | b"""
bitwise_or = or_ # numpy name for it
@_scal_elemwise_with_nfunc('bitwise_xor', 2, 1)
@_scal_elemwise
def xor(a, b):
"""bitwise a ^ b"""
bitwise_xor = xor # numpy name for it
@_scal_elemwise_with_nfunc('invert', 1, 1)
@_scal_elemwise
def invert(a):
"""bitwise ~a"""
bitwise_not = invert # numpy alias for it
......@@ -1960,7 +1960,7 @@ bitwise_not = invert # numpy alias for it
# Math
##########################
@_scal_elemwise_with_nfunc('abs', 1, 1)
@_scal_elemwise
def abs_(a):
"""|`a`|
......@@ -1972,22 +1972,22 @@ def abs_(a):
pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000)))
@_scal_elemwise_with_nfunc('exp', 1, 1)
@_scal_elemwise
def exp(a):
"""e^`a`"""
@_scal_elemwise_with_nfunc('exp2', 1, 1)
@_scal_elemwise
def exp2(a):
"""2^`a`"""
@_scal_elemwise_with_nfunc('expm1', 1, 1)
@_scal_elemwise
def expm1(a):
"""e^`a` - 1"""
@_scal_elemwise_with_nfunc('negative', 1, 1)
@_scal_elemwise
def neg(a):
"""-a"""
......@@ -1999,42 +1999,42 @@ def inv(a):
"""1.0/a"""
@_scal_elemwise_with_nfunc('log', 1, 1)
@_scal_elemwise
def log(a):
"""base e logarithm of a"""
@_scal_elemwise_with_nfunc('log2', 1, 1)
@_scal_elemwise
def log2(a):
"""base 2 logarithm of a"""
@_scal_elemwise_with_nfunc('log10', 1, 1)
@_scal_elemwise
def log10(a):
"""base 10 logarithm of a"""
@_scal_elemwise_with_nfunc('log1p', 1, 1)
@_scal_elemwise
def log1p(a):
"""log(1+a)"""
@_scal_elemwise_with_nfunc('sign', 1, 1)
@_scal_elemwise
def sgn(a):
"""sign of a"""
@_scal_elemwise_with_nfunc('ceil', 1, 1)
@_scal_elemwise
def ceil(a):
"""ceiling of a"""
@_scal_elemwise_with_nfunc('floor', 1, 1)
@_scal_elemwise
def floor(a):
"""floor of a"""
@_scal_elemwise_with_nfunc('trunc', 1, 1)
@_scal_elemwise
def trunc(a):
"""trunc of a"""
......@@ -2056,7 +2056,7 @@ def round(a, mode="half_away_from_zero"):
raise Exception("round mode %s is not implemented." % mode)
@_scal_elemwise_with_nfunc('around', 1, 1)
@_scal_elemwise
def round_half_to_even(a):
"""round_half_to_even(a)"""
......@@ -2066,7 +2066,7 @@ def round_half_away_from_zero(a):
"""round_half_away_from_zero(a)"""
@_scal_elemwise_with_nfunc('square', 1, 1)
@_scal_elemwise
def sqr(a):
"""square of a"""
......@@ -2075,82 +2075,82 @@ def sqr(a):
square = sqr
@_scal_elemwise_with_nfunc('sqrt', 1, 1)
@_scal_elemwise
def sqrt(a):
"""square root of a"""
@_scal_elemwise_with_nfunc('deg2rad', 1, 1)
@_scal_elemwise
def deg2rad(a):
"""convert degree a to radian"""
@_scal_elemwise_with_nfunc('rad2deg', 1, 1)
@_scal_elemwise
def rad2deg(a):
"""convert radian a to degree"""
@_scal_elemwise_with_nfunc('cos', 1, 1)
@_scal_elemwise
def cos(a):
"""cosine of a"""
@_scal_elemwise_with_nfunc('arccos', 1, 1)
@_scal_elemwise
def arccos(a):
"""arccosine of a"""
@_scal_elemwise_with_nfunc('sin', 1, 1)
@_scal_elemwise
def sin(a):
"""sine of a"""
@_scal_elemwise_with_nfunc('arcsin', 1, 1)
@_scal_elemwise
def arcsin(a):
"""arcsine of a"""
@_scal_elemwise_with_nfunc('tan', 1, 1)
@_scal_elemwise
def tan(a):
"""tangent of a"""
@_scal_elemwise_with_nfunc('arctan', 1, 1)
@_scal_elemwise
def arctan(a):
"""arctangent of a"""
@_scal_elemwise_with_nfunc('arctan2', 1, 1)
@_scal_elemwise
def arctan2(a, b):
"""arctangent of a / b"""
@_scal_elemwise_with_nfunc('cosh', 1, 1)
@_scal_elemwise
def cosh(a):
"""hyperbolic cosine of a"""
@_scal_elemwise_with_nfunc('arccosh', 1, 1)
@_scal_elemwise
def arccosh(a):
"""hyperbolic arc cosine of a"""
@_scal_elemwise_with_nfunc('sinh', 1, 1)
@_scal_elemwise
def sinh(a):
"""hyperbolic sine of a"""
@_scal_elemwise_with_nfunc('arcsinh', 1, 1)
@_scal_elemwise
def arcsinh(a):
"""hyperbolic arc sine of a"""
@_scal_elemwise_with_nfunc('tanh', 1, 1)
@_scal_elemwise
def tanh(a):
"""hyperbolic tangent of a"""
@_scal_elemwise_with_nfunc('arctanh', 1, 1)
@_scal_elemwise
def arctanh(a):
"""hyperbolic arc tangent of a"""
......@@ -2200,21 +2200,19 @@ def chi2sf(x, k):
"""chi squared survival function"""
# numpy.real(float32) return a view on the inputs.
# @_scal_elemwise_with_nfunc('real', 1, 1)
@_scal_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
_tensor_py_operators.real = property(real)
@_scal_elemwise_with_nfunc('imag', 1, 1)
@_scal_elemwise
def imag(z):
"""Return imaginary component of complex-valued tensor `z`"""
_tensor_py_operators.imag = property(imag)
@_scal_elemwise_with_nfunc('angle', 1, 1)
@_scal_elemwise
def angle(z):
"""Return polar-coordinate angle of complex-valued tensor `z`"""
......@@ -2224,7 +2222,7 @@ def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components"""
@_scal_elemwise_with_nfunc('conj', 1, 1)
@_scal_elemwise
def conj(z):
"""Return the complex conjugate of `z`."""
......@@ -3166,13 +3164,13 @@ setdefault = default # legacy
##########################
# Arithmetics
##########################
@_scal_elemwise_with_nfunc('maximum', 2, 1)
@_scal_elemwise
def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor"""
# see decorator for function body
@_scal_elemwise_with_nfunc('minimum', 2, 1)
@_scal_elemwise
def minimum(x, y):
"""elemwise minimum. See min for the minimum in one tensor"""
# see decorator for function body
......@@ -3191,31 +3189,31 @@ def divmod(x, y):
return floor_div(x, y), mod_check(x, y)
@_scal_elemwise_with_nfunc('add', 2, 1)
@_scal_elemwise
def add(a, *other_terms):
"""elementwise addition"""
# see decorator for function body
@_scal_elemwise_with_nfunc('subtract', 2, 1)
@_scal_elemwise
def sub(a, b):
"""elementwise subtraction"""
# see decorator for function body
@_scal_elemwise_with_nfunc('multiply', 2, 1)
@_scal_elemwise
def mul(a, *other_terms):
"""elementwise multiplication"""
# see decorator for function body
@_scal_elemwise_with_nfunc('true_divide', 2, 1)
@_scal_elemwise
def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)"""
# see decorator for function body
@_scal_elemwise_with_nfunc('floor_divide', 2, 1)
@_scal_elemwise
def int_div(a, b):
"""elementwise [floor] division (inverse of multiplication)"""
# see decorator for function body
......@@ -3256,20 +3254,18 @@ def mod_check(x, y):
return mod(x, y)
@_scal_elemwise_with_nfunc('mod', 2, 1)
@_scal_elemwise
def mod(a, b):
"""elementwise modulo"""
# see decorator for function body
@_scal_elemwise_with_nfunc('power', 2, 1)
@_scal_elemwise
def pow(a, b):
"""elementwise power"""
# see decorator for function body
# The numpy.clip don't work correctly when the min is bigger then the max,
# So we do not use @scal_elemwise_with_nfunc('clip', 3, 1)
@_scal_elemwise
def clip(x, min, max):
"""
......
......@@ -503,6 +503,8 @@ class Elemwise(OpenMPOp):
self.ufunc = None
self.nfunc = None
if nfunc_spec is None:
nfunc_spec = getattr(scalar_op, 'nfunc_spec', None)
self.nfunc_spec = nfunc_spec
if nfunc_spec:
self.nfunc = getattr(numpy, nfunc_spec[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论