提交 4bbac540 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

...@@ -417,7 +417,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False): ...@@ -417,7 +417,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
raise ValueError('mode should be bfs or dfs', mode) raise ValueError('mode should be bfs or dfs', mode)
rval_set = set() rval_set = set()
rval_list = list() rval_list = list()
if mode is 'bfs': start_pop = start.popleft if mode == 'bfs': start_pop = start.popleft
else: start_pop = start.pop else: start_pop = start.pop
expand_inv = {} expand_inv = {}
while start: while start:
......
...@@ -106,7 +106,7 @@ def map( fn ...@@ -106,7 +106,7 @@ def map( fn
:param go_backwards: Boolean value that decides the direction of :param go_backwards: Boolean value that decides the direction of
iteration. True means that sequences are parsed iteration. True means that sequences are parsed
from the end towards the begining, while False from the end towards the beginning, while False
is the other way around. is the other way around.
:param mode: See ``scan``. :param mode: See ``scan``.
...@@ -301,7 +301,7 @@ def scan( fn ...@@ -301,7 +301,7 @@ def scan( fn
scan) scan)
The order of the sequences is the same as the one in the list The order of the sequences is the same as the one in the list
`sequences` given to scan. The order of the outputs is the sane `sequences` given to scan. The order of the outputs is the same
as the order of ``output_info``. For any sequence or output the as the order of ``output_info``. For any sequence or output the
order of the time slices is the same as the order of the time order of the time slices is the same as the order of the time
taps provided. For example if one writes the following : taps provided. For example if one writes the following :
...@@ -314,7 +314,7 @@ def scan( fn ...@@ -314,7 +314,7 @@ def scan( fn
, outputs_info = [ dict( Output1, taps = [-3,-5]) , outputs_info = [ dict( Output1, taps = [-3,-5])
, dict( Output2, taps = None) , dict( Output2, taps = None)
, Output3 ] , Output3 ]
, non_sequences = [ Argument1, Argument 2]) , non_sequences = [ Argument1, Argument2])
``fn`` should expect the following arguments in this given order: ``fn`` should expect the following arguments in this given order:
...@@ -341,7 +341,7 @@ def scan( fn ...@@ -341,7 +341,7 @@ def scan( fn
`fn` should return an update dictionary ( that tells how to `fn` should return an update dictionary ( that tells how to
update any shared variable after each iteration ste). The update any shared variable after each iteration ste). The
dictionary can optionally be given as a list of tuples. There is dictionary can optionally be given as a list of tuples. There is
no constraint on the order of these two list, ``fn`` can return no constraint on the order of these two lists, ``fn`` can return
either ``(outputs_list, update_dictionary)`` or ``(update_dictionary, either ``(outputs_list, update_dictionary)`` or ``(update_dictionary,
outputs_list)`` or just one of the two (in case the other is outputs_list)`` or just one of the two (in case the other is
empty). empty).
...@@ -369,7 +369,7 @@ def scan( fn ...@@ -369,7 +369,7 @@ def scan( fn
:param outputs_info: :param outputs_info:
``outputs_info`` is the list of Theano variables or dictionaries ``outputs_info`` is the list of Theano variables or dictionaries
describing the initial state of the outputs computed describing the initial state of the outputs computed
recurrently. When this initial states are given as dictionary recurrently. When this initial state is given as a dictionary,
optional information can be provided about the output corresponding optional information can be provided about the output corresponding
to these initial states. The dictionary should have the following to these initial states. The dictionary should have the following
keys: keys:
...@@ -388,11 +388,11 @@ def scan( fn ...@@ -388,11 +388,11 @@ def scan( fn
the initial state, which in this case should have the shape the initial state, which in this case should have the shape
(5,)+output.shape. If this variable containing the initial (5,)+output.shape. If this variable containing the initial
state is called ``init_y`` then ``init_y[0]`` *corresponds to* state is called ``init_y`` then ``init_y[0]`` *corresponds to*
``output[-5]``; ``init_y[1]`` *correponds to* ``output[-4]``; ``output[-5]``; ``init_y[1]`` *corresponds to* ``output[-4]``;
``init_y[2]`` corresponds to ``output[-3]``; ``init_y[3]`` ``init_y[2]`` corresponds to ``output[-3]``; ``init_y[3]``
coresponds to ``output[-2]``; ``init_y[4]`` corresponds to coresponds to ``output[-2]``; ``init_y[4]`` corresponds to
``output[-1]``. While this order might seem strange, it comes ``output[-1]``. While this order might seem strange, it comes
natural from splitting an array at a given point. Assume that naturally from splitting an array at a given point. Assume that
we have a array ``x``, and we choose ``k`` to be time step we have a array ``x``, and we choose ``k`` to be time step
``0``. Then our initial state would be ``x[:k]``, while the ``0``. Then our initial state would be ``x[:k]``, while the
output will be ``x[k:]``. Looking at this split, elements in output will be ``x[k:]``. Looking at this split, elements in
...@@ -401,17 +401,10 @@ def scan( fn ...@@ -401,17 +401,10 @@ def scan( fn
``fn``. They are provided as a list of *negative* integers, ``fn``. They are provided as a list of *negative* integers,
where a value ``k`` implies that at iteration step ``t`` scan will where a value ``k`` implies that at iteration step ``t`` scan will
pass to ``fn`` the slice ``t+k``. pass to ``fn`` the slice ``t+k``.
* ``inplace`` -- One of the Theano variables provided as * ``inplace`` -- DEPRECATED. Previously, one could specify with this
``sequences``. ``scan`` will try to compute this output *in option whether the output should overwrite some particular input,
place* of the provided input *iff* it respects the following but it is now inferred automatically. If you specify this option
constraints: it will be ignored.
* There is no other output that is denied to be computed in
place for whatever reason.
* ``fn`` is not using past taps of the input sequence that
will get overwritten by the output
* ``return_steps`` -- Integer representing the number of steps * ``return_steps`` -- Integer representing the number of steps
to return for the current steps. For example, if ``k`` is to return for the current steps. For example, if ``k`` is
provided, ``scan`` will return ``output[-k:]``. This is meant as a provided, ``scan`` will return ``output[-k:]``. This is meant as a
...@@ -422,7 +415,7 @@ def scan( fn ...@@ -422,7 +415,7 @@ def scan( fn
* ``store_steps`` -- Integer representing the number of * ``store_steps`` -- Integer representing the number of
intermediate steps ``scan`` should use for a given output. Use intermediate steps ``scan`` should use for a given output. Use
this key only if you really know what you are doing. In general this key only if you really know what you are doing. In general
it is recommended to let scan decide for you the ammount of memory it is recommended to let scan decide for you the amount of memory
it should use. it should use.
``scan`` will follow this logic if partial information is given: ``scan`` will follow this logic if partial information is given:
...@@ -437,12 +430,12 @@ def scan( fn ...@@ -437,12 +430,12 @@ def scan( fn
* If you wrap an output in a dictionary but you do not provide any * If you wrap an output in a dictionary but you do not provide any
initial state, it assumes that you are not using any form of initial state, it assumes that you are not using any form of
taps. taps.
* If you provide a ``None`` instead of a variable or a dictionary * If you provide ``None`` instead of a variable or a dictionary
``scan`` assumes that you will not use any taps for this output ``scan`` assumes that you will not use any taps for this output
(like for example in case of a map) (like for example in case of a map)
If ``outputs_info`` is an empty list or None, ``scan`` assumes If ``outputs_info`` is an empty list or None, ``scan`` assumes
that no tap is used for any of the otuputs. If information is that no tap is used for any of the outputs. If information is
provided just for a subset of the outputs an exception is provided just for a subset of the outputs an exception is
raised (because there is no convention on how scan should map raised (because there is no convention on how scan should map
the provided information to the outputs of ``fn``) the provided information to the outputs of ``fn``)
...@@ -450,8 +443,8 @@ def scan( fn ...@@ -450,8 +443,8 @@ def scan( fn
:param non_sequences: :param non_sequences:
``non_sequences`` is the list of arguments that are passed to ``non_sequences`` is the list of arguments that are passed to
``fn`` at each steps. One can opt to exclude shared variables ``fn`` at each step. It is not necessary to list shared variables
used in ``fn`` from this list. used in ``fn`` here, since they will be identified automatically.
:param n_steps: :param n_steps:
...@@ -469,9 +462,10 @@ def scan( fn ...@@ -469,9 +462,10 @@ def scan( fn
:param truncate_gradient: :param truncate_gradient:
``truncate_gradient`` is the number of steps to use in truncated ``truncate_gradient`` is the number of steps to use in truncated
BPTT. If you compute gradients through a scan op, they are BPTT (backpropagation through time). If you compute gradients
through a scan op, they are
computed using backpropagation through time. By providing a computed using backpropagation through time. By providing a
different value then -1, you choose to use truncated BPTT instead different value than -1, you choose to use truncated BPTT instead
of classical BPTT, where you go for only ``truncate_gradient`` of classical BPTT, where you go for only ``truncate_gradient``
number of steps back in time. number of steps back in time.
...@@ -512,33 +506,32 @@ def scan( fn ...@@ -512,33 +506,32 @@ def scan( fn
""" """
# General observation : this code is executed only once, at creation # General observation : this code is executed only once, at creation
# of the computational graph, so we don't yet need to be smart about # of the computational graph, so we don't yet need to be smart about
# anything ( to speed things up) # anything (to speed things up)
# check if inputs are just single variables instead of lists # check if inputs are just single variables instead of lists
if not (type(sequences) in (list, tuple)) and sequences != None: if sequences == None:
seqs = [sequences]
elif sequences == None:
seqs = [] seqs = []
elif not (type(sequences) in (list, tuple)):
seqs = [sequences]
else: else:
seqs = sequences seqs = sequences
if not (type(outputs_info) in (list,tuple)) and outputs_info != None: if outputs_info == None:
outs_info = [outputs_info]
elif outputs_info == None:
outs_info = [] outs_info = []
elif not (type(outputs_info) in (list,tuple)):
outs_info = [outputs_info]
else: else:
outs_info = outputs_info outs_info = outputs_info
if ( not (type(non_sequences) in (list,tuple)) if non_sequences == None:
and non_sequences != None):
non_seqs = [non_sequences]
elif non_sequences == None:
non_seqs = [] non_seqs = []
elif not (type(non_sequences) in (list,tuple)):
non_seqs = [non_sequences]
else: else:
non_seqs = non_sequences non_seqs = non_sequences
# If we provided a known number of steps ( before compilation) # If we provided a known number of steps (before compilation)
# and if that number is 1 or -1, then we can skip the Scan Op, # and if that number is 1 or -1, then we can skip the Scan Op,
# and just apply the inner function once # and just apply the inner function once
# To do that we check here to see the nature of n_steps # To do that we check here to see the nature of n_steps
...@@ -570,7 +563,7 @@ def scan( fn ...@@ -570,7 +563,7 @@ def scan( fn
sequences_taps = {} sequences_taps = {}
outputs_taps = {} outputs_taps = {}
# Assume that for any output we want to store everythin that it produces # Assume that for any output we want to store everything that it produces
store_steps = [] store_steps = []
return_steps = {} return_steps = {}
...@@ -591,8 +584,8 @@ def scan( fn ...@@ -591,8 +584,8 @@ def scan( fn
# See if the user actually provided the None value to taps, # See if the user actually provided the None value to taps,
# which would indicate that the sequence was provided but # which would indicate that the sequence was provided but
# not used by the internal function; Only if the user has # not used by the internal function; Only if the user has
# not provided anything add the defaul [0] # not provided anything add the default [0]
# Possible reason to provide a squence and not use it is # A possible reason to provide a sequence and not use it is
# if you want to compute the output # if you want to compute the output
# inplace of this input; it is a very unlikely behaviour but # inplace of this input; it is a very unlikely behaviour but
# we do want to cover it for completeness # we do want to cover it for completeness
...@@ -635,7 +628,7 @@ def scan( fn ...@@ -635,7 +628,7 @@ def scan( fn
raise ValueError('If you are using slices of an output you need to '\ raise ValueError('If you are using slices of an output you need to '\
'provide an initial state for it', outs_info[i]) 'provide an initial state for it', outs_info[i])
# if there is an intial state but no tap, we will add the default value # if there is an intial state but no tap, we will add the default value
# for taps, namely [-1] ( previous value); not that this will happen # for taps, namely [-1] ( previous value); note that this will happen
# even though you have provided for taps the value None, which is a bit # even though you have provided for taps the value None, which is a bit
# strange (why would one provide an initial state but tell scan not to # strange (why would one provide an initial state but tell scan not to
# use it ? ), just that in that case we will throw in a warning message # use it ? ), just that in that case we will throw in a warning message
...@@ -658,9 +651,14 @@ def scan( fn ...@@ -658,9 +651,14 @@ def scan( fn
if outs_info[i].get('taps', None): if outs_info[i].get('taps', None):
# Create a separate outputs_taps dictionary with all the outputs taps; This # Create a separate outputs_taps dictionary with all the outputs taps; This
# is how the Scan Op expects this information, separeted from the variables # is how the Scan Op expects this information, separated from the variables
outputs_taps[i] = outs_info[i]['taps'] outputs_taps[i] = outs_info[i]['taps']
if outs_info[i].get('inplace', None): if outs_info[i].get('inplace', None):
warning("DEPRECATED: you should not set the inplace parameter for an output in scan(...). "
"This can cause problems for the early stages of the optimizer "
"and there is a late optimization which automatically figures it out.")
# The same is true for the inplace info; it has to go into a separate # The same is true for the inplace info; it has to go into a separate
# dictionary based on index; Note that the input we're replacing should also # dictionary based on index; Note that the input we're replacing should also
# come as an index, therefore we have to look for it at this point # come as an index, therefore we have to look for it at this point
......
...@@ -1336,35 +1336,39 @@ def _redefine_asRoutine(real_symbol_value): ...@@ -1336,35 +1336,39 @@ def _redefine_asRoutine(real_symbol_value):
return real_symbol_value return real_symbol_value
return decorator return decorator
def _scal_elemwise(symbol): def _scal_elemwise_with_nfunc(nfunc, nin, nout):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op""" """Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
symbolname = symbol.__name__ def construct(symbol):
inplace = symbolname.endswith('_inplace') symbolname = symbol.__name__
if inplace: inplace = symbolname.endswith('_inplace')
msg = "inplace" if inplace:
else: msg = "inplace"
msg = "no_inplace" else:
n="Elemwise{%s,%s}"%(symbolname,msg) msg = "no_inplace"
n="Elemwise{%s,%s}"%(symbolname,msg)
if inplace: if inplace:
scalar_op = getattr(scal, symbolname[:-len('_inplace')]) scalar_op = getattr(scal, symbolname[:-len('_inplace')])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0)) inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=n) rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=n, nfunc_spec=((nfunc, nin, nout) if nfunc else None))
else: else:
scalar_op = getattr(scal, symbolname) scalar_op = getattr(scal, symbolname)
rval = elemwise.Elemwise(scalar_op, name=n) rval = elemwise.Elemwise(scalar_op, name=n, nfunc_spec=((nfunc, nin, nout) if nfunc else None))
if getattr(symbol, '__doc__', False): if getattr(symbol, '__doc__', False):
rval.__doc__ = symbol.__doc__ + '\n' + rval.__doc__ rval.__doc__ = symbol.__doc__ + '\n' + rval.__doc__
#for the meaning of this see the ./epydoc script #for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object # it makes epydoc display rval as if it were a function, not an object
rval.__epydoc_asRoutine = symbol rval.__epydoc_asRoutine = symbol
rval.__module__ = 'tensor' rval.__module__ = 'tensor'
pprint.assign(rval, printing.FunctionPrinter(symbolname)) pprint.assign(rval, printing.FunctionPrinter(symbolname))
return rval return rval
return construct
_scal_elemwise = _scal_elemwise_with_nfunc(None, None, None)
######################### #########################
...@@ -1865,27 +1869,27 @@ def largest(*args): ...@@ -1865,27 +1869,27 @@ def largest(*args):
# Comparison # Comparison
########################## ##########################
@_scal_elemwise @_scal_elemwise_with_nfunc('less', 2, 1)
def lt(a, b): def lt(a, b):
"""a < b""" """a < b"""
@_scal_elemwise @_scal_elemwise_with_nfunc('greater', 2, 1)
def gt(a, b): def gt(a, b):
"""a > b""" """a > b"""
@_scal_elemwise @_scal_elemwise_with_nfunc('less_equal', 2, 1)
def le(a, b): def le(a, b):
"""a <= b""" """a <= b"""
@_scal_elemwise @_scal_elemwise_with_nfunc('greater_equal', 2, 1)
def ge(a, b): def ge(a, b):
"""a >= b""" """a >= b"""
@_scal_elemwise @_scal_elemwise_with_nfunc('equal', 2, 1)
def eq(a, b): def eq(a, b):
"""a == b""" """a == b"""
@_scal_elemwise @_scal_elemwise_with_nfunc('not_equal', 2, 1)
def neq(a, b): def neq(a, b):
"""a != b""" """a != b"""
...@@ -1903,19 +1907,19 @@ def switch(cond, ift, iff): ...@@ -1903,19 +1907,19 @@ def switch(cond, ift, iff):
# Bit-wise # Bit-wise
########################## ##########################
@_scal_elemwise @_scal_elemwise_with_nfunc('bitwise_and', 2, 1)
def and_(a,b): def and_(a,b):
"""bitwise a & b""" """bitwise a & b"""
@_scal_elemwise @_scal_elemwise_with_nfunc('bitwise_or', 2, 1)
def or_(a,b): def or_(a,b):
"""bitwise a | b""" """bitwise a | b"""
@_scal_elemwise @_scal_elemwise_with_nfunc('bitwise_xor', 2, 1)
def xor(a,b): def xor(a,b):
"""bitwise a ^ b""" """bitwise a ^ b"""
@_scal_elemwise @_scal_elemwise_with_nfunc('invert', 1, 1)
def invert(a): def invert(a):
"""bitwise ~a""" """bitwise ~a"""
...@@ -1923,7 +1927,7 @@ def invert(a): ...@@ -1923,7 +1927,7 @@ def invert(a):
# Math # Math
########################## ##########################
@_scal_elemwise @_scal_elemwise_with_nfunc('abs', 1, 1)
def abs_(a): def abs_(a):
"""|`a`| """|`a`|
...@@ -1934,43 +1938,43 @@ def abs_(a): ...@@ -1934,43 +1938,43 @@ def abs_(a):
pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000))) pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000)))
@_scal_elemwise @_scal_elemwise_with_nfunc('exp', 1, 1)
def exp(a): def exp(a):
"""e^`a`""" """e^`a`"""
@_scal_elemwise @_scal_elemwise_with_nfunc('negative', 1, 1)
def neg(a): def neg(a):
"""-a""" """-a"""
@_scal_elemwise @_scal_elemwise # numpy.reciprocal does integer division on integer inputs (which is not very interesting)
def inv(a): def inv(a):
"""1.0/a""" """1.0/a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('log', 1, 1)
def log(a): def log(a):
"""base e logarithm of a""" """base e logarithm of a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('log2', 1, 1)
def log2(a): def log2(a):
"""base 2 logarithm of a""" """base 2 logarithm of a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('log10', 1, 1)
def log10(a): def log10(a):
"""base 10 logarithm of a""" """base 10 logarithm of a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('log1p', 1, 1)
def log1p(a): def log1p(a):
"""log(1+a)""" """log(1+a)"""
@_scal_elemwise @_scal_elemwise_with_nfunc('sign', 1, 1)
def sgn(a): def sgn(a):
"""sign of a""" """sign of a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('ceil', 1, 1)
def ceil(a): def ceil(a):
"""ceiling of a""" """ceiling of a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('floor', 1, 1)
def floor(a): def floor(a):
"""floor of a""" """floor of a"""
...@@ -1989,7 +1993,10 @@ def round(a, mode="half_away_from_zero"): ...@@ -1989,7 +1993,10 @@ def round(a, mode="half_away_from_zero"):
else: else:
raise Exception("round mode %s is not implemented."%mode) raise Exception("round mode %s is not implemented."%mode)
@_scal_elemwise # def __round_half_to_even(a, dest):
# dest[:] = numpy.around(a)
@_scal_elemwise_with_nfunc('around', 1, 0)
def round_half_to_even(a): def round_half_to_even(a):
"""round_half_to_even(a)""" """round_half_to_even(a)"""
...@@ -1997,35 +2004,35 @@ def round_half_to_even(a): ...@@ -1997,35 +2004,35 @@ def round_half_to_even(a):
def round_half_away_from_zero(a): def round_half_away_from_zero(a):
"""round_half_away_from_zero(a)""" """round_half_away_from_zero(a)"""
@_scal_elemwise @_scal_elemwise_with_nfunc('square', 1, 1)
def sqr(a): def sqr(a):
"""square of a""" """square of a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('sqrt', 1, 1)
def sqrt(a): def sqrt(a):
"""square root of a""" """square root of a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('cos', 1, 1)
def cos(a): def cos(a):
"""cosine of a""" """cosine of a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('sin', 1, 1)
def sin(a): def sin(a):
"""sine of a""" """sine of a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('tan', 1, 1)
def tan(a): def tan(a):
"""tangent of a""" """tangent of a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('cosh', 1, 1)
def cosh(a): def cosh(a):
"""hyperbolic cosine of a""" """hyperbolic cosine of a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('sinh', 1, 1)
def sinh(a): def sinh(a):
"""hyperbolic sine of a""" """hyperbolic sine of a"""
@_scal_elemwise @_scal_elemwise_with_nfunc('tanh', 1, 1)
def tanh(a): def tanh(a):
"""hyperbolic tangent of a""" """hyperbolic tangent of a"""
...@@ -2037,19 +2044,19 @@ def erf(a): ...@@ -2037,19 +2044,19 @@ def erf(a):
def erfc(a): def erfc(a):
"""complementary error function""" """complementary error function"""
@_scal_elemwise @_scal_elemwise_with_nfunc('real', 1, 0)
def real(z): def real(z):
"""Return real component of complex-valued tensor `z`""" """Return real component of complex-valued tensor `z`"""
@_scal_elemwise @_scal_elemwise_with_nfunc('imag', 1, 0)
def imag(z): def imag(z):
"""Return imaginary component of complex-valued tensor `z`""" """Return imaginary component of complex-valued tensor `z`"""
@_scal_elemwise @_scal_elemwise_with_nfunc('angle', 1, 0)
def angle(z): def angle(z):
"""Return polar-coordinate angle of complex-valued tensor `z`""" """Return polar-coordinate angle of complex-valued tensor `z`"""
@_scal_elemwise @_scal_elemwise # numpy.complex cannot build tensors
def complex(real, imag): def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components""" """Return complex-valued tensor with `real` and `imag` components"""
...@@ -2475,13 +2482,13 @@ setdefault = default # legacy ...@@ -2475,13 +2482,13 @@ setdefault = default # legacy
########################## ##########################
# Arithmetics # Arithmetics
########################## ##########################
@_scal_elemwise @_scal_elemwise_with_nfunc('maximum', 2, 1)
def maximum(x,y): def maximum(x,y):
"""elemwise maximum. See max for the maximum in one tensor """elemwise maximum. See max for the maximum in one tensor
""" """
# see decorator for function body # see decorator for function body
@_scal_elemwise @_scal_elemwise_with_nfunc('minimum', 2, 1)
def minimum(x,y): def minimum(x,y):
"""elemwise minimum. See min for the minimum in one tensor """elemwise minimum. See min for the minimum in one tensor
""" """
...@@ -2495,47 +2502,47 @@ def div_proxy(x, y): ...@@ -2495,47 +2502,47 @@ def div_proxy(x, y):
else: else:
return true_div(x, y) return true_div(x, y)
@_scal_elemwise @_scal_elemwise_with_nfunc('add', 2, 1)
def add(a, *other_terms): def add(a, *other_terms):
"""elementwise addition""" """elementwise addition"""
# see decorator for function body # see decorator for function body
@_scal_elemwise @_scal_elemwise_with_nfunc('subtract', 2, 1)
def sub(a, b): def sub(a, b):
"""elementwise subtraction""" """elementwise subtraction"""
# see decorator for function body # see decorator for function body
@_scal_elemwise @_scal_elemwise_with_nfunc('multiply', 2, 1)
def mul(a, *other_terms): def mul(a, *other_terms):
"""elementwise multiplication""" """elementwise multiplication"""
# see decorator for function body # see decorator for function body
@_scal_elemwise @_scal_elemwise_with_nfunc('true_divide', 2, 1)
def true_div(a, b): def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)""" """elementwise [true] division (inverse of multiplication)"""
# see decorator for function body # see decorator for function body
@_scal_elemwise @_scal_elemwise_with_nfunc('floor_divide', 2, 1)
def floor_div(a, b): def floor_div(a, b):
"""elementwise [floor] division (inverse of multiplication)""" """elementwise [floor] division (inverse of multiplication)"""
# see decorator for function body # see decorator for function body
@_scal_elemwise @_scal_elemwise_with_nfunc('floor_divide', 2, 1) # not a c/p error, floor_div and int_div are the same thing
def int_div(a, b): def int_div(a, b):
"""elementwise integer-division""" """elementwise integer-division"""
# see decorator for function body # see decorator for function body
@_scal_elemwise @_scal_elemwise_with_nfunc('mod', 2, 1)
def mod(a, b): def mod(a, b):
"""elementwise modulo""" """elementwise modulo"""
# see decorator for function body # see decorator for function body
@_scal_elemwise @_scal_elemwise_with_nfunc('power', 2, 1)
def pow(a, b): def pow(a, b):
"""elementwise power""" """elementwise power"""
# see decorator for function body # see decorator for function body
@_scal_elemwise @_scal_elemwise_with_nfunc('clip', 3, 1)
def clip(x, min, max): def clip(x, min, max):
"""clip x to be between min and max""" """clip x to be between min and max"""
# see decorator for function body # see decorator for function body
......
...@@ -361,6 +361,21 @@ class DimShufflePrinter: ...@@ -361,6 +361,21 @@ class DimShufflePrinter:
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), DimShufflePrinter()) pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), DimShufflePrinter())
def _make_nfunc(name, nin, nout):
f = getattr(numpy, name)
return f
# if name.endswith("*"):
# name = name[:-1]
# f = getattr(numpy, name)
# def fn(*args):
# args[-1][:] = f(*(args[:-1]))
# return fn
# else:
# f = getattr(numpy, name)
# return f
################ ################
### Elemwise ### ### Elemwise ###
################ ################
...@@ -392,7 +407,7 @@ class Elemwise(Op): ...@@ -392,7 +407,7 @@ class Elemwise(Op):
Elemwise(log)(rand(3, 4, 5)) Elemwise(log)(rand(3, 4, 5))
""" """
def __init__(self, scalar_op, inplace_pattern = {}, name = None): def __init__(self, scalar_op, inplace_pattern = {}, name = None, nfunc_spec = None):
""" """
Usage: Elemwise(scalar_op, inplace_pattern = {}) Usage: Elemwise(scalar_op, inplace_pattern = {})
...@@ -406,10 +421,14 @@ class Elemwise(Op): ...@@ -406,10 +421,14 @@ class Elemwise(Op):
self.scalar_op = scalar_op self.scalar_op = scalar_op
self.inplace_pattern = inplace_pattern self.inplace_pattern = inplace_pattern
self.destroy_map = dict((o, [i]) for o, i in inplace_pattern.items()) self.destroy_map = dict((o, [i]) for o, i in inplace_pattern.items())
if scalar_op.nin > 0:
self.ufunc = None
self.nfunc = None
self.nfunc_spec = nfunc_spec
if nfunc_spec:
self.nfunc = _make_nfunc(*nfunc_spec)
elif scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(scalar_op.impl, scalar_op.nin, scalar_op.nout) self.ufunc = numpy.frompyfunc(scalar_op.impl, scalar_op.nin, scalar_op.nout)
else:
self.ufunc = None
#precompute the hash of this node #precompute the hash of this node
self._rehash() self._rehash()
...@@ -417,16 +436,19 @@ class Elemwise(Op): ...@@ -417,16 +436,19 @@ class Elemwise(Op):
def __getstate__(self): def __getstate__(self):
d = copy(self.__dict__) d = copy(self.__dict__)
d.pop('ufunc') d.pop('ufunc')
d.pop('nfunc')
d.pop('__epydoc_asRoutine', None) d.pop('__epydoc_asRoutine', None)
d.pop('_hashval') d.pop('_hashval')
return d return d
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
if self.scalar_op.nin > 0: self.ufunc = None
self.nfunc = None
if getattr(self, 'nfunc_spec', None):
self.nfunc = _make_nfunc(*self.nfunc_spec)
elif self.scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.scalar_op.nin, self.scalar_op.nout) self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.scalar_op.nin, self.scalar_op.nout)
else:
self.ufunc = None
self._rehash() self._rehash()
def make_node(self, *inputs): def make_node(self, *inputs):
...@@ -621,10 +643,16 @@ class Elemwise(Op): ...@@ -621,10 +643,16 @@ class Elemwise(Op):
else: else:
odat = numpy.ndarray(shape, dtype = output.type.dtype) odat = numpy.ndarray(shape, dtype = output.type.dtype)
storage[0] = odat storage[0] = odat
# the second calling form is used because in certain versions of numpy
# the first (faster) version leads to segfaults
ufunc_args = inputs # + output_storage ufunc_args = inputs # + output_storage
ufunc = self.ufunc or numpy.frompyfunc(self.scalar_op.impl, len(inputs), self.scalar_op.nout) if self.nfunc and len(inputs) == self.nfunc_spec[1]:
ufunc = self.nfunc
nout = 1
else:
# the second calling form is used because in certain versions of numpy
# the first (faster) version leads to segfaults
ufunc = self.ufunc or numpy.frompyfunc(self.scalar_op.impl, len(inputs), self.scalar_op.nout)
nout = ufunc.nout
try: try:
variables = ufunc(*ufunc_args) variables = ufunc(*ufunc_args)
...@@ -633,7 +661,7 @@ class Elemwise(Op): ...@@ -633,7 +661,7 @@ class Elemwise(Op):
'for params of shape', [arg.shape for arg in ufunc_args] 'for params of shape', [arg.shape for arg in ufunc_args]
e.args = e.args + errormsg e.args = e.args + errormsg
raise raise
if ufunc.nout == 1: variables = [variables] if nout == 1: variables = [variables]
for variable, storage in zip(variables, output_storage): for variable, storage in zip(variables, output_storage):
if hasattr(variable,'shape') and storage[0].shape != variable.shape: if hasattr(variable,'shape') and storage[0].shape != variable.shape:
storage[0].resize(variable.shape) storage[0].resize(variable.shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论