提交 045baff9 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

PEP8

上级 ae4e9609
......@@ -6833,7 +6833,7 @@ def take(a, indices, axis=None, mode='raise'):
# Reuse advanced_subtensor1 if indices is a vector
if indices.ndim == 1:
if mode == 'clip':
indices = clip(indices, 0, a.shape[axis]-1)
indices = clip(indices, 0, a.shape[axis] - 1)
elif mode == 'wrap':
indices = indices % a.shape[axis]
if axis is None:
......@@ -6853,10 +6853,12 @@ def take(a, indices, axis=None, mode='raise'):
shape = indices.shape
ndim = indices.ndim
else:
shape = concatenate([a.shape[:axis], indices.shape, a.shape[axis+1:]])
shape = concatenate(
[a.shape[:axis], indices.shape, a.shape[axis + 1:]])
ndim = a.ndim + indices.ndim - 1
return take(a, indices.flatten(), axis, mode).reshape(shape, ndim)
#########################
# Linalg : Dot
#########################
......@@ -7283,6 +7285,7 @@ def all(x, axis=None, keepdims=False):
out = makeKeepDims(x, out, axis)
return out
class Diagonal(Op):
"""Return specified diagonals.
......@@ -7310,7 +7313,7 @@ class Diagonal(Op):
x = as_tensor_variable(x)
assert x.ndim >= 2
return Apply(self, [x], [tensor(dtype=x.dtype,
broadcastable=[False] * (x.ndim -1))])
broadcastable=[False] * (x.ndim - 1))])
def perform(self, node, (x,), (z,)):
z[0] = x.diagonal(self.offset, self.axis1, self.axis2)
......@@ -7322,7 +7325,7 @@ class Diagonal(Op):
in_shape, = shapes
dim1 = in_shape[self.axis1]
dim2 = in_shape[self.axis2]
out_shape = [d for i,d in enumerate(in_shape)
out_shape = [d for i, d in enumerate(in_shape)
if i not in (self.axis1, self.axis2)]
# The following logic is inspired by C code of PyArray_Diagonal().
offset = self.offset
......@@ -7338,12 +7341,14 @@ class Diagonal(Op):
def __str__(self):
return self.__class__.__name__
def diagonal(a, offset=0, axis1=0, axis2=1):
if (offset, axis1, axis2) == (0, 0, 1):
from theano.sandbox.linalg import extract_diag
return extract_diag(a)
return Diagonal(offset, axis1, axis2)(a)
class Diag(Op):
def __eq__(self, other):
......@@ -7371,6 +7376,7 @@ class Diag(Op):
def __str__(self):
return self.__class__.__name__
def diag(v, k=0):
if v.ndim == 1:
assert k == 0, "diagonals other than main are not implemented"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论