提交 2aaafc0a authored 作者: Frederic Bastien's avatar Frederic Bastien

xParams for cgemv and cger

上级 77a2fd51
...@@ -1044,7 +1044,7 @@ class Gemm(GemmRelated): ...@@ -1044,7 +1044,7 @@ class Gemm(GemmRelated):
def c_code_cache_version(self): def c_code_cache_version(self):
gv = self.build_gemm_version() gv = self.build_gemm_version()
if gv: if gv:
return (5,) + gv return (6,) + gv
else: else:
return gv return gv
......
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from theano import config from theano import config
from theano.gof.params_type import ParamsType
from theano.scalar import bool as bool_t
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
from theano.tensor.blas import ldflags, blas_header_text, blas_header_version from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import blas_optdb, optdb, local_optimizer from theano.tensor.blas import blas_optdb, optdb, local_optimizer
...@@ -30,7 +32,7 @@ class BaseBLAS(object): ...@@ -30,7 +32,7 @@ class BaseBLAS(object):
# GER # GER
# ##### ####### ####### # ##### ####### #######
def ger_c_code(A, a, x, y, Z, destructive, fail): def ger_c_code(A, a, x, y, Z, fail, params):
return """ return """
int elemsize ; int elemsize ;
...@@ -71,7 +73,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -71,7 +73,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
} }
// copy A if !self.destructive or A is fully strided // copy A if !self.destructive or A is fully strided
if (!%(destructive)s if (!%(params)s->destructive
|| (PyArray_STRIDES(%(A)s)[0] < 0) || (PyArray_STRIDES(%(A)s)[0] < 0)
|| (PyArray_STRIDES(%(A)s)[1] < 0) || (PyArray_STRIDES(%(A)s)[1] < 0)
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize) || ((PyArray_STRIDES(%(A)s)[0] != elemsize)
...@@ -311,16 +313,18 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -311,16 +313,18 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
class CGer(BaseBLAS, Ger): class CGer(BaseBLAS, Ger):
params_type = ParamsType(destructive=bool_t,)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
A, a, x, y = inp A, a, x, y = inp
Z, = out Z, = out
code = ger_c_code(A, a, x, y, Z, code = ger_c_code(A, a, x, y, Z,
destructive=int(self.destructive), fail=sub['fail'],
fail=sub['fail']) params=sub['params'])
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (10, blas_header_version()) return (11, blas_header_version())
cger_inplace = CGer(True) cger_inplace = CGer(True)
cger_no_inplace = CGer(False) cger_no_inplace = CGer(False)
...@@ -349,8 +353,8 @@ def make_c_ger_destructive(node): ...@@ -349,8 +353,8 @@ def make_c_ger_destructive(node):
# ##### ####### ####### # ##### ####### #######
def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail, def gemv_c_code(y, A, x, z, alpha, beta, fail,
force_init_beta=False): force_init_beta=False, params=None):
""" """
z <- beta * y + alpha * dot(A, x) z <- beta * y + alpha * dot(A, x)
...@@ -385,7 +389,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail, ...@@ -385,7 +389,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail,
fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0]; fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0];
// copy y if not destructive // copy y if not destructive
if (!%(destructive)s) if (!%(params)s->inplace)
{ {
if ((NULL == %(z)s) if ((NULL == %(z)s)
|| (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(y)s)[0])) || (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(y)s)[0]))
...@@ -593,6 +597,8 @@ def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail, ...@@ -593,6 +597,8 @@ def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail,
class CGemv(BaseBLAS, Gemv): class CGemv(BaseBLAS, Gemv):
params_type = ParamsType(inplace=bool_t,)
def __init__(self, inplace): def __init__(self, inplace):
super(CGemv, self).__init__(inplace) super(CGemv, self).__init__(inplace)
...@@ -601,14 +607,14 @@ class CGemv(BaseBLAS, Gemv): ...@@ -601,14 +607,14 @@ class CGemv(BaseBLAS, Gemv):
z, = out z, = out
code = gemv_c_code( code = gemv_c_code(
y, A, x, z, alpha, beta, y, A, x, z, alpha, beta,
destructive=int(self.inplace),
fail=sub['fail'], fail=sub['fail'],
force_init_beta=check_force_gemv_init() force_init_beta=check_force_gemv_init(),
params=sub['params'],
) )
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (13, blas_header_version(), check_force_gemv_init()) return (14, blas_header_version(), check_force_gemv_init())
cgemv_inplace = CGemv(inplace=True) cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False) cgemv_no_inplace = CGemv(inplace=False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论