提交 4bbac540 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

......@@ -417,7 +417,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
raise ValueError('mode should be bfs or dfs', mode)
rval_set = set()
rval_list = list()
if mode is 'bfs': start_pop = start.popleft
if mode == 'bfs': start_pop = start.popleft
else: start_pop = start.pop
expand_inv = {}
while start:
......
差异被折叠。
......@@ -1336,35 +1336,39 @@ def _redefine_asRoutine(real_symbol_value):
return real_symbol_value
return decorator
def _scal_elemwise(symbol):
def _scal_elemwise_with_nfunc(nfunc, nin, nout):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
symbolname = symbol.__name__
inplace = symbolname.endswith('_inplace')
if inplace:
msg = "inplace"
else:
msg = "no_inplace"
n="Elemwise{%s,%s}"%(symbolname,msg)
def construct(symbol):
symbolname = symbol.__name__
inplace = symbolname.endswith('_inplace')
if inplace:
msg = "inplace"
else:
msg = "no_inplace"
n="Elemwise{%s,%s}"%(symbolname,msg)
if inplace:
scalar_op = getattr(scal, symbolname[:-len('_inplace')])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=n)
else:
scalar_op = getattr(scal, symbolname)
rval = elemwise.Elemwise(scalar_op, name=n)
if inplace:
scalar_op = getattr(scal, symbolname[:-len('_inplace')])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=n, nfunc_spec=((nfunc, nin, nout) if nfunc else None))
else:
scalar_op = getattr(scal, symbolname)
rval = elemwise.Elemwise(scalar_op, name=n, nfunc_spec=((nfunc, nin, nout) if nfunc else None))
if getattr(symbol, '__doc__', False):
rval.__doc__ = symbol.__doc__ + '\n' + rval.__doc__
if getattr(symbol, '__doc__', False):
rval.__doc__ = symbol.__doc__ + '\n' + rval.__doc__
#for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object
rval.__epydoc_asRoutine = symbol
rval.__module__ = 'tensor'
#for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object
rval.__epydoc_asRoutine = symbol
rval.__module__ = 'tensor'
pprint.assign(rval, printing.FunctionPrinter(symbolname))
pprint.assign(rval, printing.FunctionPrinter(symbolname))
return rval
return rval
return construct
_scal_elemwise = _scal_elemwise_with_nfunc(None, None, None)
#########################
......@@ -1865,27 +1869,27 @@ def largest(*args):
# Comparison
##########################
@_scal_elemwise
@_scal_elemwise_with_nfunc('less', 2, 1)
def lt(a, b):
"""a < b"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('greater', 2, 1)
def gt(a, b):
"""a > b"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('less_equal', 2, 1)
def le(a, b):
"""a <= b"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('greater_equal', 2, 1)
def ge(a, b):
"""a >= b"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('equal', 2, 1)
def eq(a, b):
"""a == b"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('not_equal', 2, 1)
def neq(a, b):
"""a != b"""
......@@ -1903,19 +1907,19 @@ def switch(cond, ift, iff):
# Bit-wise
##########################
@_scal_elemwise
@_scal_elemwise_with_nfunc('bitwise_and', 2, 1)
def and_(a,b):
"""bitwise a & b"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('bitwise_or', 2, 1)
def or_(a,b):
"""bitwise a | b"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('bitwise_xor', 2, 1)
def xor(a,b):
"""bitwise a ^ b"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('invert', 1, 1)
def invert(a):
"""bitwise ~a"""
......@@ -1923,7 +1927,7 @@ def invert(a):
# Math
##########################
@_scal_elemwise
@_scal_elemwise_with_nfunc('abs', 1, 1)
def abs_(a):
"""|`a`|
......@@ -1934,43 +1938,43 @@ def abs_(a):
pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000)))
@_scal_elemwise
@_scal_elemwise_with_nfunc('exp', 1, 1)
def exp(a):
"""e^`a`"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('negative', 1, 1)
def neg(a):
"""-a"""
@_scal_elemwise
@_scal_elemwise # numpy.reciprocal does integer division on integer inputs (which is not very interesting)
def inv(a):
"""1.0/a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('log', 1, 1)
def log(a):
"""base e logarithm of a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('log2', 1, 1)
def log2(a):
"""base 2 logarithm of a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('log10', 1, 1)
def log10(a):
"""base 10 logarithm of a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('log1p', 1, 1)
def log1p(a):
"""log(1+a)"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('sign', 1, 1)
def sgn(a):
"""sign of a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('ceil', 1, 1)
def ceil(a):
"""ceiling of a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('floor', 1, 1)
def floor(a):
"""floor of a"""
......@@ -1989,7 +1993,10 @@ def round(a, mode="half_away_from_zero"):
else:
raise Exception("round mode %s is not implemented."%mode)
@_scal_elemwise
# def __round_half_to_even(a, dest):
# dest[:] = numpy.around(a)
@_scal_elemwise_with_nfunc('around', 1, 0)
def round_half_to_even(a):
"""round_half_to_even(a)"""
......@@ -1997,35 +2004,35 @@ def round_half_to_even(a):
def round_half_away_from_zero(a):
"""round_half_away_from_zero(a)"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('square', 1, 1)
def sqr(a):
"""square of a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('sqrt', 1, 1)
def sqrt(a):
"""square root of a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('cos', 1, 1)
def cos(a):
"""cosine of a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('sin', 1, 1)
def sin(a):
"""sine of a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('tan', 1, 1)
def tan(a):
"""tangent of a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('cosh', 1, 1)
def cosh(a):
"""hyperbolic cosine of a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('sinh', 1, 1)
def sinh(a):
"""hyperbolic sine of a"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('tanh', 1, 1)
def tanh(a):
"""hyperbolic tangent of a"""
......@@ -2037,19 +2044,19 @@ def erf(a):
def erfc(a):
"""complementary error function"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('real', 1, 0)
def real(z):
"""Return real component of complex-valued tensor `z`"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('imag', 1, 0)
def imag(z):
"""Return imaginary component of complex-valued tensor `z`"""
@_scal_elemwise
@_scal_elemwise_with_nfunc('angle', 1, 0)
def angle(z):
"""Return polar-coordinate angle of complex-valued tensor `z`"""
@_scal_elemwise
@_scal_elemwise # numpy.complex cannot build tensors
def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components"""
......@@ -2475,13 +2482,13 @@ setdefault = default # legacy
##########################
# Arithmetics
##########################
@_scal_elemwise
@_scal_elemwise_with_nfunc('maximum', 2, 1)
def maximum(x,y):
"""elemwise maximum. See max for the maximum in one tensor
"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise_with_nfunc('minimum', 2, 1)
def minimum(x,y):
"""elemwise minimum. See min for the minimum in one tensor
"""
......@@ -2495,47 +2502,47 @@ def div_proxy(x, y):
else:
return true_div(x, y)
@_scal_elemwise
@_scal_elemwise_with_nfunc('add', 2, 1)
def add(a, *other_terms):
"""elementwise addition"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise_with_nfunc('subtract', 2, 1)
def sub(a, b):
"""elementwise subtraction"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise_with_nfunc('multiply', 2, 1)
def mul(a, *other_terms):
"""elementwise multiplication"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise_with_nfunc('true_divide', 2, 1)
def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise_with_nfunc('floor_divide', 2, 1)
def floor_div(a, b):
"""elementwise [floor] division (inverse of multiplication)"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise_with_nfunc('floor_divide', 2, 1) # not a c/p error, floor_div and int_div are the same thing
def int_div(a, b):
"""elementwise integer-division"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise_with_nfunc('mod', 2, 1)
def mod(a, b):
"""elementwise modulo"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise_with_nfunc('power', 2, 1)
def pow(a, b):
"""elementwise power"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise_with_nfunc('clip', 3, 1)
def clip(x, min, max):
"""clip x to be between min and max"""
# see decorator for function body
......
......@@ -361,6 +361,21 @@ class DimShufflePrinter:
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), DimShufflePrinter())
def _make_nfunc(name, nin, nout):
f = getattr(numpy, name)
return f
# if name.endswith("*"):
# name = name[:-1]
# f = getattr(numpy, name)
# def fn(*args):
# args[-1][:] = f(*(args[:-1]))
# return fn
# else:
# f = getattr(numpy, name)
# return f
################
### Elemwise ###
################
......@@ -392,7 +407,7 @@ class Elemwise(Op):
Elemwise(log)(rand(3, 4, 5))
"""
def __init__(self, scalar_op, inplace_pattern = {}, name = None):
def __init__(self, scalar_op, inplace_pattern = {}, name = None, nfunc_spec = None):
"""
Usage: Elemwise(scalar_op, inplace_pattern = {})
......@@ -406,10 +421,14 @@ class Elemwise(Op):
self.scalar_op = scalar_op
self.inplace_pattern = inplace_pattern
self.destroy_map = dict((o, [i]) for o, i in inplace_pattern.items())
if scalar_op.nin > 0:
self.ufunc = None
self.nfunc = None
self.nfunc_spec = nfunc_spec
if nfunc_spec:
self.nfunc = _make_nfunc(*nfunc_spec)
elif scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(scalar_op.impl, scalar_op.nin, scalar_op.nout)
else:
self.ufunc = None
#precompute the hash of this node
self._rehash()
......@@ -417,16 +436,19 @@ class Elemwise(Op):
def __getstate__(self):
d = copy(self.__dict__)
d.pop('ufunc')
d.pop('nfunc')
d.pop('__epydoc_asRoutine', None)
d.pop('_hashval')
return d
def __setstate__(self, d):
self.__dict__.update(d)
if self.scalar_op.nin > 0:
self.ufunc = None
self.nfunc = None
if getattr(self, 'nfunc_spec', None):
self.nfunc = _make_nfunc(*self.nfunc_spec)
elif self.scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.scalar_op.nin, self.scalar_op.nout)
else:
self.ufunc = None
self._rehash()
def make_node(self, *inputs):
......@@ -621,10 +643,16 @@ class Elemwise(Op):
else:
odat = numpy.ndarray(shape, dtype = output.type.dtype)
storage[0] = odat
# the second calling form is used because in certain versions of numpy
# the first (faster) version leads to segfaults
ufunc_args = inputs # + output_storage
ufunc = self.ufunc or numpy.frompyfunc(self.scalar_op.impl, len(inputs), self.scalar_op.nout)
if self.nfunc and len(inputs) == self.nfunc_spec[1]:
ufunc = self.nfunc
nout = 1
else:
# the second calling form is used because in certain versions of numpy
# the first (faster) version leads to segfaults
ufunc = self.ufunc or numpy.frompyfunc(self.scalar_op.impl, len(inputs), self.scalar_op.nout)
nout = ufunc.nout
try:
variables = ufunc(*ufunc_args)
......@@ -633,7 +661,7 @@ class Elemwise(Op):
'for params of shape', [arg.shape for arg in ufunc_args]
e.args = e.args + errormsg
raise
if ufunc.nout == 1: variables = [variables]
if nout == 1: variables = [variables]
for variable, storage in zip(variables, output_storage):
if hasattr(variable,'shape') and storage[0].shape != variable.shape:
storage[0].resize(variable.shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论