提交 befc44f8 authored 作者: Frederic's avatar Frederic

pep8

上级 c78d8db1
......@@ -4,7 +4,8 @@ from theano import config
from theano.tensor.opt import in2out
from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import blas_optdb, optdb, local_optimizer, EquilibriumOptimizer
from theano.tensor.blas import (
blas_optdb, optdb, local_optimizer, EquilibriumOptimizer)
from theano.tensor.blas import Ger, ger, ger_destructive
from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace
from theano.tensor import basic as T
......@@ -28,9 +29,9 @@ class BaseBLAS(object):
return blas_header_text()
####### ####### #######
# ##### ####### #######
# GER
####### ####### #######
# ##### ####### #######
def ger_c_code(A, a, x, y, Z, destructive, fail):
return """
......@@ -250,8 +251,8 @@ class CGer(BaseBLAS, Ger):
A, a, x, y = inp
Z, = out
code = ger_c_code(A, a, x, y, Z,
destructive=int(self.destructive),
fail=sub['fail'])
destructive=int(self.destructive),
fail=sub['fail'])
return code
def c_code_cache_version(self):
......@@ -279,12 +280,13 @@ def make_c_ger_destructive(node):
return [cger_inplace(*node.inputs)]
####### ####### #######
# ##### ####### #######
# GEMV
####### ####### #######
# ##### ####### #######
def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, force_init_beta=False):
def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail,
force_init_beta=False):
"""
zz <- beta * aa + alpha * dot(xx, yy)
......@@ -618,11 +620,11 @@ class CGemv(BaseBLAS, Gemv):
aa, alpha, xx, yy, beta = inp
zz, = out
code = gemv_c_code(
aa, xx, yy, zz, alpha, beta,
destructive=int(self.inplace),
fail=sub['fail'],
force_init_beta=self.force_init_beta
)
aa, xx, yy, zz, alpha, beta,
destructive=int(self.inplace),
fail=sub['fail'],
force_init_beta=self.force_init_beta
)
return code
def c_code_cache_version(self):
......@@ -630,6 +632,7 @@ class CGemv(BaseBLAS, Gemv):
cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False)
def check_force_gemv_init():
if check_force_gemv_init._force_init_beta is None:
"""
......@@ -680,6 +683,7 @@ def check_force_gemv_init():
check_force_gemv_init._force_init_beta = None
@local_optimizer([gemv_inplace, gemv_no_inplace])
def use_c_gemv(node):
if not config.blas.ldflags:
......@@ -709,7 +713,8 @@ def use_c_gemv(node):
"""
force_init_beta = check_force_gemv_init()
return [CGemv(inplace=False, force_init_beta=force_init_beta)(*node.inputs)]
return [CGemv(inplace=False,
force_init_beta=force_init_beta)(*node.inputs)]
if (node.op == gemv_inplace and
node.outputs[0].dtype in ['float32', 'float64']):
return [CGemv(inplace=True)(*node.inputs)]
......@@ -721,15 +726,13 @@ def make_c_gemv_destructive(node):
return [cgemv_inplace(*node.inputs)]
####### ####### #######
# ##### ####### #######
# Optimizers
####### ####### #######
# ##### ####### #######
blas_optdb.register('use_c_blas',
in2out(use_c_ger, use_c_gemv),
20, 'fast_run', 'c_blas')
#print 'BLAS_OPTDB'
#print blas_optdb
# this matches the InplaceBlasOpt defined in blas.py
optdb.register('c_blas_destructive',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论