提交 045baff9 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

PEP8

上级 ae4e9609
...@@ -6833,7 +6833,7 @@ def take(a, indices, axis=None, mode='raise'): ...@@ -6833,7 +6833,7 @@ def take(a, indices, axis=None, mode='raise'):
# Reuse advanced_subtensor1 if indices is a vector # Reuse advanced_subtensor1 if indices is a vector
if indices.ndim == 1: if indices.ndim == 1:
if mode == 'clip': if mode == 'clip':
indices = clip(indices, 0, a.shape[axis]-1) indices = clip(indices, 0, a.shape[axis] - 1)
elif mode == 'wrap': elif mode == 'wrap':
indices = indices % a.shape[axis] indices = indices % a.shape[axis]
if axis is None: if axis is None:
...@@ -6853,10 +6853,12 @@ def take(a, indices, axis=None, mode='raise'): ...@@ -6853,10 +6853,12 @@ def take(a, indices, axis=None, mode='raise'):
shape = indices.shape shape = indices.shape
ndim = indices.ndim ndim = indices.ndim
else: else:
shape = concatenate([a.shape[:axis], indices.shape, a.shape[axis+1:]]) shape = concatenate(
[a.shape[:axis], indices.shape, a.shape[axis + 1:]])
ndim = a.ndim + indices.ndim - 1 ndim = a.ndim + indices.ndim - 1
return take(a, indices.flatten(), axis, mode).reshape(shape, ndim) return take(a, indices.flatten(), axis, mode).reshape(shape, ndim)
######################### #########################
# Linalg : Dot # Linalg : Dot
######################### #########################
...@@ -7283,6 +7285,7 @@ def all(x, axis=None, keepdims=False): ...@@ -7283,6 +7285,7 @@ def all(x, axis=None, keepdims=False):
out = makeKeepDims(x, out, axis) out = makeKeepDims(x, out, axis)
return out return out
class Diagonal(Op): class Diagonal(Op):
"""Return specified diagonals. """Return specified diagonals.
...@@ -7310,7 +7313,7 @@ class Diagonal(Op): ...@@ -7310,7 +7313,7 @@ class Diagonal(Op):
x = as_tensor_variable(x) x = as_tensor_variable(x)
assert x.ndim >= 2 assert x.ndim >= 2
return Apply(self, [x], [tensor(dtype=x.dtype, return Apply(self, [x], [tensor(dtype=x.dtype,
broadcastable=[False] * (x.ndim -1))]) broadcastable=[False] * (x.ndim - 1))])
def perform(self, node, (x,), (z,)): def perform(self, node, (x,), (z,)):
z[0] = x.diagonal(self.offset, self.axis1, self.axis2) z[0] = x.diagonal(self.offset, self.axis1, self.axis2)
...@@ -7322,7 +7325,7 @@ class Diagonal(Op): ...@@ -7322,7 +7325,7 @@ class Diagonal(Op):
in_shape, = shapes in_shape, = shapes
dim1 = in_shape[self.axis1] dim1 = in_shape[self.axis1]
dim2 = in_shape[self.axis2] dim2 = in_shape[self.axis2]
out_shape = [d for i,d in enumerate(in_shape) out_shape = [d for i, d in enumerate(in_shape)
if i not in (self.axis1, self.axis2)] if i not in (self.axis1, self.axis2)]
# The following logic is inspired by C code of PyArray_Diagonal(). # The following logic is inspired by C code of PyArray_Diagonal().
offset = self.offset offset = self.offset
...@@ -7338,12 +7341,14 @@ class Diagonal(Op): ...@@ -7338,12 +7341,14 @@ class Diagonal(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def diagonal(a, offset=0, axis1=0, axis2=1): def diagonal(a, offset=0, axis1=0, axis2=1):
if (offset, axis1, axis2) == (0, 0, 1): if (offset, axis1, axis2) == (0, 0, 1):
from theano.sandbox.linalg import extract_diag from theano.sandbox.linalg import extract_diag
return extract_diag(a) return extract_diag(a)
return Diagonal(offset, axis1, axis2)(a) return Diagonal(offset, axis1, axis2)(a)
class Diag(Op): class Diag(Op):
def __eq__(self, other): def __eq__(self, other):
...@@ -7371,6 +7376,7 @@ class Diag(Op): ...@@ -7371,6 +7376,7 @@ class Diag(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def diag(v, k=0): def diag(v, k=0):
if v.ndim == 1: if v.ndim == 1:
assert k == 0, "diagonals other than main are not implemented" assert k == 0, "diagonals other than main are not implemented"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论