提交 77a2fd51 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add params to Gemm op.

上级 9db9d791
...@@ -145,9 +145,11 @@ from theano.gof import (utils, Op, view_roots, ...@@ -145,9 +145,11 @@ from theano.gof import (utils, Op, view_roots,
InconsistencyError, toolbox, SequenceDB, InconsistencyError, toolbox, SequenceDB,
EquilibriumOptimizer, Apply, EquilibriumOptimizer, Apply,
ReplacementDidntRemovedError) ReplacementDidntRemovedError)
from theano.gof.params_type import ParamsType
from theano.printing import pprint, FunctionPrinter, debugprint from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb from theano.compile.mode import optdb
import theano.scalar import theano.scalar
from theano.scalar import bool as bool_t
from theano.tensor import basic as T from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version from theano.tensor.blas_headers import blas_header_version
...@@ -522,7 +524,11 @@ class GemmRelated(Op): ...@@ -522,7 +524,11 @@ class GemmRelated(Op):
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1; int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
""" """
# implement if you don't have an inplace props
# setup_z_Nz_Sz = None # setup_z_Nz_Sz = None
# otherwise implement
# setup_z_Nz_Sz_inplace = None
# setup_z_Nz_Sz_outplace = None
check_xyz_rank2 = """ check_xyz_rank2 = """
if (PyArray_NDIM(%(_x)s) != 2) { if (PyArray_NDIM(%(_x)s) != 2) {
...@@ -755,11 +761,16 @@ class GemmRelated(Op): ...@@ -755,11 +761,16 @@ class GemmRelated(Op):
""" """
def build_gemm_call(self): def build_gemm_call(self):
if hasattr(self, 'inplace'):
setup_z_Nz_Sz = "if(%%(params)s->inplace){%s}else{%s}" % (
self.setup_z_Nz_Sz_inplace, self.setup_z_Nz_Sz_outplace)
else:
setup_z_Nz_Sz = self.setup_z_Nz_Sz
return reduce(str.__add__, ( return reduce(str.__add__, (
self.declare_NS, self.declare_NS,
self.check_xyz_rank2, self.check_xyz_rank2,
self.setup_z_Nz_Sz, setup_z_Nz_Sz,
self.check_xyz_double_or_float, self.check_xyz_double_or_float,
self.check_ab_double_or_float, self.check_ab_double_or_float,
self.check_dims, self.check_dims,
...@@ -809,14 +820,12 @@ class Gemm(GemmRelated): ...@@ -809,14 +820,12 @@ class Gemm(GemmRelated):
E_float = 'gemm requires floating-point dtypes' E_float = 'gemm requires floating-point dtypes'
__props__ = ('inplace',) __props__ = ('inplace',)
params_type = ParamsType(inplace=bool_t,)
def __init__(self, inplace): def __init__(self, inplace):
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
...@@ -827,10 +836,6 @@ class Gemm(GemmRelated): ...@@ -827,10 +836,6 @@ class Gemm(GemmRelated):
def __setstate__(self, dct): def __setstate__(self, dct):
self.__dict__.update(dct) self.__dict__.update(dct)
if self.inplace:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
# Correctly reload older pickles where destroy_map were not # Correctly reload older pickles where destroy_map were not
# saved # saved
...@@ -841,7 +846,7 @@ class Gemm(GemmRelated): ...@@ -841,7 +846,7 @@ class Gemm(GemmRelated):
rval = self.__dict__.copy() rval = self.__dict__.copy()
# Do not serialize the setup code, it will be restored in __setstate__ # Do not serialize the setup code, it will be restored in __setstate__
# depending on the value of 'inplace' # depending on the value of 'inplace'
rval.pop('setup_z_Nz_Sz') rval.pop('setup_z_Nz_Sz', None)
return rval return rval
def make_node(self, *inputs): def make_node(self, *inputs):
...@@ -892,12 +897,12 @@ class Gemm(GemmRelated): ...@@ -892,12 +897,12 @@ class Gemm(GemmRelated):
output = z.type() output = z.type()
return Apply(self, inputs, [output]) return Apply(self, inputs, [output])
def perform(self, node, inp, out): def perform(self, node, inp, out, params):
z, a, x, y, b = inp z, a, x, y, b = inp
zout, = out zout, = out
assert a.shape == () assert a.shape == ()
assert b.shape == () assert b.shape == ()
if not self.inplace: if not params.inplace:
z = z.copy() # the original z will not be changed z = z.copy() # the original z will not be changed
if z.shape == (): if z.shape == ():
z.itemset(z * a + b * np.dot(x, y)) z.itemset(z * a + b * np.dot(x, y))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论