提交 e8ecd0fc authored 作者: abergeron's avatar abergeron

Merge pull request #3296 from harlouci/numpydoc_typedList_scalar

Numpydoc typed list scalar
""" """
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ .. warning::
WARNING
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
This directory is for the internal of Theano. This directory is for the internal of Theano.
...@@ -42,12 +40,18 @@ builtin_float = float ...@@ -42,12 +40,18 @@ builtin_float = float
class ComplexError(Exception): class ComplexError(Exception):
"""Raised if complex numbers are used in an unsupported operation.""" """
Raised if complex numbers are used in an unsupported operation.
"""
pass pass
class IntegerDivisionError(Exception): class IntegerDivisionError(Exception):
"""Raised if someone tries to divide integers with '/' instead of '//'.""" """
Raised if someone tries to divide integers with '/' instead of '//'.
"""
pass pass
...@@ -87,6 +91,7 @@ def get_scalar_type(dtype): ...@@ -87,6 +91,7 @@ def get_scalar_type(dtype):
Return a Scalar(dtype) object. Return a Scalar(dtype) object.
This caches objects to save allocation and run time. This caches objects to save allocation and run time.
""" """
if dtype not in get_scalar_type.cache: if dtype not in get_scalar_type.cache:
get_scalar_type.cache[dtype] = Scalar(dtype=dtype) get_scalar_type.cache[dtype] = Scalar(dtype=dtype)
...@@ -147,13 +152,16 @@ def constant(x): ...@@ -147,13 +152,16 @@ def constant(x):
class Scalar(Type): class Scalar(Type):
""" """
Internal class, should not be used by clients Internal class, should not be used by clients.
Primarily used by tensor.elemwise and tensor.reduce
Analogous to TensorType, but for zero-dimensional objects Primarily used by tensor.elemwise and tensor.reduce.
Maps directly to C primitives Analogous to TensorType, but for zero-dimensional objects.
Maps directly to C primitives.
TODO: refactor to be named ScalarType for consistency with TensorType.
TODO: refactor to be named ScalarType for consistency with TensorType
""" """
ndim = 0 ndim = 0
def __init__(self, dtype): def __init__(self, dtype):
...@@ -533,7 +541,7 @@ class _scalar_py_operators: ...@@ -533,7 +541,7 @@ class _scalar_py_operators:
ndim = 0 ndim = 0
dtype = property(lambda self: self.type.dtype) dtype = property(lambda self: self.type.dtype)
""" The dtype of this scalar. """ """The dtype of this scalar."""
# UNARY # UNARY
def __abs__(self): def __abs__(self):
...@@ -683,6 +691,7 @@ class upgrade_to_float(object): ...@@ -683,6 +691,7 @@ class upgrade_to_float(object):
def __new__(self, *types): def __new__(self, *types):
""" """
Upgrade any int types to float32 or float64 to avoid losing precision. Upgrade any int types to float32 or float64 to avoid losing precision.
""" """
conv = {int8: float32, conv = {int8: float32,
int16: float32, int16: float32,
...@@ -763,7 +772,8 @@ def float_out(*types): ...@@ -763,7 +772,8 @@ def float_out(*types):
def upgrade_to_float_no_complex(*types): def upgrade_to_float_no_complex(*types):
""" """
don't accept complex, otherwise call upgrade_to_float(). Don't accept complex, otherwise call upgrade_to_float().
""" """
for type in types: for type in types:
if type in complex_types: if type in complex_types:
...@@ -793,12 +803,13 @@ def float_out_nocomplex(*types): ...@@ -793,12 +803,13 @@ def float_out_nocomplex(*types):
class unary_out_lookup(gof.utils.object2): class unary_out_lookup(gof.utils.object2):
""" """
get a output_types_preference object by passing a dictionary: Get a output_types_preference object by passing a dictionary:
unary_out_lookup({int8:int32, float32:complex128}) unary_out_lookup({int8:int32, float32:complex128})
The result is an op that maps in8 to int32 and float32 to The result is an op that maps in8 to int32 and float32 to
complex128 and other input types lead to a TypeError. complex128 and other input types lead to a TypeError.
""" """
def __init__(self, type_table): def __init__(self, type_table):
self.tbl = type_table self.tbl = type_table
...@@ -917,9 +928,9 @@ class ScalarOp(Op): ...@@ -917,9 +928,9 @@ class ScalarOp(Op):
return (4,) return (4,)
def c_code_contiguous(self, node, name, inp, out, sub): def c_code_contiguous(self, node, name, inp, out, sub):
"""This function is called by Elemwise when all inputs and """
outputs are c_contiguous. This allows to use the SIMD version This function is called by Elemwise when all inputs and outputs are
of this op. c_contiguous. This allows to use the SIMD version of this op.
The inputs are the same as c_code except that: The inputs are the same as c_code except that:
...@@ -1002,6 +1013,7 @@ class LogicalComparison(BinaryScalarOp): ...@@ -1002,6 +1013,7 @@ class LogicalComparison(BinaryScalarOp):
class FixedLogicalComparison(UnaryScalarOp): class FixedLogicalComparison(UnaryScalarOp):
""" """
Comparison to a fixed value. Comparison to a fixed value.
""" """
def output_types(self, *input_dtypes): def output_types(self, *input_dtypes):
return [int8] return [int8]
...@@ -1531,17 +1543,29 @@ def int_or_true_div(x_discrete, y_discrete): ...@@ -1531,17 +1543,29 @@ def int_or_true_div(x_discrete, y_discrete):
""" """
Return 'int' or 'true' depending on the type of division used for x / y. Return 'int' or 'true' depending on the type of division used for x / y.
:param x_discrete: True if `x` is discrete ([unsigned] integer). Parameters
----------
x_discrete : bool
True if `x` is discrete ([unsigned] integer).
y_discrete : bool
True if `y` is discrete ([unsigned] integer).
Returns
-------
str
'int' if `x / y` should be an integer division, or `true` if it
should be a true division.
Raises
------
IntegerDivisionError
If both `x_discrete` and `y_discrete` are True and `config.int_division`
is set to 'raise'.
Notes
-----
This function is used by both scalar/basic.py and tensor/basic.py.
:param y_discrete: True if `x` is discrete ([unsigned] integer).
:returns: 'int' if `x / y` should be an integer division, or `true` if it
should be a true division.
Raises an IntegerDivisionError if both `x_discrete` and `y_discrete` are
True and `config.int_division` is set to 'raise'.
This function is used by both scalar/basic.py and tensor.basic/py.
""" """
if (x_discrete and y_discrete): if (x_discrete and y_discrete):
if config.int_division == 'raise': if config.int_division == 'raise':
...@@ -1568,7 +1592,10 @@ def int_or_true_div(x_discrete, y_discrete): ...@@ -1568,7 +1592,10 @@ def int_or_true_div(x_discrete, y_discrete):
def div_proxy(x, y): def div_proxy(x, y):
"""Proxy for either true_div or int_div, depending on types of x, y.""" """
Proxy for either true_div or int_div, depending on types of x, y.
"""
f = eval('%s_div' % int_or_true_div(as_scalar(x).type in discrete_types, f = eval('%s_div' % int_or_true_div(as_scalar(x).type in discrete_types,
as_scalar(y).type in discrete_types)) as_scalar(y).type in discrete_types))
return f(x, y) return f(x, y)
...@@ -1735,8 +1762,9 @@ class Mod(BinaryScalarOp): ...@@ -1735,8 +1762,9 @@ class Mod(BinaryScalarOp):
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
""" """
We want the result to have the same sign as python, not the other We want the result to have the same sign as Python, not the other
implementation of mod. implementation of mod.
""" """
(x, y) = inputs (x, y) = inputs
(z,) = outputs (z,) = outputs
...@@ -2027,7 +2055,10 @@ _cast_mapping = { ...@@ -2027,7 +2055,10 @@ _cast_mapping = {
def cast(x, dtype): def cast(x, dtype):
"""Symbolically cast `x` to a Scalar of given `dtype`.""" """
Symbolically cast `x` to a Scalar of given `dtype`.
"""
if dtype == 'floatX': if dtype == 'floatX':
dtype = config.floatX dtype = config.floatX
...@@ -2176,10 +2207,11 @@ trunc = Trunc(same_out_nocomplex, name='trunc') ...@@ -2176,10 +2207,11 @@ trunc = Trunc(same_out_nocomplex, name='trunc')
class RoundHalfToEven(UnaryScalarOp): class RoundHalfToEven(UnaryScalarOp):
""" """
This function implement the same rounding than numpy: Round half to even This function implement the same rounding than numpy: Round half to even.
c/c++ round fct IS DIFFERENT! c/c++ round fct IS DIFFERENT!
See http://en.wikipedia.org/wiki/Rounding for more detail See http://en.wikipedia.org/wiki/Rounding for more details.
""" """
def impl(self, x): def impl(self, x):
return numpy.round(x) return numpy.round(x)
...@@ -2273,9 +2305,10 @@ def round_half_away_from_zero_vec(a): ...@@ -2273,9 +2305,10 @@ def round_half_away_from_zero_vec(a):
class RoundHalfAwayFromZero(UnaryScalarOp): class RoundHalfAwayFromZero(UnaryScalarOp):
""" """
Implement the same rounding algo as c round() fct. Implement the same rounding algo as c round() fct.
numpy.round fct IS DIFFERENT! numpy.round fct IS DIFFERENT!
See http://en.wikipedia.org/wiki/Rounding for more details.
See http://en.wikipedia.org/wiki/Rounding for more detail
""" """
def impl(self, x): def impl(self, x):
return round_half_away_from_zero_vec(x) return round_half_away_from_zero_vec(x)
...@@ -2332,7 +2365,10 @@ pprint.assign(mod, printing.OperatorPrinter('%', -1, 'left')) ...@@ -2332,7 +2365,10 @@ pprint.assign(mod, printing.OperatorPrinter('%', -1, 'left'))
class Inv(UnaryScalarOp): class Inv(UnaryScalarOp):
""" multiplicative inverse. Also called reciprocal""" """
Multiplicative inverse. Also called reciprocal.
"""
def impl(self, x): def impl(self, x):
return numpy.float32(1.0) / x return numpy.float32(1.0) / x
...@@ -2359,7 +2395,10 @@ inv = Inv(upgrade_to_float, name='inv') ...@@ -2359,7 +2395,10 @@ inv = Inv(upgrade_to_float, name='inv')
class Log(UnaryScalarOp): class Log(UnaryScalarOp):
""" log base e """ """
log base e.
"""
amd_float32 = "amd_vrsa_logf" amd_float32 = "amd_vrsa_logf"
amd_float64 = "amd_vrda_log" amd_float64 = "amd_vrda_log"
...@@ -2397,7 +2436,10 @@ log = Log(upgrade_to_float, name='log') ...@@ -2397,7 +2436,10 @@ log = Log(upgrade_to_float, name='log')
class Log2(UnaryScalarOp): class Log2(UnaryScalarOp):
""" log base 2 """ """
log base 2.
"""
amd_float32 = "amd_vrsa_log2f" amd_float32 = "amd_vrsa_log2f"
amd_float64 = "amd_vrda_log2" amd_float64 = "amd_vrda_log2"
...@@ -2432,7 +2474,10 @@ log2 = Log2(upgrade_to_float, name='log2') ...@@ -2432,7 +2474,10 @@ log2 = Log2(upgrade_to_float, name='log2')
class Log10(UnaryScalarOp): class Log10(UnaryScalarOp):
""" log base 10 """ """
log base 10.
"""
amd_float32 = "amd_vrsa_log10f" amd_float32 = "amd_vrsa_log10f"
amd_float64 = "amd_vrda_log10" amd_float64 = "amd_vrda_log10"
...@@ -2467,7 +2512,10 @@ log10 = Log10(upgrade_to_float, name='log10') ...@@ -2467,7 +2512,10 @@ log10 = Log10(upgrade_to_float, name='log10')
class Log1p(UnaryScalarOp): class Log1p(UnaryScalarOp):
""" log(1+x) """ """
log(1+x).
"""
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.log1p will compute the result in # If x is an int8 or uint8, numpy.log1p will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
...@@ -2951,7 +2999,8 @@ arctan2 = ArcTan2(upgrade_to_float, name='arctan2') ...@@ -2951,7 +2999,8 @@ arctan2 = ArcTan2(upgrade_to_float, name='arctan2')
class Cosh(UnaryScalarOp): class Cosh(UnaryScalarOp):
""" """
cosh(x) = (exp(x) + exp(-x)) / 2 cosh(x) = (exp(x) + exp(-x)) / 2.
""" """
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.cosh will compute the result in # If x is an int8 or uint8, numpy.cosh will compute the result in
...@@ -3016,7 +3065,8 @@ arccosh = ArcCosh(upgrade_to_float, name='arccosh') ...@@ -3016,7 +3065,8 @@ arccosh = ArcCosh(upgrade_to_float, name='arccosh')
class Sinh(UnaryScalarOp): class Sinh(UnaryScalarOp):
""" """
sinh(x) = (exp(x) - exp(-x)) / 2 sinh(x) = (exp(x) - exp(-x)) / 2.
""" """
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.sinh will compute the result in # If x is an int8 or uint8, numpy.sinh will compute the result in
...@@ -3082,7 +3132,8 @@ arcsinh = ArcSinh(upgrade_to_float, name='arcsinh') ...@@ -3082,7 +3132,8 @@ arcsinh = ArcSinh(upgrade_to_float, name='arcsinh')
class Tanh(UnaryScalarOp): class Tanh(UnaryScalarOp):
""" """
tanh(x) = sinh(x) / cosh(x) tanh(x) = sinh(x) / cosh(x)
= (exp(2*x) - 1) / (exp(2*x) + 1) = (exp(2*x) - 1) / (exp(2*x) + 1).
""" """
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.tanh will compute the result in # If x is an int8 or uint8, numpy.tanh will compute the result in
...@@ -3146,7 +3197,10 @@ arctanh = ArcTanh(upgrade_to_float, name='arctanh') ...@@ -3146,7 +3197,10 @@ arctanh = ArcTanh(upgrade_to_float, name='arctanh')
class Real(UnaryScalarOp): class Real(UnaryScalarOp):
"""Extract the real coordinate of a complex number. """ """
Extract the real coordinate of a complex number.
"""
def impl(self, x): def impl(self, x):
return numpy.real(x) return numpy.real(x)
...@@ -3271,6 +3325,7 @@ class Composite(ScalarOp): ...@@ -3271,6 +3325,7 @@ class Composite(ScalarOp):
fusion. fusion.
Composite depends on all the Ops in its graph having C code. Composite depends on all the Ops in its graph having C code.
""" """
def __str__(self): def __str__(self):
return self.name return self.name
...@@ -3280,6 +3335,7 @@ class Composite(ScalarOp): ...@@ -3280,6 +3335,7 @@ class Composite(ScalarOp):
This op.__init__ fct don't have the same parameter as other scalar op. This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization. This break the insert_inplace_optimizer optimization.
This fct allow fix patch this. This fct allow fix patch this.
""" """
out = self.__class__(self.inputs, self.outputs) out = self.__class__(self.inputs, self.outputs)
if name: if name:
...@@ -3290,7 +3346,10 @@ class Composite(ScalarOp): ...@@ -3290,7 +3346,10 @@ class Composite(ScalarOp):
return out return out
def init_c_code(self): def init_c_code(self):
"""Return the C code for this Composite Op. """ """
Return the C code for this Composite Op.
"""
subd = dict(chain( subd = dict(chain(
((e, "%%(i%i)s" % i) for i, e in enumerate(self.fgraph.inputs)), ((e, "%%(i%i)s" % i) for i, e in enumerate(self.fgraph.inputs)),
((e, "%%(o%i)s" % i) for i, e in enumerate(self.fgraph.outputs)))) ((e, "%%(o%i)s" % i) for i, e in enumerate(self.fgraph.outputs))))
...@@ -3335,7 +3394,9 @@ class Composite(ScalarOp): ...@@ -3335,7 +3394,9 @@ class Composite(ScalarOp):
self._c_code = _c_code self._c_code = _c_code
def init_py_impls(self): def init_py_impls(self):
"""Return a list of functions that compute each output of self """
Return a list of functions that compute each output of self.
""" """
def compose_impl(r): def compose_impl(r):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1) # this is not optimal at all eg in add(*1 -> mul(x, y), *1)
...@@ -3353,7 +3414,9 @@ class Composite(ScalarOp): ...@@ -3353,7 +3414,9 @@ class Composite(ScalarOp):
self._impls = [compose_impl(r) for r in self.fgraph.outputs] self._impls = [compose_impl(r) for r in self.fgraph.outputs]
def init_name(self): def init_name(self):
"""Return a readable string representation of self.fgraph """
Return a readable string representation of self.fgraph.
""" """
try: try:
rval = self.name rval = self.name
......
...@@ -87,14 +87,18 @@ erfc = Erfc(upgrade_to_float_no_complex, name='erfc') ...@@ -87,14 +87,18 @@ erfc = Erfc(upgrade_to_float_no_complex, name='erfc')
class Erfcx(UnaryScalarOp): class Erfcx(UnaryScalarOp):
""" """
Implements the scaled complementary error function exp(x**2)*erfc(x) in a numerically stable way for large x. This Implements the scaled complementary error function exp(x**2)*erfc(x) in a
is useful for calculating things like log(erfc(x)) = log(erfcx(x)) - x ** 2 without causing underflow. Should only numerically stable way for large x. This is useful for calculating things
be used if x is known to be large and positive, as using erfcx(x) for large negative x may instead introduce like log(erfc(x)) = log(erfcx(x)) - x ** 2 without causing underflow.
overflow problems. Should only be used if x is known to be large and positive, as using
erfcx(x) for large negative x may instead introduce overflow problems.
Note: This op can still be executed on GPU, despite not having c_code. When
Notes
-----
This op can still be executed on GPU, despite not having c_code. When
running on GPU, sandbox.cuda.opt.local_gpu_elemwise_[0,1] replaces this op running on GPU, sandbox.cuda.opt.local_gpu_elemwise_[0,1] replaces this op
with sandbox.cuda.elemwise.ErfcxGPU. with sandbox.cuda.elemwise.ErfcxGPU.
""" """
def impl(self, x): def impl(self, x):
if imported_scipy_special: if imported_scipy_special:
...@@ -124,7 +128,9 @@ class Erfinv(UnaryScalarOp): ...@@ -124,7 +128,9 @@ class Erfinv(UnaryScalarOp):
""" """
Implements the inverse error function. Implements the inverse error function.
Note: This op can still be executed on GPU, despite not having c_code. When Notes
-----
This op can still be executed on GPU, despite not having c_code. When
running on GPU, sandbox.cuda.opt.local_gpu_elemwise_[0,1] replaces this op running on GPU, sandbox.cuda.opt.local_gpu_elemwise_[0,1] replaces this op
with sandbox.cuda.elemwise.ErfinvGPU. with sandbox.cuda.elemwise.ErfinvGPU.
...@@ -237,6 +243,7 @@ gamma = Gamma(upgrade_to_float, name='gamma') ...@@ -237,6 +243,7 @@ gamma = Gamma(upgrade_to_float, name='gamma')
class GammaLn(UnaryScalarOp): class GammaLn(UnaryScalarOp):
""" """
Log gamma function. Log gamma function.
""" """
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
...@@ -280,6 +287,7 @@ gammaln = GammaLn(upgrade_to_float, name='gammaln') ...@@ -280,6 +287,7 @@ gammaln = GammaLn(upgrade_to_float, name='gammaln')
class Psi(UnaryScalarOp): class Psi(UnaryScalarOp):
""" """
Derivative of log gamma function. Derivative of log gamma function.
""" """
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
...@@ -360,13 +368,13 @@ psi = Psi(upgrade_to_float, name='psi') ...@@ -360,13 +368,13 @@ psi = Psi(upgrade_to_float, name='psi')
class Chi2SF(BinaryScalarOp): class Chi2SF(BinaryScalarOp):
""" """
Compute (1 - chi2_cdf(x)) Compute (1 - chi2_cdf(x)) ie. chi2 pvalue (chi2 'survival function').
ie. chi2 pvalue (chi2 'survival function')
C code is provided in the Theano_lgpl repository. C code is provided in the Theano_lgpl repository.
This make it faster. This make it faster.
https://github.com/Theano/Theano_lgpl.git https://github.com/Theano/Theano_lgpl.git
""" """
@staticmethod @staticmethod
......
...@@ -28,8 +28,11 @@ def theano_dtype(expr): ...@@ -28,8 +28,11 @@ def theano_dtype(expr):
class SymPyCCode(ScalarOp): class SymPyCCode(ScalarOp):
""" An Operator that wraps SymPy's C code generation """
An Operator that wraps SymPy's C code generation.
Examples
--------
>>> from sympy.abc import x, y # SymPy Variables >>> from sympy.abc import x, y # SymPy Variables
>>> from theano.scalar.basic_sympy import SymPyCCode >>> from theano.scalar.basic_sympy import SymPyCCode
>>> op = SymPyCCode([x, y], x + y) >>> op = SymPyCCode([x, y], x + y)
...@@ -42,6 +45,7 @@ class SymPyCCode(ScalarOp): ...@@ -42,6 +45,7 @@ class SymPyCCode(ScalarOp):
>>> f = theano.function([xt, yt], zt) >>> f = theano.function([xt, yt], zt)
>>> f(1.0, 2.0) >>> f(1.0, 2.0)
3.0 3.0
""" """
def __init__(self, inputs, expr, name=None): def __init__(self, inputs, expr, name=None):
......
"""A shared variable container for true scalars - for internal use. """
A shared variable container for true scalars - for internal use.
Why does this file exist? Why does this file exist?
------------------------- -------------------------
...@@ -37,9 +38,12 @@ class ScalarSharedVariable(_scalar_py_operators, SharedVariable): ...@@ -37,9 +38,12 @@ class ScalarSharedVariable(_scalar_py_operators, SharedVariable):
def shared(value, name=None, strict=False, allow_downcast=None): def shared(value, name=None, strict=False, allow_downcast=None):
"""SharedVariable constructor for scalar values. Default: int64 or float64. """
SharedVariable constructor for scalar values. Default: int64 or float64.
:note: We implement this using 0-d tensors for now. Notes
-----
We implement this using 0-d tensors for now.
""" """
if not isinstance(value, (numpy.number, float, int, complex)): if not isinstance(value, (numpy.number, float, int, complex)):
......
...@@ -48,6 +48,7 @@ class _typed_list_py_operators: ...@@ -48,6 +48,7 @@ class _typed_list_py_operators:
class TypedListVariable(_typed_list_py_operators, Variable): class TypedListVariable(_typed_list_py_operators, Variable):
""" """
Subclass to add the typed list operators to the basic `Variable` class. Subclass to add the typed list operators to the basic `Variable` class.
""" """
TypedListType.Variable = TypedListVariable TypedListType.Variable = TypedListVariable
...@@ -104,8 +105,13 @@ getitem = GetItem() ...@@ -104,8 +105,13 @@ getitem = GetItem()
""" """
Get specified slice of a typed list. Get specified slice of a typed list.
:param x: typed list. Parameters
:param index: the index of the value to return from `x`. ----------
x
Typed list.
index
The index of the value to return from `x`.
""" """
...@@ -174,8 +180,13 @@ append = Append() ...@@ -174,8 +180,13 @@ append = Append()
""" """
Append an element at the end of another list. Append an element at the end of another list.
:param x: the base typed list. Parameters
:param y: the element to append to `x`. ----------
x
The base typed list.
y
The element to append to `x`.
""" """
...@@ -250,8 +261,13 @@ extend = Extend() ...@@ -250,8 +261,13 @@ extend = Extend()
""" """
Append all elements of a list at the end of another list. Append all elements of a list at the end of another list.
:param x: The typed list to extend. Parameters
:param toAppend: The typed list that will be added at the end of `x`. ----------
x
The typed list to extend.
toAppend
The typed list that will be added at the end of `x`.
""" """
...@@ -325,9 +341,15 @@ insert = Insert() ...@@ -325,9 +341,15 @@ insert = Insert()
""" """
Insert an element at an index in a typed list. Insert an element at an index in a typed list.
:param x: the typed list to modify. Parameters
:param index: the index where to put the new element in `x`. ----------
:param toInsert: The new element to insert. x
The typed list to modify.
index
The index where to put the new element in `x`.
toInsert
The new element to insert.
""" """
...@@ -356,9 +378,9 @@ class Remove(Op): ...@@ -356,9 +378,9 @@ class Remove(Op):
out[0] = x out[0] = x
""" """
inelegant workaround for ValueError: The truth value of an Inelegant workaround for ValueError: The truth value of an
array with more than one element is ambiguous. Use a.any() or a.all() array with more than one element is ambiguous. Use a.any() or a.all()
being thrown when trying to remove a matrix from a matrices list being thrown when trying to remove a matrix from a matrices list.
""" """
for y in range(out[0].__len__()): for y in range(out[0].__len__()):
if node.inputs[0].ttype.values_eq(out[0][y], toRemove): if node.inputs[0].ttype.values_eq(out[0][y], toRemove):
...@@ -371,13 +393,18 @@ class Remove(Op): ...@@ -371,13 +393,18 @@ class Remove(Op):
remove = Remove() remove = Remove()
"""Remove an element from a typed list. """Remove an element from a typed list.
:param x: the typed list to be changed. Parameters
:param toRemove: an element to be removed from the typed list. ----------
x
The typed list to be changed.
toRemove
An element to be removed from the typed list.
We only remove the first instance. We only remove the first instance.
:note: Python implementation of remove doesn't work when we want to Notes
remove an ndarray from a list. This implementation works in that -----
case. Python implementation of remove doesn't work when we want to remove an ndarray
from a list. This implementation works in that case.
""" """
...@@ -437,7 +464,11 @@ reverse = Reverse() ...@@ -437,7 +464,11 @@ reverse = Reverse()
""" """
Reverse the order of a typed list. Reverse the order of a typed list.
:param x: the typed list to be reversed. Parameters
----------
x
The typed list to be reversed.
""" """
...@@ -452,7 +483,7 @@ class Index(Op): ...@@ -452,7 +483,7 @@ class Index(Op):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
""" """
inelegant workaround for ValueError: The truth value of an Inelegant workaround for ValueError: The truth value of an
array with more than one element is ambiguous. Use a.any() or a.all() array with more than one element is ambiguous. Use a.any() or a.all()
being thrown when trying to remove a matrix from a matrices list being thrown when trying to remove a matrix from a matrices list
""" """
...@@ -480,7 +511,7 @@ class Count(Op): ...@@ -480,7 +511,7 @@ class Count(Op):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
""" """
inelegant workaround for ValueError: The truth value of an Inelegant workaround for ValueError: The truth value of an
array with more than one element is ambiguous. Use a.any() or a.all() array with more than one element is ambiguous. Use a.any() or a.all()
being thrown when trying to remove a matrix from a matrices list being thrown when trying to remove a matrix from a matrices list
""" """
...@@ -499,13 +530,18 @@ count = Count() ...@@ -499,13 +530,18 @@ count = Count()
""" """
Count the number of times an element is in the typed list. Count the number of times an element is in the typed list.
:param x: The typed list to look into. Parameters
:param elem: The element we want to count in list. ----------
x
The typed list to look into.
elem
The element we want to count in list.
The elements are compared with equals. The elements are compared with equals.
:note: Python implementation of count doesn't work when we want to Notes
count an ndarray from a list. This implementation works in that -----
case. Python implementation of count doesn't work when we want to count an ndarray
from a list. This implementation works in that case.
""" """
...@@ -543,7 +579,11 @@ length = Length() ...@@ -543,7 +579,11 @@ length = Length()
""" """
Returns the size of a list. Returns the size of a list.
:param x: typed list. Parameters
----------
x
Typed list.
""" """
...@@ -573,7 +613,12 @@ make_list = MakeList() ...@@ -573,7 +613,12 @@ make_list = MakeList()
""" """
Build a Python list from those Theano variable. Build a Python list from those Theano variable.
:param a: tuple/list of Theano variable Parameters
----------
a : tuple/list of Theano variable
Notes
-----
All Theano variables must have the same type.
:note: All Theano variable must have the same type.
""" """
...@@ -2,16 +2,20 @@ from theano import gof ...@@ -2,16 +2,20 @@ from theano import gof
class TypedListType(gof.Type): class TypedListType(gof.Type):
"""
Parameters
----------
ttype
Type of theano variable this list will contains, can be another list.
depth
Optionnal parameters, any value above 0 will create a nested list of
this depth. (0-based)
"""
def __init__(self, ttype, depth=0): def __init__(self, ttype, depth=0):
"""
:Parameters:
-'ttype' : Type of theano variable this list
will contains, can be another list.
-'depth' : Optionnal parameters, any value
above 0 will create a nested list of this
depth. (0-based)
"""
if depth < 0: if depth < 0:
raise ValueError('Please specify a depth superior or' raise ValueError('Please specify a depth superior or'
'equal to 0') 'equal to 0')
...@@ -25,10 +29,16 @@ class TypedListType(gof.Type): ...@@ -25,10 +29,16 @@ class TypedListType(gof.Type):
def filter(self, x, strict=False, allow_downcast=None): def filter(self, x, strict=False, allow_downcast=None):
""" """
:Parameters:
-'x' : value to filter Parameters
-'strict' : if true, only native python list will be accepted ----------
-'allow_downcast' : does not have any utility at the moment x
Value to filter.
strict
If true, only native python list will be accepted.
allow_downcast
Does not have any utility at the moment.
""" """
if strict: if strict:
if not isinstance(x, list): if not isinstance(x, list):
...@@ -45,9 +55,9 @@ class TypedListType(gof.Type): ...@@ -45,9 +55,9 @@ class TypedListType(gof.Type):
def __eq__(self, other): def __eq__(self, other):
""" """
two list are equals if they contains the same type. Two lists are equal if they contain the same type.
"""
"""
return type(self) == type(other) and self.ttype == other.ttype return type(self) == type(other) and self.ttype == other.ttype
def __hash__(self): def __hash__(self):
...@@ -58,8 +68,8 @@ class TypedListType(gof.Type): ...@@ -58,8 +68,8 @@ class TypedListType(gof.Type):
def get_depth(self): def get_depth(self):
""" """
utilitary function to get the 0 based Utilitary function to get the 0 based level of the list.
level of the list
""" """
if isinstance(self.ttype, TypedListType): if isinstance(self.ttype, TypedListType):
return self.ttype.get_depth() + 1 return self.ttype.get_depth() + 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论