提交 286e1df5 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fix (mostly pre-existing) pep8 violations in tensor.basic

上级 ba4ba5ef
......@@ -48,6 +48,7 @@ continuous_dtypes = map(str, scal.continuous_types)
discrete_dtypes = map(str, scal.discrete_types)
all_dtypes = map(str, scal.all_types)
class ShapeError(Exception):
"""Raised when the shape cannot be computed."""
pass
......@@ -395,6 +396,7 @@ def constant(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim,
dtype=dtype)
def _obj_is_wrappable_as_tensor(x):
try:
constant(x)
......@@ -406,7 +408,7 @@ def _obj_is_wrappable_as_tensor(x):
def _wrap_tensor_into_member(x):
return compile.module.Member(constant(x))
compile.module.register_wrapper(_obj_is_wrappable_as_tensor,
_wrap_tensor_into_member, no_warn = True)
_wrap_tensor_into_member, no_warn=True)
if int(config.tensor.cmp_sloppy) > 1:
......@@ -1503,10 +1505,9 @@ class _tensor_py_operators:
"""
if ndim is not None:
if not isinstance(ndim,int):
if not isinstance(ndim, int):
raise ValueError("Expected ndim to be an integer, is "\
+str(type(ndim)))
+ str(type(ndim)))
return reshape(self, shape, ndim=ndim)
......@@ -1803,7 +1804,6 @@ class TensorConstant(_tensor_py_operators, Constant):
TensorType.Constant = TensorConstant
Tensor = TensorType
......@@ -1817,6 +1817,7 @@ elemwise.TensorConstant = TensorConstant
# Utilities
#########################
def _redefine(real_symbol_value, module='tensor'):
"""Replace the value associated with a function symbol.
......@@ -2062,6 +2063,7 @@ def cast(x, dtype):
# Unary Operations
##########################
class Shape(Op):
"""
L{Op} to return the shape of a matrix.
......@@ -2333,16 +2335,16 @@ class MaxAndArgmax(Op):
#if the op is totally disconnected, so are its inputs
if g_max_disconnected and g_max_idx_disconnected:
return [ DisconnectedType()(), DisconnectedType()() ]
return [DisconnectedType()(), DisconnectedType()()]
axis_grad = grad_undefined(self, 1, axis,
"argmax is not defined for non-integer axes so"
" argmax(x, axis+eps) is undefined" )
" argmax(x, axis+eps) is undefined")
#if the max is disconnected but the argmax is not,
#the gradient on its inputs is zero
if g_max_disconnected:
return [ x.zeros_like(), axis_grad ]
return [x.zeros_like(), axis_grad]
xmax = max(x, axis)
# Raise the g_max and xmax to the same number of dim as the input.
......@@ -2875,6 +2877,7 @@ def complex_from_polar(abs, angle):
# Misc
##########################
#fill, _fill_inplace = _elemwise(scal.second, 'fill',
#"""fill WRITEME (elemwise)""")
@_scal_elemwise
......@@ -2943,7 +2946,7 @@ class Eye(gof.Op):
return [out_shape]
def grad(self, inp, grads):
return [ grad_undefined(self,i,inp[i]) for i in xrange(3) ]
return [grad_undefined(self, i, inp[i]) for i in xrange(3)]
def __eq__(self, other):
return type(self) == type(other) and self.dtype == other.dtype
......@@ -3418,15 +3421,20 @@ def var(input, axis=None, keepdims=False):
@constructor
def std(input, axis=None, keepdims=False):
"""
Computes the standard deviation along the given axis(es) of a tensor `input`.
Computes the standard deviation along the given axis(es)
of a tensor `input`.
:param axis: Compute the standard deviation along this axis of the tensor.
:param axis: Compute the standard deviation along this
axis of the tensor.
None means all axes (like numpy).
:type axis: None or int or (list of int) (see `Sum`)
:param keepdims: If this is set to True, the axes which are reduced are
left in the result as dimensions with size one. With this option,
the result will broadcast correctly against the original tensor.
:param keepdims: If this is set to True, the axes
which are reduced are
left in the result as dimensions with size one.
With this option,
the result will broadcast correctly against the
original tensor.
"""
return sqrt(var(input=input, axis=axis, keepdims=keepdims))
......@@ -5402,7 +5410,6 @@ class Reshape(Op):
raise ValueError('Cannot reshape input of shape %s to shape %s' %
(x.shape, shp))
def grad(self, inp, grads):
x, shp = inp
g_out, = grads
......@@ -5488,7 +5495,8 @@ class Reshape(Op):
%(shp)s->data + ii * %(shp)s->strides[0]))[0];
}
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, PyArray_CORDER);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape,
PyArray_CORDER);
if (!%(z)s)
{
PyErr_Format(PyExc_ValueError,
......@@ -6351,6 +6359,7 @@ advanced_inc_subtensor = AdvancedIncSubtensor()
#
# TODO: Dotinv should go here, Eigs, Svd, etc.
class Dot(Op):
"""Compute matrix-matrix, matrix-vector products and vector inner-products.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论