提交 befc44f8 authored 作者: Frederic's avatar Frederic

pep8

上级 c78d8db1
...@@ -4,7 +4,8 @@ from theano import config ...@@ -4,7 +4,8 @@ from theano import config
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
from theano.tensor.blas import ldflags, blas_header_text, blas_header_version from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import blas_optdb, optdb, local_optimizer, EquilibriumOptimizer from theano.tensor.blas import (
blas_optdb, optdb, local_optimizer, EquilibriumOptimizer)
from theano.tensor.blas import Ger, ger, ger_destructive from theano.tensor.blas import Ger, ger, ger_destructive
from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace
from theano.tensor import basic as T from theano.tensor import basic as T
...@@ -28,9 +29,9 @@ class BaseBLAS(object): ...@@ -28,9 +29,9 @@ class BaseBLAS(object):
return blas_header_text() return blas_header_text()
####### ####### ####### # ##### ####### #######
# GER # GER
####### ####### ####### # ##### ####### #######
def ger_c_code(A, a, x, y, Z, destructive, fail): def ger_c_code(A, a, x, y, Z, destructive, fail):
return """ return """
...@@ -279,12 +280,13 @@ def make_c_ger_destructive(node): ...@@ -279,12 +280,13 @@ def make_c_ger_destructive(node):
return [cger_inplace(*node.inputs)] return [cger_inplace(*node.inputs)]
####### ####### ####### # ##### ####### #######
# GEMV # GEMV
####### ####### ####### # ##### ####### #######
def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, force_init_beta=False): def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail,
force_init_beta=False):
""" """
zz <- beta * aa + alpha * dot(xx, yy) zz <- beta * aa + alpha * dot(xx, yy)
...@@ -630,6 +632,7 @@ class CGemv(BaseBLAS, Gemv): ...@@ -630,6 +632,7 @@ class CGemv(BaseBLAS, Gemv):
cgemv_inplace = CGemv(inplace=True) cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False) cgemv_no_inplace = CGemv(inplace=False)
def check_force_gemv_init(): def check_force_gemv_init():
if check_force_gemv_init._force_init_beta is None: if check_force_gemv_init._force_init_beta is None:
""" """
...@@ -680,6 +683,7 @@ def check_force_gemv_init(): ...@@ -680,6 +683,7 @@ def check_force_gemv_init():
check_force_gemv_init._force_init_beta = None check_force_gemv_init._force_init_beta = None
@local_optimizer([gemv_inplace, gemv_no_inplace]) @local_optimizer([gemv_inplace, gemv_no_inplace])
def use_c_gemv(node): def use_c_gemv(node):
if not config.blas.ldflags: if not config.blas.ldflags:
...@@ -709,7 +713,8 @@ def use_c_gemv(node): ...@@ -709,7 +713,8 @@ def use_c_gemv(node):
""" """
force_init_beta = check_force_gemv_init() force_init_beta = check_force_gemv_init()
return [CGemv(inplace=False, force_init_beta=force_init_beta)(*node.inputs)] return [CGemv(inplace=False,
force_init_beta=force_init_beta)(*node.inputs)]
if (node.op == gemv_inplace and if (node.op == gemv_inplace and
node.outputs[0].dtype in ['float32', 'float64']): node.outputs[0].dtype in ['float32', 'float64']):
return [CGemv(inplace=True)(*node.inputs)] return [CGemv(inplace=True)(*node.inputs)]
...@@ -721,15 +726,13 @@ def make_c_gemv_destructive(node): ...@@ -721,15 +726,13 @@ def make_c_gemv_destructive(node):
return [cgemv_inplace(*node.inputs)] return [cgemv_inplace(*node.inputs)]
####### ####### ####### # ##### ####### #######
# Optimizers # Optimizers
####### ####### ####### # ##### ####### #######
blas_optdb.register('use_c_blas', blas_optdb.register('use_c_blas',
in2out(use_c_ger, use_c_gemv), in2out(use_c_ger, use_c_gemv),
20, 'fast_run', 'c_blas') 20, 'fast_run', 'c_blas')
#print 'BLAS_OPTDB'
#print blas_optdb
# this matches the InplaceBlasOpt defined in blas.py # this matches the InplaceBlasOpt defined in blas.py
optdb.register('c_blas_destructive', optdb.register('c_blas_destructive',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论