提交 cc8822d3 authored 作者: James Bergstra's avatar James Bergstra

several changes to tensor/blas.py

上级 dfdea75f
......@@ -387,6 +387,12 @@ class GemmRelated(Op):
(long int)Ny[1], (long int)Nz[1]);
%(fail)s;
}
if (Nx[1] == 0)
{
PyErr_Format(PyExc_ValueError,
"Undefined semantics: x has 0 cols");
%(fail)s;
}
"""
check_strides = """
......@@ -497,6 +503,12 @@ class GemmRelated(Op):
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
//double t0 = time_time();
//fprintf(stderr, "unit=%%x N= %%i %%i %%i S = %%i %%i %%i %%i %%i %%i\\n", unit,
//Nz1, Nz0, Nx1,
//sy_0, sy_1,
//sx_0, sx_1,
//sz_0, sz_1
//);
switch(unit)
{
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
......@@ -540,7 +552,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '')
def build_gemm_version(self):
return (6,)
return (7,)
class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation):
......@@ -818,13 +830,6 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M)
if M.type.broadcastable != L.type.broadcastable:
# GEMM cannot do the broadcasting that add used to be doing
# so abort.
return
assert L.type.dtype == M.type.dtype # because of local_dot_to_dot22
# we've already checked the client counts, now just make the type check.
####if res_is_a(M, _dot22, 1):
if M.owner and M.owner.op == _dot22:
......@@ -1017,20 +1022,23 @@ def _gemm_from_factored_list(lst):
"""
# Make every pair in list have matching dtypes
lst = [(T.cast(si,Mi.type.dtype), Mi) for si,Mi in lst]
def is_pair(sM):
try:
s, M = sM
return True
except:
return False
lst = [(T.cast(sM[0],sM[1].type.dtype), sM[1])
for sM in lst if is_pair(sM)]
# Try every pair in the sM_list, trying to turn it into a gemm operation
for i in xrange(len(lst) - 1):
try:
s_i,M_i = lst[i]
except:
continue
s_i,M_i = lst[i]
for j in xrange(i+1, len(lst)):
s_j, M_j = lst[j]
try:
s_j, M_j = lst[j]
except:
if M_i.type != M_j.type:
continue
#print 'TRYING', (s_i, M_i, s_j, M_j)
......@@ -1281,24 +1289,32 @@ class Dot22Scalar(GemmRelated):
Also used to generate a gemm later.
compute scalar*dot(x,y)
"""
def make_node(self, x, y, scalar):
if not _is_real_matrix(x):
raise TypeError(x)
if not _is_real_matrix(x):
raise TypeError(y)
if not _as_scalar(scalar):
raise TypeError(scalar)
if y.type.dtype != x.type.dtype and y.type.dtype != scalar.type.dtype:
raise TypeError('dtype mismatch to Dot22Scalar')
bz = [False, False]
def make_node(self, x, y, a):
if a.ndim != 0:
raise TypeError(Gemm.E_scalar, a)
if x.ndim != 2:
raise TypeError(Gemm.E_rank, x)
if y.ndim != 2:
raise TypeError(Gemm.E_rank, y)
if not (a.dtype == x.dtype == y.dtype):
raise TypeError('Dot22Scalar requires matching dtypes',
(a.dtype, x.dtype, y.dtype))
if (not a.dtype.startswith('float')
and not a.dtype.startswith('complex')):
raise TypeError('Dot22Scalar requires float or complex args',
a.dtype)
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y,scalar], outputs)
return Apply(self, [x,y,a], outputs)
def perform(self, node, inp, out):
x, y, scalar = inp
z, = out
try:
z[0] = scalar * numpy.asarray(numpy.dot(x, y))
z[0] = numpy.asarray(scalar * numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape)
......@@ -1360,21 +1376,23 @@ def local_dot22_to_dot22scalar(node):
return False
i_dot22 = [x.owner and x.owner.op==_dot22 for x in node.inputs]
if not any(i_dot22): return False # no dot22
if i_dot22.count(True)>1: return False #TODO fix
#we take the first _dot22 found. TODO check others!
if i_dot22.count(True)>1:
#TODO: try each of them.
pass
#return False #TODO fix
dot22_idx = i_dot22.index(True)
d = node.inputs[dot22_idx]
i_scalar = [_as_scalar(x) for x in node.inputs]
if not any(i_scalar) and not any([x.owner and x.owner.op ==T.mul for x in node.inputs]):
#no scalar in input and no multiplication
#if their was a multiplication we couls reorder the graph by the associativity of the graph.
return False
if not any(i_scalar):
i_mul = [x.owner and x.owner.op ==T.mul for x in node.inputs]
if not any(i_mul):
#no scalar in input and no multiplication
#if their was a multiplication we couls reorder the graph by the associativity of the graph.
return False
#maybe we can reorder the graph as this mul have a mul in input.
#The canonizer should have merged those mul together.
#We support only 1 additional level of mul.
i_mul = [x.owner and x.owner.op ==T.mul for x in node.inputs]
mul_idx = i_mul.index(True)#we take the first mul!
m = node.inputs[mul_idx]
......@@ -1384,7 +1402,17 @@ def local_dot22_to_dot22scalar(node):
if _as_scalar(x):
scalar_idx=i
break
dot=_dot22scalar(d.owner.inputs[0], d.owner.inputs[1],m.owner.inputs[scalar_idx])
a = T.cast(_as_scalar(m.owner.inputs[scalar_idx]), d.type.dtype)
assert not a.type.ndim
dot=_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
# What about the other inputs to the original node that were
# neither part of the dot22 or this mul?
# I'm asserting there are no such inputs here:
assert dot22_idx != mul_idx
assert all((i in (dot22_idx, mul_idx))
for i in range(len(node.inputs)))
return [T.mul(m.owner.inputs[1-i],dot)]
elif m.owner and m.owner.op == T.mul:
......@@ -1397,7 +1425,9 @@ def local_dot22_to_dot22scalar(node):
scalar_idx = -1
for i,x in enumerate(node.inputs):
if i_scalar[i] and theano.scalar.upcast(x.type.dtype,d.type.dtype) == d.type.dtype:
if (i_scalar[i] is not None
and (theano.scalar.upcast(x.type.dtype,d.type.dtype)
== d.type.dtype)):
scalar_idx = i
break
if scalar_idx<0:
......@@ -1405,15 +1435,18 @@ def local_dot22_to_dot22scalar(node):
'of the scalar cannot be upcasted to the matrix type',
node.inputs, [x.type for x in node.inputs])
return False
assert scalar_idx<len(node.inputs)
assert scalar_idx < len(node.inputs)
s = node.inputs[scalar_idx]
o = copy.copy(node.inputs)
o.remove(d)
o.remove(s)
if len(o)==0:
return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], s)]
a = T.cast(i_scalar[scalar_idx], d.type.dtype)
assert not a.type.ndim
if len(o) == 0:
return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)]
else:
return [T.mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], s), *o)]
return [T.mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a), *o)]
#must happen after gemm as the gemm optimizer don't understant dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register('local_dot22_to_dot22scalar',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论