提交 2aaafc0a authored 作者: Frederic Bastien's avatar Frederic Bastien

xParams for cgemv and cger

上级 77a2fd51
......@@ -1044,7 +1044,7 @@ class Gemm(GemmRelated):
def c_code_cache_version(self):
gv = self.build_gemm_version()
if gv:
return (5,) + gv
return (6,) + gv
else:
return gv
......
from __future__ import absolute_import, print_function, division
from theano import config
from theano.gof.params_type import ParamsType
from theano.scalar import bool as bool_t
from theano.tensor.opt import in2out
from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import blas_optdb, optdb, local_optimizer
......@@ -30,7 +32,7 @@ class BaseBLAS(object):
# GER
# ##### ####### #######
def ger_c_code(A, a, x, y, Z, destructive, fail):
def ger_c_code(A, a, x, y, Z, fail, params):
return """
int elemsize ;
......@@ -71,7 +73,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
}
// copy A if !self.destructive or A is fully strided
if (!%(destructive)s
if (!%(params)s->destructive
|| (PyArray_STRIDES(%(A)s)[0] < 0)
|| (PyArray_STRIDES(%(A)s)[1] < 0)
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize)
......@@ -311,16 +313,18 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
class CGer(BaseBLAS, Ger):
params_type = ParamsType(destructive=bool_t,)
def c_code(self, node, name, inp, out, sub):
A, a, x, y = inp
Z, = out
code = ger_c_code(A, a, x, y, Z,
destructive=int(self.destructive),
fail=sub['fail'])
fail=sub['fail'],
params=sub['params'])
return code
def c_code_cache_version(self):
return (10, blas_header_version())
return (11, blas_header_version())
cger_inplace = CGer(True)
cger_no_inplace = CGer(False)
......@@ -349,8 +353,8 @@ def make_c_ger_destructive(node):
# ##### ####### #######
def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail,
force_init_beta=False):
def gemv_c_code(y, A, x, z, alpha, beta, fail,
force_init_beta=False, params=None):
"""
z <- beta * y + alpha * dot(A, x)
......@@ -385,7 +389,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail,
fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0];
// copy y if not destructive
if (!%(destructive)s)
if (!%(params)s->inplace)
{
if ((NULL == %(z)s)
|| (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(y)s)[0]))
......@@ -593,6 +597,8 @@ def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail,
class CGemv(BaseBLAS, Gemv):
params_type = ParamsType(inplace=bool_t,)
def __init__(self, inplace):
super(CGemv, self).__init__(inplace)
......@@ -601,14 +607,14 @@ class CGemv(BaseBLAS, Gemv):
z, = out
code = gemv_c_code(
y, A, x, z, alpha, beta,
destructive=int(self.inplace),
fail=sub['fail'],
force_init_beta=check_force_gemv_init()
force_init_beta=check_force_gemv_init(),
params=sub['params'],
)
return code
def c_code_cache_version(self):
return (13, blas_header_version(), check_force_gemv_init())
return (14, blas_header_version(), check_force_gemv_init())
cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论