提交 f9202f68 authored 作者: Shawn Tan's avatar Shawn Tan

Made theano/sandbox/linalg/ops.py flake8 compliant

上级 29c39347
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import logging import logging
logger = logging.getLogger(__name__)
import numpy
from six import iteritems, integer_types from six import iteritems, integer_types
from six.moves import xrange from six.moves import xrange
from theano.gof import Op, Apply from theano.gof import Op, Apply
from theano.tensor import as_tensor_variable, dot, DimShuffle, Dot from theano.tensor import DimShuffle, Dot
from theano.tensor.blas import Dot22 from theano.tensor.blas import Dot22
from theano import tensor from theano import tensor
import theano.tensor import theano.tensor
from theano.tensor.opt import (register_stabilize, from theano.tensor.opt import (register_stabilize,
register_specialize, register_canonicalize) register_specialize,
register_canonicalize)
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
from theano.gradient import DisconnectedType
from theano.tensor.nlinalg import (MatrixInverse,
from theano.tensor.nlinalg import ( MatrixInverse, matrix_inverse,
matrix_inverse, extract_diag,
MatrixPinv, trace,
pinv, det)
AllocDiag,
alloc_diag, from theano.tensor.slinalg import (Cholesky,
ExtractDiag, cholesky,
extract_diag, Solve,
diag, solve)
trace,
Det, logger = logging.getLogger(__name__)
det,
Eig,
eig,
Eigh,
EighGrad,
eigh,
matrix_dot,
_zero_disconnected,
qr,
svd,
lstsq,
matrix_power,
norm
)
from theano.tensor.slinalg import ( Cholesky,
cholesky,
CholeskyGrad,
Solve,
solve,
Eigvalsh,
EigvalshGrad,
eigvalsh
)
try:
import scipy.linalg
imported_scipy = True
except ImportError:
# some ops (e.g. Cholesky, Solve, A_Xinv_b) won't work
imported_scipy = False
class Hint(Op): class Hint(Op):
...@@ -212,8 +178,6 @@ class HintsFeature(object): ...@@ -212,8 +178,6 @@ class HintsFeature(object):
class HintsOptimizer(Optimizer): class HintsOptimizer(Optimizer):
""" """
Optimizer that serves to add HintsFeature as an fgraph feature. Optimizer that serves to add HintsFeature as an fgraph feature.
""" """
def __init__(self): def __init__(self):
...@@ -280,8 +244,6 @@ def transinv_to_invtrans(node): ...@@ -280,8 +244,6 @@ def transinv_to_invtrans(node):
@register_stabilize @register_stabilize
@local_optimizer([Dot, Dot22]) @local_optimizer([Dot, Dot22])
def inv_as_solve(node): def inv_as_solve(node):
if not imported_scipy:
return False
if isinstance(node.op, (Dot, Dot22)): if isinstance(node.op, (Dot, Dot22)):
l, r = node.inputs l, r = node.inputs
if l.owner and l.owner.op == matrix_inverse: if l.owner and l.owner.op == matrix_inverse:
...@@ -310,8 +272,8 @@ def tag_solve_triangular(node): ...@@ -310,8 +272,8 @@ def tag_solve_triangular(node):
return [Solve('lower_triangular')(A, b)] return [Solve('lower_triangular')(A, b)]
else: else:
return [Solve('upper_triangular')(A, b)] return [Solve('upper_triangular')(A, b)]
if (A.owner and isinstance(A.owner.op, DimShuffle) if (A.owner and isinstance(A.owner.op, DimShuffle) and
and A.owner.op.new_order == (1, 0)): A.owner.op.new_order == (1, 0)):
A_T, = A.owner.inputs A_T, = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, type(cholesky)): if A_T.owner and isinstance(A_T.owner.op, type(cholesky)):
if A_T.owner.op.lower: if A_T.owner.op.lower:
...@@ -423,6 +385,5 @@ def spectral_radius_bound(X, log2_exponent): ...@@ -423,6 +385,5 @@ def spectral_radius_bound(X, log2_exponent):
XX = X XX = X
for i in xrange(log2_exponent): for i in xrange(log2_exponent):
XX = tensor.dot(XX, XX) XX = tensor.dot(XX, XX)
return tensor.pow( return tensor.pow(trace(XX),
trace(XX), 2 ** (-log2_exponent))
2 ** (-log2_exponent))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论