提交 befc44f8 authored 作者: Frederic's avatar Frederic

pep8

上级 c78d8db1
...@@ -4,7 +4,8 @@ from theano import config ...@@ -4,7 +4,8 @@ from theano import config
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
from theano.tensor.blas import ldflags, blas_header_text, blas_header_version from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import blas_optdb, optdb, local_optimizer, EquilibriumOptimizer from theano.tensor.blas import (
blas_optdb, optdb, local_optimizer, EquilibriumOptimizer)
from theano.tensor.blas import Ger, ger, ger_destructive from theano.tensor.blas import Ger, ger, ger_destructive
from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace
from theano.tensor import basic as T from theano.tensor import basic as T
...@@ -28,9 +29,9 @@ class BaseBLAS(object): ...@@ -28,9 +29,9 @@ class BaseBLAS(object):
return blas_header_text() return blas_header_text()
####### ####### ####### # ##### ####### #######
# GER # GER
####### ####### ####### # ##### ####### #######
def ger_c_code(A, a, x, y, Z, destructive, fail): def ger_c_code(A, a, x, y, Z, destructive, fail):
return """ return """
...@@ -250,8 +251,8 @@ class CGer(BaseBLAS, Ger): ...@@ -250,8 +251,8 @@ class CGer(BaseBLAS, Ger):
A, a, x, y = inp A, a, x, y = inp
Z, = out Z, = out
code = ger_c_code(A, a, x, y, Z, code = ger_c_code(A, a, x, y, Z,
destructive=int(self.destructive), destructive=int(self.destructive),
fail=sub['fail']) fail=sub['fail'])
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -279,12 +280,13 @@ def make_c_ger_destructive(node): ...@@ -279,12 +280,13 @@ def make_c_ger_destructive(node):
return [cger_inplace(*node.inputs)] return [cger_inplace(*node.inputs)]
####### ####### ####### # ##### ####### #######
# GEMV # GEMV
####### ####### ####### # ##### ####### #######
def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, force_init_beta=False): def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail,
force_init_beta=False):
""" """
zz <- beta * aa + alpha * dot(xx, yy) zz <- beta * aa + alpha * dot(xx, yy)
...@@ -618,11 +620,11 @@ class CGemv(BaseBLAS, Gemv): ...@@ -618,11 +620,11 @@ class CGemv(BaseBLAS, Gemv):
aa, alpha, xx, yy, beta = inp aa, alpha, xx, yy, beta = inp
zz, = out zz, = out
code = gemv_c_code( code = gemv_c_code(
aa, xx, yy, zz, alpha, beta, aa, xx, yy, zz, alpha, beta,
destructive=int(self.inplace), destructive=int(self.inplace),
fail=sub['fail'], fail=sub['fail'],
force_init_beta=self.force_init_beta force_init_beta=self.force_init_beta
) )
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -630,6 +632,7 @@ class CGemv(BaseBLAS, Gemv): ...@@ -630,6 +632,7 @@ class CGemv(BaseBLAS, Gemv):
cgemv_inplace = CGemv(inplace=True) cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False) cgemv_no_inplace = CGemv(inplace=False)
def check_force_gemv_init(): def check_force_gemv_init():
if check_force_gemv_init._force_init_beta is None: if check_force_gemv_init._force_init_beta is None:
""" """
...@@ -680,6 +683,7 @@ def check_force_gemv_init(): ...@@ -680,6 +683,7 @@ def check_force_gemv_init():
check_force_gemv_init._force_init_beta = None check_force_gemv_init._force_init_beta = None
@local_optimizer([gemv_inplace, gemv_no_inplace]) @local_optimizer([gemv_inplace, gemv_no_inplace])
def use_c_gemv(node): def use_c_gemv(node):
if not config.blas.ldflags: if not config.blas.ldflags:
...@@ -709,7 +713,8 @@ def use_c_gemv(node): ...@@ -709,7 +713,8 @@ def use_c_gemv(node):
""" """
force_init_beta = check_force_gemv_init() force_init_beta = check_force_gemv_init()
return [CGemv(inplace=False, force_init_beta=force_init_beta)(*node.inputs)] return [CGemv(inplace=False,
force_init_beta=force_init_beta)(*node.inputs)]
if (node.op == gemv_inplace and if (node.op == gemv_inplace and
node.outputs[0].dtype in ['float32', 'float64']): node.outputs[0].dtype in ['float32', 'float64']):
return [CGemv(inplace=True)(*node.inputs)] return [CGemv(inplace=True)(*node.inputs)]
...@@ -721,15 +726,13 @@ def make_c_gemv_destructive(node): ...@@ -721,15 +726,13 @@ def make_c_gemv_destructive(node):
return [cgemv_inplace(*node.inputs)] return [cgemv_inplace(*node.inputs)]
####### ####### ####### # ##### ####### #######
# Optimizers # Optimizers
####### ####### ####### # ##### ####### #######
blas_optdb.register('use_c_blas', blas_optdb.register('use_c_blas',
in2out(use_c_ger, use_c_gemv), in2out(use_c_ger, use_c_gemv),
20, 'fast_run', 'c_blas') 20, 'fast_run', 'c_blas')
#print 'BLAS_OPTDB'
#print blas_optdb
# this matches the InplaceBlasOpt defined in blas.py # this matches the InplaceBlasOpt defined in blas.py
optdb.register('c_blas_destructive', optdb.register('c_blas_destructive',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论