提交 4d040292 authored 作者: James Bergstra's avatar James Bergstra

Added Gemv to tensor.blas

上级 ce7f9a78
...@@ -13,6 +13,7 @@ from theano.gof.python25 import any ...@@ -13,6 +13,7 @@ from theano.gof.python25 import any
import theano.scalar import theano.scalar
import basic as T import basic as T
from theano.tensor.tsor_apply import Apply from theano.tensor.tsor_apply import Apply
#NB: this clobbers the builtin 'compile' symbol #NB: this clobbers the builtin 'compile' symbol
...@@ -28,6 +29,74 @@ def warn(*msg): _logger.warn(' '.join(str(m) for m in msg)) ...@@ -28,6 +29,74 @@ def warn(*msg): _logger.warn(' '.join(str(m) for m in msg))
def warning(*msg): _logger.warning(' '.join(str(m) for m in msg)) def warning(*msg): _logger.warning(' '.join(str(m) for m in msg))
def error(*msg): _logger.error(' '.join(str(m) for m in msg)) def error(*msg): _logger.error(' '.join(str(m) for m in msg))
try:
import scipy.linalg.blas
_have_fblas = True
_blas_gemv_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sgemv,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dgemv,
numpy.dtype('complex64'):scipy.linalg.blas.fblas.cgemv,
numpy.dtype('complex128'):scipy.linalg.blas.fblas.zgemv,
}
except ImportError, e:
_have_fblas = False
warning('Failed to import scipy.linalg.blas.fblas. Falling back on slower implementations (%s)' % str(e))
class Gemv(Op):
"""
expression is beta * y + alpha * A x
A is matrix
x, y are vectors
alpha, beta are scalars
"""
def __init__(self, inplace):
self.inplace=inplace
if inplace:
self.destroy_map={0:[0]}
def __eq__(self, other):
return type(self)==type(other) and self.inplace == other.inplace
def __str__(self):
if self.inplace:
return 'Gemv{inplace}'
else:
return 'Gemv{no_inplace}'
def __hash__(self):
return hash(type(self)) ^ hash(self.inplace)
def make_node(self, y, alpha, A, x, beta):
y = T.as_tensor_variable(y)
x = T.as_tensor_variable(x)
A = T.as_tensor_variable(A)
alpha = T.as_tensor_variable(alpha)
beta = T.as_tensor_variable(beta)
if y.dtype != A.dtype or y.dtype != x.dtype:
raise TypeError('Gemv requires matching dtypes', (y.dtype, A.dtype, x.dtype))
if A.ndim != 2: raise TypeError('gemv requires matrix for A', A.type)
if x.ndim != 1: raise TypeError('gemv requires vector for x', x.type)
if y.ndim != 1: raise TypeError('gemv requires vector for y', y.type)
if y.broadcastable[0] != A.broadcastable[0]:
raise TypeError('broadcastable mismatch between y and A', (y.type, A.type))
# The following is not grounds for error
# because as long as sizes are 1 at time of perform() there is no problem
#if x.broadcastable[0] != A.broadcastable[1]:
#raise TypeError('broadcastable mismatch between x and A', (x.type, A.type))
return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs
if _have_fblas:
if not self.inplace:
y = y.copy()
gemv = _blas_gemv_fns[y.dtype]
out_storage[0][0] = gemv(alpha, A, x, beta, y, overwrite_y=self.inplace)
else:
out_storage[0][0] = numpy.asarray(
beta * y + alpha * numpy.dot(A, x)
, dtype=y.dtype)
gemv_no_inplace = Gemv(inplace=False)
gemv_inplace = Gemv(inplace=True)
def default_blas_ldflags(): def default_blas_ldflags():
try: try:
return ' '.join( return ' '.join(
...@@ -583,11 +652,43 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -583,11 +652,43 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
# we've already checked the client counts, now just make the type check. # we've already checked the client counts, now just make the type check.
####if res_is_a(M, _dot22, 1): ####if res_is_a(M, _dot22, 1):
if M.owner and M.owner.op == _dot22: if M.owner and M.owner.op == _dot22:
if M.broadcastable == L.broadcastable:
Ml, Mr = M.owner.inputs Ml, Mr = M.owner.inputs
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)] rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
#print 'GEMM 0', rval, beta, L, alpha, M #print 'GEMM 0', rval, beta, L, alpha, M
return rval return rval
if M.owner and M.owner.op == T.dot\
and L.broadcastable==(False,) \
and M.broadcastable==(False,):
Ml, Mr = M.owner.inputs
rval = None
if Ml.ndim == 1:
if Mr.ndim == 1:
#TODO: insert a BLAS ddot Op
pass
if Mr.ndim == 2:
print "RETURNING GEMV (case 2)"
if Mr.dtype == Ml.dtype:
rval = [gemv_no_inplace(L, alpha, Mr.T, Ml, beta)]
assert L.type == rval[0].type, (L.type, rval[0].type)
else:
# TODO
pass
if Ml.ndim == 2:
if Mr.ndim == 1:
print "RETURNING GEMV (case 3)"
if Mr.dtype == Ml.dtype:
rval = [gemv_no_inplace(L, alpha, Ml, Mr, beta)]
assert L.type == rval[0].type, (L.type, rval[0].type)
else:
# TODO
pass
if Mr.ndim == 2:
# should have already got this case with a _dot22
pass
return rval
# this is False'd out because of inadequate testing. # this is False'd out because of inadequate testing.
# TODO see ticket #237 # TODO see ticket #237
if False and res_is_a(M, gemm_no_inplace, 1): if False and res_is_a(M, gemm_no_inplace, 1):
...@@ -620,7 +721,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -620,7 +721,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
def _gemm_canonicalize(r, scale, rval, maxclients): def _gemm_canonicalize(r, scale, rval, maxclients):
# Tries to interpret node as a sum of scalars * matrices # Tries to interpret node as a sum of scalars * (vectors or matrices)
def scaled(thing): def scaled(thing):
if scale == 1: if scale == 1:
return thing return thing
...@@ -633,7 +734,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients): ...@@ -633,7 +734,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
except: except:
return None return None
if (tuple(r.type.broadcastable) != (False, False) or if ((r.type.ndim not in (1, 2)) or
r.type.dtype not in ('float32', 'float64', 'complex64', 'complex128')): r.type.dtype not in ('float32', 'float64', 'complex64', 'complex128')):
rval.append(scaled(r)) rval.append(scaled(r))
return rval return rval
...@@ -655,6 +756,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients): ...@@ -655,6 +756,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
elif r.owner and r.owner.op == T.mul: elif r.owner and r.owner.op == T.mul:
scalars = [] scalars = []
vectors = []
matrices = [] matrices = []
for i in r.owner.inputs: for i in r.owner.inputs:
if numpy.all(i.type.broadcastable): if numpy.all(i.type.broadcastable):
...@@ -664,6 +766,8 @@ def _gemm_canonicalize(r, scale, rval, maxclients): ...@@ -664,6 +766,8 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
scalars.append(i.dimshuffle()) scalars.append(i.dimshuffle())
else: else:
scalars.append(i) scalars.append(i)
elif _is_real_vector(i):
vectors.append(i)
elif _is_real_matrix(i): elif _is_real_matrix(i):
matrices.append(i) matrices.append(i)
else: else:
...@@ -671,6 +775,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients): ...@@ -671,6 +775,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
rval.append((scale,r)) rval.append((scale,r))
return rval return rval
if len(matrices)==1: if len(matrices)==1:
assert len(vectors)==0
m = matrices[0] m = matrices[0]
if len(scalars) == 0: if len(scalars) == 0:
_gemm_canonicalize(m, scale, rval, 1) _gemm_canonicalize(m, scale, rval, 1)
...@@ -678,7 +783,16 @@ def _gemm_canonicalize(r, scale, rval, maxclients): ...@@ -678,7 +783,16 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
_gemm_canonicalize(m, scaled(scalars[0]), rval, 1) _gemm_canonicalize(m, scaled(scalars[0]), rval, 1)
else: else:
_gemm_canonicalize(m, T.mul(scaled(scalars[0]), *scalars[1:]), rval, 1) _gemm_canonicalize(m, T.mul(scaled(scalars[0]), *scalars[1:]), rval, 1)
else: #there are many matrices... lets not open this up elif len(vectors)==1:
assert len(matrices)==0
v = vectors[0]
if len(scalars) == 0:
_gemm_canonicalize(v, scale, rval, 1)
elif len(scalars) == 1:
_gemm_canonicalize(v, scaled(scalars[0]), rval, 1)
else:
_gemm_canonicalize(v, T.mul(scaled(scalars[0]), *scalars[1:]), rval, 1)
else: #lets not open this up
rval.append((scale,r)) rval.append((scale,r))
else: else:
rval.append((scale,r)) rval.append((scale,r))
...@@ -739,8 +853,8 @@ def _gemm_from_factored_list(lst): ...@@ -739,8 +853,8 @@ def _gemm_from_factored_list(lst):
#print 'TRYING', (s_i, M_i, s_j, M_j) #print 'TRYING', (s_i, M_i, s_j, M_j)
gemm_of_sM_list = _beta_L_plus_alpha_M(s_i, M_i, s_j, M_j) gemm_of_sM_list = _beta_L_plus_alpha_M(s_i, M_i, s_j, M_j)
if gemm_of_sM_list:
#print 'GOT IT', gemm_of_sM_list #print 'GOT IT', gemm_of_sM_list
if gemm_of_sM_list:
def item_to_var(t): def item_to_var(t):
try: s,M = t try: s,M = t
except: return t except: return t
...@@ -753,9 +867,11 @@ def _gemm_from_factored_list(lst): ...@@ -753,9 +867,11 @@ def _gemm_from_factored_list(lst):
for k, input in enumerate(lst) if k not in (i,j)] for k, input in enumerate(lst) if k not in (i,j)]
add_inputs.extend(gemm_of_sM_list) add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1: if len(add_inputs) > 1:
return [T.add(*add_inputs)] rval = [T.add(*add_inputs)]
else: else:
return add_inputs rval = add_inputs
#print "RETURNING GEMM THIGN", rval
return rval
def _gemm_from_node2(node): def _gemm_from_node2(node):
""" """
...@@ -766,9 +882,12 @@ def _gemm_from_node2(node): ...@@ -766,9 +882,12 @@ def _gemm_from_node2(node):
""" """
lst = [] lst = []
_gemm_canonicalize(node.outputs[0], 1.0, lst, 0) _gemm_canonicalize(node.outputs[0], 1.0, lst, 0)
#print "GEMM CANON", lst
if len(lst) > 1: if len(lst) > 1:
lst = _factor_canonicalized(lst) lst = _factor_canonicalized(lst)
rval = _gemm_from_factored_list(lst) rval = _gemm_from_factored_list(lst)
if rval:
assert rval[0].type == node.outputs[0].type, (rval[0].type, node.outputs[0].type)
return rval return rval
class GemmOptimizer(Optimizer): class GemmOptimizer(Optimizer):
...@@ -885,6 +1004,7 @@ def local_dot_to_dot22(node): ...@@ -885,6 +1004,7 @@ def local_dot_to_dot22(node):
if y.type.dtype.startswith('float'): if y.type.dtype.startswith('float'):
if _is_real_matrix(x) and _is_real_matrix(y): if _is_real_matrix(x) and _is_real_matrix(y):
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
if 0:
if _is_real_matrix(x) and _is_real_vector(y): if _is_real_matrix(x) and _is_real_vector(y):
return [_dot22(x, y.dimshuffle(0,'x')).dimshuffle(0)] return [_dot22(x, y.dimshuffle(0,'x')).dimshuffle(0)]
if _is_real_vector(x) and _is_real_matrix(y): if _is_real_vector(x) and _is_real_matrix(y):
...@@ -898,6 +1018,10 @@ def local_dot_to_dot22(node): ...@@ -898,6 +1018,10 @@ def local_dot_to_dot22(node):
def local_inplace_gemm(node): def local_inplace_gemm(node):
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
return [gemm_inplace(*node.inputs)] return [gemm_inplace(*node.inputs)]
@local_optimizer([gemv_no_inplace])
def local_inplace_gemv(node):
if node.op == gemv_no_inplace:
return [gemv_inplace(*node.inputs)]
################################# #################################
# #
...@@ -921,7 +1045,7 @@ blas_optdb.register('local_dot_to_gemm', GemmOptimizer(), 10, 'fast_run') ...@@ -921,7 +1045,7 @@ blas_optdb.register('local_dot_to_gemm', GemmOptimizer(), 10, 'fast_run')
# Try to make gemm inplace # Try to make gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71) # Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71)
optdb.register('InplaceBlasOpt', optdb.register('InplaceBlasOpt',
EquilibriumOptimizer([local_inplace_gemm], failure_callback=EquilibriumOptimizer.warn_inplace, EquilibriumOptimizer([local_inplace_gemm, local_inplace_gemv], failure_callback=EquilibriumOptimizer.warn_inplace,
max_use_ratio=5), max_use_ratio=5),
70.0, 'fast_run', 'inplace') 70.0, 'fast_run', 'inplace')
...@@ -1063,8 +1187,9 @@ blas_optdb.register('local_dot22_to_dot22scalar', ...@@ -1063,8 +1187,9 @@ blas_optdb.register('local_dot22_to_dot22scalar',
11, 'fast_run') 11, 'fast_run')
#@opt.register_stabilize from opt import register_specialize
@gof.local_optimizer([]) #@register_specialize
@local_optimizer([])
def local_print_as_we_go_along(node): def local_print_as_we_go_along(node):
if node.op == tensor.true_div: if node.op in (T.sub, T.add):
debugprint(node) debugprint(node)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论