提交 77a2fd51 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add params to Gemm op.

上级 9db9d791
......@@ -145,9 +145,11 @@ from theano.gof import (utils, Op, view_roots,
InconsistencyError, toolbox, SequenceDB,
EquilibriumOptimizer, Apply,
ReplacementDidntRemovedError)
from theano.gof.params_type import ParamsType
from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb
import theano.scalar
from theano.scalar import bool as bool_t
from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version
......@@ -522,7 +524,11 @@ class GemmRelated(Op):
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
"""
# implement if you don't have an inplace props
# setup_z_Nz_Sz = None
# otherwise implement
# setup_z_Nz_Sz_inplace = None
# setup_z_Nz_Sz_outplace = None
check_xyz_rank2 = """
if (PyArray_NDIM(%(_x)s) != 2) {
......@@ -755,11 +761,16 @@ class GemmRelated(Op):
"""
def build_gemm_call(self):
if hasattr(self, 'inplace'):
setup_z_Nz_Sz = "if(%%(params)s->inplace){%s}else{%s}" % (
self.setup_z_Nz_Sz_inplace, self.setup_z_Nz_Sz_outplace)
else:
setup_z_Nz_Sz = self.setup_z_Nz_Sz
return reduce(str.__add__, (
self.declare_NS,
self.check_xyz_rank2,
self.setup_z_Nz_Sz,
setup_z_Nz_Sz,
self.check_xyz_double_or_float,
self.check_ab_double_or_float,
self.check_dims,
......@@ -809,14 +820,12 @@ class Gemm(GemmRelated):
E_float = 'gemm requires floating-point dtypes'
__props__ = ('inplace',)
params_type = ParamsType(inplace=bool_t,)
def __init__(self, inplace):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
def __str__(self):
if self.inplace:
......@@ -827,10 +836,6 @@ class Gemm(GemmRelated):
def __setstate__(self, dct):
self.__dict__.update(dct)
if self.inplace:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
# Correctly reload older pickles where destroy_map were not
# saved
......@@ -841,7 +846,7 @@ class Gemm(GemmRelated):
rval = self.__dict__.copy()
# Do not serialize the setup code, it will be restored in __setstate__
# depending on the value of 'inplace'
rval.pop('setup_z_Nz_Sz')
rval.pop('setup_z_Nz_Sz', None)
return rval
def make_node(self, *inputs):
......@@ -892,12 +897,12 @@ class Gemm(GemmRelated):
output = z.type()
return Apply(self, inputs, [output])
def perform(self, node, inp, out):
def perform(self, node, inp, out, params):
z, a, x, y, b = inp
zout, = out
assert a.shape == ()
assert b.shape == ()
if not self.inplace:
if not params.inplace:
z = z.copy() # the original z will not be changed
if z.shape == ():
z.itemset(z * a + b * np.dot(x, y))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论