提交 286e1df5 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fix (mostly pre-existing) pep8 violations in tensor.basic

上级 ba4ba5ef
...@@ -48,6 +48,7 @@ continuous_dtypes = map(str, scal.continuous_types) ...@@ -48,6 +48,7 @@ continuous_dtypes = map(str, scal.continuous_types)
discrete_dtypes = map(str, scal.discrete_types) discrete_dtypes = map(str, scal.discrete_types)
all_dtypes = map(str, scal.all_types) all_dtypes = map(str, scal.all_types)
class ShapeError(Exception): class ShapeError(Exception):
"""Raised when the shape cannot be computed.""" """Raised when the shape cannot be computed."""
pass pass
...@@ -395,6 +396,7 @@ def constant(x, name=None, ndim=None, dtype=None): ...@@ -395,6 +396,7 @@ def constant(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim, return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim,
dtype=dtype) dtype=dtype)
def _obj_is_wrappable_as_tensor(x): def _obj_is_wrappable_as_tensor(x):
try: try:
constant(x) constant(x)
...@@ -406,7 +408,7 @@ def _obj_is_wrappable_as_tensor(x): ...@@ -406,7 +408,7 @@ def _obj_is_wrappable_as_tensor(x):
def _wrap_tensor_into_member(x): def _wrap_tensor_into_member(x):
return compile.module.Member(constant(x)) return compile.module.Member(constant(x))
compile.module.register_wrapper(_obj_is_wrappable_as_tensor, compile.module.register_wrapper(_obj_is_wrappable_as_tensor,
_wrap_tensor_into_member, no_warn = True) _wrap_tensor_into_member, no_warn=True)
if int(config.tensor.cmp_sloppy) > 1: if int(config.tensor.cmp_sloppy) > 1:
...@@ -1503,10 +1505,9 @@ class _tensor_py_operators: ...@@ -1503,10 +1505,9 @@ class _tensor_py_operators:
""" """
if ndim is not None: if ndim is not None:
if not isinstance(ndim,int): if not isinstance(ndim, int):
raise ValueError("Expected ndim to be an integer, is "\ raise ValueError("Expected ndim to be an integer, is "\
+str(type(ndim))) + str(type(ndim)))
return reshape(self, shape, ndim=ndim) return reshape(self, shape, ndim=ndim)
...@@ -1803,7 +1804,6 @@ class TensorConstant(_tensor_py_operators, Constant): ...@@ -1803,7 +1804,6 @@ class TensorConstant(_tensor_py_operators, Constant):
TensorType.Constant = TensorConstant TensorType.Constant = TensorConstant
Tensor = TensorType Tensor = TensorType
...@@ -1817,6 +1817,7 @@ elemwise.TensorConstant = TensorConstant ...@@ -1817,6 +1817,7 @@ elemwise.TensorConstant = TensorConstant
# Utilities # Utilities
######################### #########################
def _redefine(real_symbol_value, module='tensor'): def _redefine(real_symbol_value, module='tensor'):
"""Replace the value associated with a function symbol. """Replace the value associated with a function symbol.
...@@ -2062,6 +2063,7 @@ def cast(x, dtype): ...@@ -2062,6 +2063,7 @@ def cast(x, dtype):
# Unary Operations # Unary Operations
########################## ##########################
class Shape(Op): class Shape(Op):
""" """
L{Op} to return the shape of a matrix. L{Op} to return the shape of a matrix.
...@@ -2333,16 +2335,16 @@ class MaxAndArgmax(Op): ...@@ -2333,16 +2335,16 @@ class MaxAndArgmax(Op):
#if the op is totally disconnected, so are its inputs #if the op is totally disconnected, so are its inputs
if g_max_disconnected and g_max_idx_disconnected: if g_max_disconnected and g_max_idx_disconnected:
return [ DisconnectedType()(), DisconnectedType()() ] return [DisconnectedType()(), DisconnectedType()()]
axis_grad = grad_undefined(self, 1, axis, axis_grad = grad_undefined(self, 1, axis,
"argmax is not defined for non-integer axes so" "argmax is not defined for non-integer axes so"
" argmax(x, axis+eps) is undefined" ) " argmax(x, axis+eps) is undefined")
#if the max is disconnected but the argmax is not, #if the max is disconnected but the argmax is not,
#the gradient on its inputs is zero #the gradient on its inputs is zero
if g_max_disconnected: if g_max_disconnected:
return [ x.zeros_like(), axis_grad ] return [x.zeros_like(), axis_grad]
xmax = max(x, axis) xmax = max(x, axis)
# Raise the g_max and xmax to the same number of dim as the input. # Raise the g_max and xmax to the same number of dim as the input.
...@@ -2875,6 +2877,7 @@ def complex_from_polar(abs, angle): ...@@ -2875,6 +2877,7 @@ def complex_from_polar(abs, angle):
# Misc # Misc
########################## ##########################
#fill, _fill_inplace = _elemwise(scal.second, 'fill', #fill, _fill_inplace = _elemwise(scal.second, 'fill',
#"""fill WRITEME (elemwise)""") #"""fill WRITEME (elemwise)""")
@_scal_elemwise @_scal_elemwise
...@@ -2943,7 +2946,7 @@ class Eye(gof.Op): ...@@ -2943,7 +2946,7 @@ class Eye(gof.Op):
return [out_shape] return [out_shape]
def grad(self, inp, grads): def grad(self, inp, grads):
return [ grad_undefined(self,i,inp[i]) for i in xrange(3) ] return [grad_undefined(self, i, inp[i]) for i in xrange(3)]
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.dtype == other.dtype return type(self) == type(other) and self.dtype == other.dtype
...@@ -3418,15 +3421,20 @@ def var(input, axis=None, keepdims=False): ...@@ -3418,15 +3421,20 @@ def var(input, axis=None, keepdims=False):
@constructor @constructor
def std(input, axis=None, keepdims=False): def std(input, axis=None, keepdims=False):
""" """
Computes the standard deviation along the given axis(es) of a tensor `input`. Computes the standard deviation along the given axis(es)
of a tensor `input`.
:param axis: Compute the standard deviation along this axis of the tensor. :param axis: Compute the standard deviation along this
axis of the tensor.
None means all axes (like numpy). None means all axes (like numpy).
:type axis: None or int or (list of int) (see `Sum`) :type axis: None or int or (list of int) (see `Sum`)
:param keepdims: If this is set to True, the axes which are reduced are :param keepdims: If this is set to True, the axes
left in the result as dimensions with size one. With this option, which are reduced are
the result will broadcast correctly against the original tensor. left in the result as dimensions with size one.
With this option,
the result will broadcast correctly against the
original tensor.
""" """
return sqrt(var(input=input, axis=axis, keepdims=keepdims)) return sqrt(var(input=input, axis=axis, keepdims=keepdims))
...@@ -5402,7 +5410,6 @@ class Reshape(Op): ...@@ -5402,7 +5410,6 @@ class Reshape(Op):
raise ValueError('Cannot reshape input of shape %s to shape %s' % raise ValueError('Cannot reshape input of shape %s to shape %s' %
(x.shape, shp)) (x.shape, shp))
def grad(self, inp, grads): def grad(self, inp, grads):
x, shp = inp x, shp = inp
g_out, = grads g_out, = grads
...@@ -5488,7 +5495,8 @@ class Reshape(Op): ...@@ -5488,7 +5495,8 @@ class Reshape(Op):
%(shp)s->data + ii * %(shp)s->strides[0]))[0]; %(shp)s->data + ii * %(shp)s->strides[0]))[0];
} }
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, PyArray_CORDER); %(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape,
PyArray_CORDER);
if (!%(z)s) if (!%(z)s)
{ {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
...@@ -6351,6 +6359,7 @@ advanced_inc_subtensor = AdvancedIncSubtensor() ...@@ -6351,6 +6359,7 @@ advanced_inc_subtensor = AdvancedIncSubtensor()
# #
# TODO: Dotinv should go here, Eigs, Svd, etc. # TODO: Dotinv should go here, Eigs, Svd, etc.
class Dot(Op): class Dot(Op):
"""Compute matrix-matrix, matrix-vector products and vector inner-products. """Compute matrix-matrix, matrix-vector products and vector inner-products.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论