提交 f9202f68 authored 作者: Shawn Tan's avatar Shawn Tan

Made theano/sandbox/linalg/ops.py flake8 compliant

上级 29c39347
from __future__ import absolute_import, print_function, division
import logging
logger = logging.getLogger(__name__)
import numpy
from six import iteritems, integer_types
from six.moves import xrange
from theano.gof import Op, Apply
from theano.tensor import as_tensor_variable, dot, DimShuffle, Dot
from theano.tensor import DimShuffle, Dot
from theano.tensor.blas import Dot22
from theano import tensor
import theano.tensor
from theano.tensor.opt import (register_stabilize,
register_specialize, register_canonicalize)
register_specialize,
register_canonicalize)
from theano.gof import local_optimizer
from theano.gof.opt import Optimizer
from theano.gradient import DisconnectedType
from theano.tensor.nlinalg import ( MatrixInverse,
matrix_inverse,
MatrixPinv,
pinv,
AllocDiag,
alloc_diag,
ExtractDiag,
extract_diag,
diag,
trace,
Det,
det,
Eig,
eig,
Eigh,
EighGrad,
eigh,
matrix_dot,
_zero_disconnected,
qr,
svd,
lstsq,
matrix_power,
norm
)
from theano.tensor.slinalg import ( Cholesky,
cholesky,
CholeskyGrad,
Solve,
solve,
Eigvalsh,
EigvalshGrad,
eigvalsh
)
try:
import scipy.linalg
imported_scipy = True
except ImportError:
# some ops (e.g. Cholesky, Solve, A_Xinv_b) won't work
imported_scipy = False
from theano.tensor.nlinalg import (MatrixInverse,
matrix_inverse,
extract_diag,
trace,
det)
from theano.tensor.slinalg import (Cholesky,
cholesky,
Solve,
solve)
logger = logging.getLogger(__name__)
class Hint(Op):
......@@ -212,8 +178,6 @@ class HintsFeature(object):
class HintsOptimizer(Optimizer):
"""
Optimizer that serves to add HintsFeature as an fgraph feature.
"""
def __init__(self):
......@@ -280,8 +244,6 @@ def transinv_to_invtrans(node):
@register_stabilize
@local_optimizer([Dot, Dot22])
def inv_as_solve(node):
if not imported_scipy:
return False
if isinstance(node.op, (Dot, Dot22)):
l, r = node.inputs
if l.owner and l.owner.op == matrix_inverse:
......@@ -310,8 +272,8 @@ def tag_solve_triangular(node):
return [Solve('lower_triangular')(A, b)]
else:
return [Solve('upper_triangular')(A, b)]
if (A.owner and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)):
if (A.owner and isinstance(A.owner.op, DimShuffle) and
A.owner.op.new_order == (1, 0)):
A_T, = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, type(cholesky)):
if A_T.owner.op.lower:
......@@ -423,6 +385,5 @@ def spectral_radius_bound(X, log2_exponent):
XX = X
for i in xrange(log2_exponent):
XX = tensor.dot(XX, XX)
return tensor.pow(
trace(XX),
2 ** (-log2_exponent))
return tensor.pow(trace(XX),
2 ** (-log2_exponent))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论