提交 215bdcd2 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Introduce two functions: 'gemm_inplace' and 'gemm_no_inplace' instead of 'gemm'.

Add parameter 'inplace' to Gemm Op. Alias gemm_inplace to Gemm(inplace=True) and gemm_no_inplace to Gemm(inplace=False). Replace calls to 'gemm' in GemmOptimizer by calls to gemm_no_inplace.
上级 611d631b
...@@ -300,7 +300,26 @@ class Gemm(GemmRelated): ...@@ -300,7 +300,26 @@ class Gemm(GemmRelated):
E_rank = 'gemm only works for rank 2' E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument' E_scalar = 'gemm requires scalar argument'
E_z_uniq = 'argument z aliased to x or y' E_z_uniq = 'argument z aliased to x or y'
destroy_map = {0: [0]} def __init__(self, inplace):
if inplace:
self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
self.inplace = inplace
def __eq__(self, other):
return (type(self) == type(other)\
and self.inplace == other.inplace)
def __hash__(self):
return hash(type(self)) ^ hash(self.inplace)
def __str__(self):
if self.inplace: inplace_str = 'inplace'
else: inplace_str = 'no_inplace'
return '%s{%s}' % (self.__class__.__name__, inplace_str)
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = map(T.as_tensor_variable, inputs) inputs = map(T.as_tensor_variable, inputs)
if len(inputs) != 5: if len(inputs) != 5:
...@@ -322,6 +341,8 @@ class Gemm(GemmRelated): ...@@ -322,6 +341,8 @@ class Gemm(GemmRelated):
def perform(self, node, (z, a, x, y, b), (zout, )): def perform(self, node, (z, a, x, y, b), (zout, )):
assert a.shape == () assert a.shape == ()
assert b.shape == () assert b.shape == ()
if not self.inplace:
z = z.copy() # the original z will not be changed
if z.shape == (): if z.shape == ():
z.itemset(z*a + b*numpy.dot(x,y)) z.itemset(z*a + b*numpy.dot(x,y))
zout[0] = z zout[0] = z
...@@ -345,7 +366,7 @@ class Gemm(GemmRelated): ...@@ -345,7 +366,7 @@ class Gemm(GemmRelated):
z += a * numpy.dot(x,y) z += a * numpy.dot(x,y)
zout[0] = z zout[0] = z
setup_z_Nz_Sz = """ setup_z_Nz_Sz_inplace = """
if (%(_zout)s != %(_z)s) if (%(_zout)s != %(_z)s)
{ {
if (%(_zout)s) if (%(_zout)s)
...@@ -359,6 +380,25 @@ class Gemm(GemmRelated): ...@@ -359,6 +380,25 @@ class Gemm(GemmRelated):
Sz = %(_z)s->strides; Sz = %(_z)s->strides;
""" """
setup_z_Nz_Sz_outplace = """
if ((NULL == %(_zout)s)
|| (%(_zout)s->dimensions[0] != %(_z)s->dimensions[0])
|| (%(_zout)s->dimensions[1] != %(_z)s->dimensions[1]))
{
if (NULL != %(_zout)s) Py_XDECREF(%(_zout)s);
npy_intp dims[2];
dims[0] = %(_z)s->dimensions[0];
dims[1] = %(_z)s->dimensions[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_z)s);
if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc gemm_no_inplace output");
%(fail)s
}
}
Nz = %(_zout)s->dimensions;
Sz = %(_zout)s->strides;
"""
case_float_ab_constants = """ case_float_ab_constants = """
#define REAL float #define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT) float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
...@@ -387,23 +427,12 @@ class Gemm(GemmRelated): ...@@ -387,23 +427,12 @@ class Gemm(GemmRelated):
return full_code return full_code
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) + self.build_gemm_version() return (2,) + self.build_gemm_version()
class PseudoGemm(Op): gemm_inplace = Gemm(inplace=True)
# should be replaced by Gemm gemm_no_inplace = Gemm(inplace=False)
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, *args):
inputs = [T.as_tensor_variable(i) for i in args]
return Apply(self, inputs, [inputs[0].type()])
def perform(self, node, (z, a, x, y, b), (zout, )):
zout[0] = a * numpy.dot(x,y) + b * z
gemm = PseudoGemm()
gemm_inplace = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm'))
pprint.assign(gemm_inplace, FunctionPrinter('gemm_inplace')) pprint.assign(gemm_inplace, FunctionPrinter('gemm_inplace'))
pprint.assign(gemm_no_inplace, FunctionPrinter('gemm_no_inplace'))
def res_is_a(node, op, maxclients=None): def res_is_a(node, op, maxclients=None):
if maxclients is not None: if maxclients is not None:
...@@ -464,35 +493,35 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -464,35 +493,35 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
if res_is_a(M, _dot22, 1): if res_is_a(M, _dot22, 1):
Ml, Mr = M.owner.inputs Ml, Mr = M.owner.inputs
rval = [gemm(L, alpha, Ml, Mr, beta)] rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
#print 'GEMM 0', rval, beta, L, alpha, M #print 'GEMM 0', rval, beta, L, alpha, M
return rval return rval
# this is False'd out because of inadequate testing. # this is False'd out because of inadequate testing.
# TODO see ticket #237 # TODO see ticket #237
if False and res_is_a(M, gemm, 1): if False and res_is_a(M, gemm_no_inplace, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b))) #EXPRESSION: (beta * L) + (alpha * (gemm_no_inplace(G, a, u, v, b)))
#EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v) #EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v)
G, a, u, v, b = M.owner.inputs G, a, u, v, b = M.owner.inputs
#print 'GEMM', G, L #print 'GEMM', G, L
if res_is_a(G, _dot22, 1): if res_is_a(G, _dot22, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(dot(x,y), a, u, v, b))) #EXPRESSION: (beta * L) + (alpha * (gemm_no_inplace(dot(x,y), a, u, v, b)))
x, y = G.owner.inputs x, y = G.owner.inputs
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v))))) #EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v)) #EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
rval = [gemm(gemm(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)] rval = [gemm_no_inplace(gemm_no_inplace(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)]
print 'GEMM 1', rval print 'GEMM 1', rval
return rval return rval
if (G is L): if (G is L):
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v)) #EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval = [gemm(L, alpha*a, u, v, alpha * b + beta)] rval = [gemm_no_inplace(L, alpha*a, u, v, alpha * b + beta)]
print 'GEMM 2', rval print 'GEMM 2', rval
return rval return rval
if (1.0 != alpha): if (1.0 != alpha):
#at the very least, move the alpha inside the gemm #at the very least, move the alpha inside the gemm_no_inplace
rval = [beta * L + gemm(G, alpha * a, u, v, alpha * b)] rval = [beta * L + gemm_no_inplace(G, alpha * a, u, v, alpha * b)]
print 'GEMM 3', rval print 'GEMM 3', rval
return rval return rval
...@@ -695,9 +724,9 @@ def local_dot_to_dot22(node): ...@@ -695,9 +724,9 @@ def local_dot_to_dot22(node):
else: else:
return False return False
@local_optimizer([gemm]) @local_optimizer([gemm_no_inplace])
def local_inplace_gemm(node): def local_inplace_gemm(node):
if node.op == gemm: if node.op == gemm_no_inplace:
return [gemm_inplace(*node.inputs)] return [gemm_inplace(*node.inputs)]
################################# #################################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论