提交 215bdcd2 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Introduce two functions: 'gemm_inplace' and 'gemm_no_inplace' instead of 'gemm'.

Add parameter 'inplace' to Gemm Op. Alias gemm_inplace to Gemm(inplace=True) and gemm_no_inplace to Gemm(inplace=False). Replace calls to 'gemm' in GemmOptimizer by calls to gemm_no_inplace.
上级 611d631b
......@@ -300,7 +300,26 @@ class Gemm(GemmRelated):
E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument'
E_z_uniq = 'argument z aliased to x or y'
destroy_map = {0: [0]}
def __init__(self, inplace):
if inplace:
self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
self.inplace = inplace
def __eq__(self, other):
return (type(self) == type(other)\
and self.inplace == other.inplace)
def __hash__(self):
return hash(type(self)) ^ hash(self.inplace)
def __str__(self):
if self.inplace: inplace_str = 'inplace'
else: inplace_str = 'no_inplace'
return '%s{%s}' % (self.__class__.__name__, inplace_str)
def make_node(self, *inputs):
inputs = map(T.as_tensor_variable, inputs)
if len(inputs) != 5:
......@@ -322,6 +341,8 @@ class Gemm(GemmRelated):
def perform(self, node, (z, a, x, y, b), (zout, )):
assert a.shape == ()
assert b.shape == ()
if not self.inplace:
z = z.copy() # the original z will not be changed
if z.shape == ():
z.itemset(z*a + b*numpy.dot(x,y))
zout[0] = z
......@@ -345,7 +366,7 @@ class Gemm(GemmRelated):
z += a * numpy.dot(x,y)
zout[0] = z
setup_z_Nz_Sz = """
setup_z_Nz_Sz_inplace = """
if (%(_zout)s != %(_z)s)
{
if (%(_zout)s)
......@@ -359,6 +380,25 @@ class Gemm(GemmRelated):
Sz = %(_z)s->strides;
"""
setup_z_Nz_Sz_outplace = """
if ((NULL == %(_zout)s)
|| (%(_zout)s->dimensions[0] != %(_z)s->dimensions[0])
|| (%(_zout)s->dimensions[1] != %(_z)s->dimensions[1]))
{
if (NULL != %(_zout)s) Py_XDECREF(%(_zout)s);
npy_intp dims[2];
dims[0] = %(_z)s->dimensions[0];
dims[1] = %(_z)s->dimensions[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_z)s);
if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc gemm_no_inplace output");
%(fail)s
}
}
Nz = %(_zout)s->dimensions;
Sz = %(_zout)s->strides;
"""
case_float_ab_constants = """
#define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
......@@ -387,23 +427,12 @@ class Gemm(GemmRelated):
return full_code
def c_code_cache_version(self):
return (1,) + self.build_gemm_version()
return (2,) + self.build_gemm_version()
class PseudoGemm(Op):
# should be replaced by Gemm
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, *args):
inputs = [T.as_tensor_variable(i) for i in args]
return Apply(self, inputs, [inputs[0].type()])
def perform(self, node, (z, a, x, y, b), (zout, )):
zout[0] = a * numpy.dot(x,y) + b * z
gemm = PseudoGemm()
gemm_inplace = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm'))
gemm_inplace = Gemm(inplace=True)
gemm_no_inplace = Gemm(inplace=False)
pprint.assign(gemm_inplace, FunctionPrinter('gemm_inplace'))
pprint.assign(gemm_no_inplace, FunctionPrinter('gemm_no_inplace'))
def res_is_a(node, op, maxclients=None):
if maxclients is not None:
......@@ -464,35 +493,35 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
if res_is_a(M, _dot22, 1):
Ml, Mr = M.owner.inputs
rval = [gemm(L, alpha, Ml, Mr, beta)]
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
#print 'GEMM 0', rval, beta, L, alpha, M
return rval
# this is False'd out because of inadequate testing.
# TODO see ticket #237
if False and res_is_a(M, gemm, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b)))
if False and res_is_a(M, gemm_no_inplace, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm_no_inplace(G, a, u, v, b)))
#EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v)
G, a, u, v, b = M.owner.inputs
#print 'GEMM', G, L
if res_is_a(G, _dot22, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(dot(x,y), a, u, v, b)))
#EXPRESSION: (beta * L) + (alpha * (gemm_no_inplace(dot(x,y), a, u, v, b)))
x, y = G.owner.inputs
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
rval = [gemm(gemm(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)]
rval = [gemm_no_inplace(gemm_no_inplace(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)]
print 'GEMM 1', rval
return rval
if (G is L):
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval = [gemm(L, alpha*a, u, v, alpha * b + beta)]
rval = [gemm_no_inplace(L, alpha*a, u, v, alpha * b + beta)]
print 'GEMM 2', rval
return rval
if (1.0 != alpha):
#at the very least, move the alpha inside the gemm
rval = [beta * L + gemm(G, alpha * a, u, v, alpha * b)]
#at the very least, move the alpha inside the gemm_no_inplace
rval = [beta * L + gemm_no_inplace(G, alpha * a, u, v, alpha * b)]
print 'GEMM 3', rval
return rval
......@@ -695,9 +724,9 @@ def local_dot_to_dot22(node):
else:
return False
@local_optimizer([gemm])
@local_optimizer([gemm_no_inplace])
def local_inplace_gemm(node):
if node.op == gemm:
if node.op == gemm_no_inplace:
return [gemm_inplace(*node.inputs)]
#################################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论