提交 5849db21 authored 作者: abergeron's avatar abergeron

Merge pull request #3532 from nouiz/ufunc_32

[CRASH] when Elemwise have more then 31 inputs
......@@ -132,6 +132,16 @@ class Apply(Node):
return self.op.get_context(self)
return NoContext
def __getstate__(self):
d = self.__dict__
# ufunc don't pickle/unpickle well
if hasattr(self.tag, 'ufunc'):
d = copy(self.__dict__)
t = d["tag"]
del t.ufunc
d["tag"] = t
return d
def default_output(self):
"""
Returns the default output for this node.
......
差异被折叠。
......@@ -1739,42 +1739,42 @@ def largest(*args):
# Comparison
##########################
@_scal_elemwise_with_nfunc('less', 2, 1)
@_scal_elemwise
def lt(a, b):
"""a < b"""
@_scal_elemwise_with_nfunc('greater', 2, 1)
@_scal_elemwise
def gt(a, b):
"""a > b"""
@_scal_elemwise_with_nfunc('less_equal', 2, 1)
@_scal_elemwise
def le(a, b):
"""a <= b"""
@_scal_elemwise_with_nfunc('greater_equal', 2, 1)
@_scal_elemwise
def ge(a, b):
"""a >= b"""
@_scal_elemwise_with_nfunc('equal', 2, 1)
@_scal_elemwise
def eq(a, b):
"""a == b"""
@_scal_elemwise_with_nfunc('not_equal', 2, 1)
@_scal_elemwise
def neq(a, b):
"""a != b"""
@_scal_elemwise_with_nfunc('isnan', 1, 1)
@_scal_elemwise
def isnan(a):
"""isnan(a)"""
@_scal_elemwise_with_nfunc('isinf', 1, 1)
@_scal_elemwise
def isinf(a):
"""isinf(a)"""
......@@ -1923,7 +1923,7 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
# Condition
##########################
@_scal_elemwise_with_nfunc('where', 3, 1)
@_scal_elemwise
def switch(cond, ift, iff):
"""if cond then ift else iff"""
......@@ -1933,25 +1933,25 @@ where = switch
##########################
@_scal_elemwise_with_nfunc('bitwise_and', 2, 1)
@_scal_elemwise
def and_(a, b):
"""bitwise a & b"""
bitwise_and = and_ # numpy name for it
@_scal_elemwise_with_nfunc('bitwise_or', 2, 1)
@_scal_elemwise
def or_(a, b):
"""bitwise a | b"""
bitwise_or = or_ # numpy name for it
@_scal_elemwise_with_nfunc('bitwise_xor', 2, 1)
@_scal_elemwise
def xor(a, b):
"""bitwise a ^ b"""
bitwise_xor = xor # numpy name for it
@_scal_elemwise_with_nfunc('invert', 1, 1)
@_scal_elemwise
def invert(a):
"""bitwise ~a"""
bitwise_not = invert # numpy alias for it
......@@ -1961,7 +1961,7 @@ bitwise_not = invert # numpy alias for it
# Math
##########################
@_scal_elemwise_with_nfunc('abs', 1, 1)
@_scal_elemwise
def abs_(a):
"""|`a`|
......@@ -1973,22 +1973,22 @@ def abs_(a):
pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000)))
@_scal_elemwise_with_nfunc('exp', 1, 1)
@_scal_elemwise
def exp(a):
"""e^`a`"""
@_scal_elemwise_with_nfunc('exp2', 1, 1)
@_scal_elemwise
def exp2(a):
"""2^`a`"""
@_scal_elemwise_with_nfunc('expm1', 1, 1)
@_scal_elemwise
def expm1(a):
"""e^`a` - 1"""
@_scal_elemwise_with_nfunc('negative', 1, 1)
@_scal_elemwise
def neg(a):
"""-a"""
......@@ -2000,42 +2000,42 @@ def inv(a):
"""1.0/a"""
@_scal_elemwise_with_nfunc('log', 1, 1)
@_scal_elemwise
def log(a):
"""base e logarithm of a"""
@_scal_elemwise_with_nfunc('log2', 1, 1)
@_scal_elemwise
def log2(a):
"""base 2 logarithm of a"""
@_scal_elemwise_with_nfunc('log10', 1, 1)
@_scal_elemwise
def log10(a):
"""base 10 logarithm of a"""
@_scal_elemwise_with_nfunc('log1p', 1, 1)
@_scal_elemwise
def log1p(a):
"""log(1+a)"""
@_scal_elemwise_with_nfunc('sign', 1, 1)
@_scal_elemwise
def sgn(a):
"""sign of a"""
@_scal_elemwise_with_nfunc('ceil', 1, 1)
@_scal_elemwise
def ceil(a):
"""ceiling of a"""
@_scal_elemwise_with_nfunc('floor', 1, 1)
@_scal_elemwise
def floor(a):
"""floor of a"""
@_scal_elemwise_with_nfunc('trunc', 1, 1)
@_scal_elemwise
def trunc(a):
"""trunc of a"""
......@@ -2057,7 +2057,7 @@ def round(a, mode="half_away_from_zero"):
raise Exception("round mode %s is not implemented." % mode)
@_scal_elemwise_with_nfunc('around', 1, 1)
@_scal_elemwise
def round_half_to_even(a):
"""round_half_to_even(a)"""
......@@ -2067,7 +2067,7 @@ def round_half_away_from_zero(a):
"""round_half_away_from_zero(a)"""
@_scal_elemwise_with_nfunc('square', 1, 1)
@_scal_elemwise
def sqr(a):
"""square of a"""
......@@ -2076,82 +2076,82 @@ def sqr(a):
square = sqr
@_scal_elemwise_with_nfunc('sqrt', 1, 1)
@_scal_elemwise
def sqrt(a):
"""square root of a"""
@_scal_elemwise_with_nfunc('deg2rad', 1, 1)
@_scal_elemwise
def deg2rad(a):
"""convert degree a to radian"""
@_scal_elemwise_with_nfunc('rad2deg', 1, 1)
@_scal_elemwise
def rad2deg(a):
"""convert radian a to degree"""
@_scal_elemwise_with_nfunc('cos', 1, 1)
@_scal_elemwise
def cos(a):
"""cosine of a"""
@_scal_elemwise_with_nfunc('arccos', 1, 1)
@_scal_elemwise
def arccos(a):
"""arccosine of a"""
@_scal_elemwise_with_nfunc('sin', 1, 1)
@_scal_elemwise
def sin(a):
"""sine of a"""
@_scal_elemwise_with_nfunc('arcsin', 1, 1)
@_scal_elemwise
def arcsin(a):
"""arcsine of a"""
@_scal_elemwise_with_nfunc('tan', 1, 1)
@_scal_elemwise
def tan(a):
"""tangent of a"""
@_scal_elemwise_with_nfunc('arctan', 1, 1)
@_scal_elemwise
def arctan(a):
"""arctangent of a"""
@_scal_elemwise_with_nfunc('arctan2', 1, 1)
@_scal_elemwise
def arctan2(a, b):
"""arctangent of a / b"""
@_scal_elemwise_with_nfunc('cosh', 1, 1)
@_scal_elemwise
def cosh(a):
"""hyperbolic cosine of a"""
@_scal_elemwise_with_nfunc('arccosh', 1, 1)
@_scal_elemwise
def arccosh(a):
"""hyperbolic arc cosine of a"""
@_scal_elemwise_with_nfunc('sinh', 1, 1)
@_scal_elemwise
def sinh(a):
"""hyperbolic sine of a"""
@_scal_elemwise_with_nfunc('arcsinh', 1, 1)
@_scal_elemwise
def arcsinh(a):
"""hyperbolic arc sine of a"""
@_scal_elemwise_with_nfunc('tanh', 1, 1)
@_scal_elemwise
def tanh(a):
"""hyperbolic tangent of a"""
@_scal_elemwise_with_nfunc('arctanh', 1, 1)
@_scal_elemwise
def arctanh(a):
"""hyperbolic arc tangent of a"""
......@@ -2201,21 +2201,19 @@ def chi2sf(x, k):
"""chi squared survival function"""
# numpy.real(float32) return a view on the inputs.
# @_scal_elemwise_with_nfunc('real', 1, 1)
@_scal_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
_tensor_py_operators.real = property(real)
@_scal_elemwise_with_nfunc('imag', 1, 1)
@_scal_elemwise
def imag(z):
"""Return imaginary component of complex-valued tensor `z`"""
_tensor_py_operators.imag = property(imag)
@_scal_elemwise_with_nfunc('angle', 1, 1)
@_scal_elemwise
def angle(z):
"""Return polar-coordinate angle of complex-valued tensor `z`"""
......@@ -2225,7 +2223,7 @@ def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components"""
@_scal_elemwise_with_nfunc('conj', 1, 1)
@_scal_elemwise
def conj(z):
"""Return the complex conjugate of `z`."""
......@@ -3202,13 +3200,13 @@ setdefault = default # legacy
##########################
# Arithmetics
##########################
@_scal_elemwise_with_nfunc('maximum', 2, 1)
@_scal_elemwise
def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor"""
# see decorator for function body
@_scal_elemwise_with_nfunc('minimum', 2, 1)
@_scal_elemwise
def minimum(x, y):
"""elemwise minimum. See min for the minimum in one tensor"""
# see decorator for function body
......@@ -3227,31 +3225,31 @@ def divmod(x, y):
return floor_div(x, y), mod_check(x, y)
@_scal_elemwise_with_nfunc('add', 2, 1)
@_scal_elemwise
def add(a, *other_terms):
"""elementwise addition"""
# see decorator for function body
@_scal_elemwise_with_nfunc('subtract', 2, 1)
@_scal_elemwise
def sub(a, b):
"""elementwise subtraction"""
# see decorator for function body
@_scal_elemwise_with_nfunc('multiply', 2, 1)
@_scal_elemwise
def mul(a, *other_terms):
"""elementwise multiplication"""
# see decorator for function body
@_scal_elemwise_with_nfunc('true_divide', 2, 1)
@_scal_elemwise
def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)"""
# see decorator for function body
@_scal_elemwise_with_nfunc('floor_divide', 2, 1)
@_scal_elemwise
def int_div(a, b):
"""elementwise [floor] division (inverse of multiplication)"""
# see decorator for function body
......@@ -3292,20 +3290,18 @@ def mod_check(x, y):
return mod(x, y)
@_scal_elemwise_with_nfunc('mod', 2, 1)
@_scal_elemwise
def mod(a, b):
"""elementwise modulo"""
# see decorator for function body
@_scal_elemwise_with_nfunc('power', 2, 1)
@_scal_elemwise
def pow(a, b):
"""elementwise power"""
# see decorator for function body
# The numpy.clip don't work correctly when the min is bigger then the max,
# So we do not use @scal_elemwise_with_nfunc('clip', 3, 1)
@_scal_elemwise
def clip(x, min, max):
"""
......
......@@ -7,6 +7,7 @@ import numpy
import theano
from theano import gof
from theano.compat import izip
from theano.compat import get_unbound_function
from six import iteritems
from six.moves import xrange
from theano.gof import Apply, Op, OpenMPOp
......@@ -502,12 +503,11 @@ class Elemwise(OpenMPOp):
self.ufunc = None
self.nfunc = None
if nfunc_spec is None:
nfunc_spec = getattr(scalar_op, 'nfunc_spec', None)
self.nfunc_spec = nfunc_spec
if nfunc_spec:
self.nfunc = getattr(numpy, nfunc_spec[0])
elif scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(scalar_op.impl, scalar_op.nin,
scalar_op.nout)
# precompute the hash of this node
self._rehash()
......@@ -527,7 +527,7 @@ class Elemwise(OpenMPOp):
self.nfunc = None
if getattr(self, 'nfunc_spec', None):
self.nfunc = getattr(numpy, self.nfunc_spec[0])
elif self.scalar_op.nin > 0:
elif self.scalar_op.nin > 0 and self.scalar_op.nin < 32:
self.ufunc = numpy.frompyfunc(self.scalar_op.impl,
self.scalar_op.nin,
self.scalar_op.nout)
......@@ -792,6 +792,28 @@ class Elemwise(OpenMPOp):
return ret
def make_thunk(self, node, storage_map, compute_map, no_recycling):
node_ = node
# Postpone the ufunc building to the last minutes
# NumPy ufunc support only up to 31 inputs.
# But our c code support more.
if (len(node.inputs) < 32 and
(self.nfunc is None or
self.scalar_op.nin != len(node.inputs)) and
self.ufunc is None):
ufunc = numpy.frompyfunc(self.scalar_op.impl,
len(node.inputs),
self.scalar_op.nout)
if self.scalar_op.nin > 0:
# We can reuse it for many nodes
self.ufunc = ufunc
else:
node.tag.ufunc = ufunc
return super(Elemwise, node_.op).make_thunk(node_, storage_map,
compute_map, no_recycling)
def perform(self, node, inputs, output_storage):
if len(node.inputs) >= 32:
# Some versions of NumPy will segfault, other will raise a
......@@ -859,9 +881,18 @@ class Elemwise(OpenMPOp):
else:
# the second calling form is used because in certain versions of
# numpy the first (faster) version leads to segfaults
ufunc = (self.ufunc or
numpy.frompyfunc(self.scalar_op.impl, len(inputs),
self.scalar_op.nout))
if self.ufunc:
ufunc = self.ufunc
else:
if not hasattr(node.tag, 'ufunc'):
# It happen that make_thunk isn't called, like in
# get_scalar_constant_value
node.tag.ufunc = numpy.frompyfunc(self.scalar_op.impl,
len(node.inputs),
self.scalar_op.nout)
ufunc = node.tag.ufunc
nout = ufunc.nout
variables = ufunc(*ufunc_args, **ufunc_kwargs)
......@@ -1234,6 +1265,9 @@ class Elemwise(OpenMPOp):
"""
return node.outputs[0].ndim == 0
theano.compile.debugmode.default_make_thunk.append(
get_unbound_function(Elemwise.make_thunk))
# def elemwise_to_scal(fgraph):
# TODO: why is this commented out? should it be removed?
# it has needed maintenance despite being commented
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论