提交 cc8822d3 authored 作者: James Bergstra's avatar James Bergstra

several changes to tensor/blas.py

上级 dfdea75f
...@@ -387,6 +387,12 @@ class GemmRelated(Op): ...@@ -387,6 +387,12 @@ class GemmRelated(Op):
(long int)Ny[1], (long int)Nz[1]); (long int)Ny[1], (long int)Nz[1]);
%(fail)s; %(fail)s;
} }
if (Nx[1] == 0)
{
PyErr_Format(PyExc_ValueError,
"Undefined semantics: x has 0 cols");
%(fail)s;
}
""" """
check_strides = """ check_strides = """
...@@ -497,6 +503,12 @@ class GemmRelated(Op): ...@@ -497,6 +503,12 @@ class GemmRelated(Op):
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1]; int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n'; //std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
//double t0 = time_time(); //double t0 = time_time();
//fprintf(stderr, "unit=%%x N= %%i %%i %%i S = %%i %%i %%i %%i %%i %%i\\n", unit,
//Nz1, Nz0, Nx1,
//sy_0, sy_1,
//sx_0, sx_1,
//sz_0, sz_1
//);
switch(unit) switch(unit)
{ {
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break; case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
...@@ -540,7 +552,7 @@ class GemmRelated(Op): ...@@ -540,7 +552,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '') self.end_switch_typenum), '')
def build_gemm_version(self): def build_gemm_version(self):
return (6,) return (7,)
class Gemm(GemmRelated): class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation): """In-place version of matrix-matrix multiplication (with accumulation):
...@@ -818,13 +830,6 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -818,13 +830,6 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip #print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M) #EXPRESSION: (beta * L) + (alpha * M)
if M.type.broadcastable != L.type.broadcastable:
# GEMM cannot do the broadcasting that add used to be doing
# so abort.
return
assert L.type.dtype == M.type.dtype # because of local_dot_to_dot22
# we've already checked the client counts, now just make the type check. # we've already checked the client counts, now just make the type check.
####if res_is_a(M, _dot22, 1): ####if res_is_a(M, _dot22, 1):
if M.owner and M.owner.op == _dot22: if M.owner and M.owner.op == _dot22:
...@@ -1017,20 +1022,23 @@ def _gemm_from_factored_list(lst): ...@@ -1017,20 +1022,23 @@ def _gemm_from_factored_list(lst):
""" """
# Make every pair in list have matching dtypes # Make every pair in list have matching dtypes
lst = [(T.cast(si,Mi.type.dtype), Mi) for si,Mi in lst] def is_pair(sM):
try:
s, M = sM
return True
except:
return False
lst = [(T.cast(sM[0],sM[1].type.dtype), sM[1])
for sM in lst if is_pair(sM)]
# Try every pair in the sM_list, trying to turn it into a gemm operation # Try every pair in the sM_list, trying to turn it into a gemm operation
for i in xrange(len(lst) - 1): for i in xrange(len(lst) - 1):
try: s_i,M_i = lst[i]
s_i,M_i = lst[i]
except:
continue
for j in xrange(i+1, len(lst)): for j in xrange(i+1, len(lst)):
s_j, M_j = lst[j]
try: if M_i.type != M_j.type:
s_j, M_j = lst[j]
except:
continue continue
#print 'TRYING', (s_i, M_i, s_j, M_j) #print 'TRYING', (s_i, M_i, s_j, M_j)
...@@ -1281,24 +1289,32 @@ class Dot22Scalar(GemmRelated): ...@@ -1281,24 +1289,32 @@ class Dot22Scalar(GemmRelated):
Also used to generate a gemm later. Also used to generate a gemm later.
compute scalar*dot(x,y) compute scalar*dot(x,y)
""" """
def make_node(self, x, y, scalar): def make_node(self, x, y, a):
if not _is_real_matrix(x): if a.ndim != 0:
raise TypeError(x) raise TypeError(Gemm.E_scalar, a)
if not _is_real_matrix(x): if x.ndim != 2:
raise TypeError(y) raise TypeError(Gemm.E_rank, x)
if not _as_scalar(scalar): if y.ndim != 2:
raise TypeError(scalar) raise TypeError(Gemm.E_rank, y)
if y.type.dtype != x.type.dtype and y.type.dtype != scalar.type.dtype:
raise TypeError('dtype mismatch to Dot22Scalar') if not (a.dtype == x.dtype == y.dtype):
bz = [False, False] raise TypeError('Dot22Scalar requires matching dtypes',
(a.dtype, x.dtype, y.dtype))
if (not a.dtype.startswith('float')
and not a.dtype.startswith('complex')):
raise TypeError('Dot22Scalar requires float or complex args',
a.dtype)
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)] outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y,scalar], outputs) return Apply(self, [x,y,a], outputs)
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, y, scalar = inp x, y, scalar = inp
z, = out z, = out
try: try:
z[0] = scalar * numpy.asarray(numpy.dot(x, y)) z[0] = numpy.asarray(scalar * numpy.dot(x, y))
except ValueError, e: except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that # The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape) e.args = e.args + (x.shape, y.shape)
...@@ -1360,21 +1376,23 @@ def local_dot22_to_dot22scalar(node): ...@@ -1360,21 +1376,23 @@ def local_dot22_to_dot22scalar(node):
return False return False
i_dot22 = [x.owner and x.owner.op==_dot22 for x in node.inputs] i_dot22 = [x.owner and x.owner.op==_dot22 for x in node.inputs]
if not any(i_dot22): return False # no dot22 if not any(i_dot22): return False # no dot22
if i_dot22.count(True)>1: return False #TODO fix if i_dot22.count(True)>1:
#we take the first _dot22 found. TODO check others! #TODO: try each of them.
pass
#return False #TODO fix
dot22_idx = i_dot22.index(True) dot22_idx = i_dot22.index(True)
d = node.inputs[dot22_idx] d = node.inputs[dot22_idx]
i_scalar = [_as_scalar(x) for x in node.inputs] i_scalar = [_as_scalar(x) for x in node.inputs]
if not any(i_scalar) and not any([x.owner and x.owner.op ==T.mul for x in node.inputs]):
#no scalar in input and no multiplication
#if their was a multiplication we couls reorder the graph by the associativity of the graph.
return False
if not any(i_scalar): if not any(i_scalar):
i_mul = [x.owner and x.owner.op ==T.mul for x in node.inputs]
if not any(i_mul):
#no scalar in input and no multiplication
#if their was a multiplication we couls reorder the graph by the associativity of the graph.
return False
#maybe we can reorder the graph as this mul have a mul in input. #maybe we can reorder the graph as this mul have a mul in input.
#The canonizer should have merged those mul together. #The canonizer should have merged those mul together.
#We support only 1 additional level of mul. #We support only 1 additional level of mul.
i_mul = [x.owner and x.owner.op ==T.mul for x in node.inputs]
mul_idx = i_mul.index(True)#we take the first mul! mul_idx = i_mul.index(True)#we take the first mul!
m = node.inputs[mul_idx] m = node.inputs[mul_idx]
...@@ -1384,7 +1402,17 @@ def local_dot22_to_dot22scalar(node): ...@@ -1384,7 +1402,17 @@ def local_dot22_to_dot22scalar(node):
if _as_scalar(x): if _as_scalar(x):
scalar_idx=i scalar_idx=i
break break
dot=_dot22scalar(d.owner.inputs[0], d.owner.inputs[1],m.owner.inputs[scalar_idx])
a = T.cast(_as_scalar(m.owner.inputs[scalar_idx]), d.type.dtype)
assert not a.type.ndim
dot=_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
# What about the other inputs to the original node that were
# neither part of the dot22 or this mul?
# I'm asserting there are no such inputs here:
assert dot22_idx != mul_idx
assert all((i in (dot22_idx, mul_idx))
for i in range(len(node.inputs)))
return [T.mul(m.owner.inputs[1-i],dot)] return [T.mul(m.owner.inputs[1-i],dot)]
elif m.owner and m.owner.op == T.mul: elif m.owner and m.owner.op == T.mul:
...@@ -1397,7 +1425,9 @@ def local_dot22_to_dot22scalar(node): ...@@ -1397,7 +1425,9 @@ def local_dot22_to_dot22scalar(node):
scalar_idx = -1 scalar_idx = -1
for i,x in enumerate(node.inputs): for i,x in enumerate(node.inputs):
if i_scalar[i] and theano.scalar.upcast(x.type.dtype,d.type.dtype) == d.type.dtype: if (i_scalar[i] is not None
and (theano.scalar.upcast(x.type.dtype,d.type.dtype)
== d.type.dtype)):
scalar_idx = i scalar_idx = i
break break
if scalar_idx<0: if scalar_idx<0:
...@@ -1405,15 +1435,18 @@ def local_dot22_to_dot22scalar(node): ...@@ -1405,15 +1435,18 @@ def local_dot22_to_dot22scalar(node):
'of the scalar cannot be upcasted to the matrix type', 'of the scalar cannot be upcasted to the matrix type',
node.inputs, [x.type for x in node.inputs]) node.inputs, [x.type for x in node.inputs])
return False return False
assert scalar_idx<len(node.inputs) assert scalar_idx < len(node.inputs)
s = node.inputs[scalar_idx] s = node.inputs[scalar_idx]
o = copy.copy(node.inputs) o = copy.copy(node.inputs)
o.remove(d) o.remove(d)
o.remove(s) o.remove(s)
if len(o)==0:
return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], s)] a = T.cast(i_scalar[scalar_idx], d.type.dtype)
assert not a.type.ndim
if len(o) == 0:
return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)]
else: else:
return [T.mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], s), *o)] return [T.mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a), *o)]
#must happen after gemm as the gemm optimizer don't understant dot22scalar and gemm give more speed up then dot22scalar #must happen after gemm as the gemm optimizer don't understant dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register('local_dot22_to_dot22scalar', blas_optdb.register('local_dot22_to_dot22scalar',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论