提交 dfd1c614 authored 作者: James Bergstra's avatar James Bergstra 提交者: Frederic

added late dot22->gemv optimization

上级 32fa6f2e
...@@ -43,7 +43,7 @@ GEMV: Gemv ...@@ -43,7 +43,7 @@ GEMV: Gemv
---------- ----------
The BLAS GEMV operation implements Z <- a X Y + b Z, The BLAS GEMV operation implements Z <- a X Y + b Z,
where Z is a matrix, Y, and Z are vectors, and a and b are scalars. where X is a matrix, Y, and Z are vectors, and a and b are scalars.
Gemv implements the GEMV call in all its generality. Gemv implements the GEMV call in all its generality.
...@@ -1148,7 +1148,7 @@ def _gemm_from_factored_list(lst): ...@@ -1148,7 +1148,7 @@ def _gemm_from_factored_list(lst):
# Try every pair in the sM_list, trying to turn it into a gemm operation # Try every pair in the sM_list, trying to turn it into a gemm operation
for i in xrange(len(lst) - 1): for i in xrange(len(lst) - 1):
s_i,M_i = lst[i] s_i, M_i = lst[i]
for j in xrange(i+1, len(lst)): for j in xrange(i+1, len(lst)):
s_j, M_j = lst[j] s_j, M_j = lst[j]
...@@ -1414,20 +1414,36 @@ def local_gemm_to_ger(node): ...@@ -1414,20 +1414,36 @@ def local_gemm_to_ger(node):
#TODO: delete this optimization when we have the proper dot->gemm->ger pipeline #TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
# working # working
@local_optimizer([_dot22]) @local_optimizer([_dot22])
def local_dot22_to_ger(node): def local_dot22_to_ger_or_gemv(node):
"""GEMM computing an outer-product -> GER """dot22 computing an outer-product -> GER
""" """
if node.op == _dot22: if node.op == _dot22:
x, y = node.inputs x, y = node.inputs
if x.broadcastable[1] and y.broadcastable[0]: xb = x.broadcastable
yb = y.broadcastable
one = T.as_tensor_variable(numpy.asarray(1, dtype=x.dtype))
zero = T.as_tensor_variable(numpy.asarray(0, dtype=x.dtype))
if xb[1] and yb[0]:
# x and y are both vectors so this might qualifies for a GER # x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0) xv = x.dimshuffle(0)
yv = y.dimshuffle(1) yv = y.dimshuffle(1)
one = T.as_tensor_variable(numpy.asarray(1, dtype=x.dtype))
zeros = T.alloc(numpy.asarray(0, dtype=x.dtype), x.shape[0], y.shape[1]) zeros = T.alloc(numpy.asarray(0, dtype=x.dtype), x.shape[0], y.shape[1])
rval = ger(zeros, one, xv, yv) rval = ger(zeros, one, xv, yv)
return [rval] return [rval]
if xb[0] and not yb[0] and not yb[1]:
# x is vector, y is matrix so try gemv
xv = x.dimshuffle(1)
zeros = T.alloc(numpy.asarray(0, dtype=x.dtype), y.shape[1])
rval = gemv_inplace(zeros, one, y.T, xv, one)
return [rval.dimshuffle('x', 0)]
if not xb[0] and not xb[1] and yb[1]:
# x is matrix, y is vector, try gemv
yv = y.dimshuffle(0)
zeros = T.alloc(numpy.asarray(0, dtype=x.dtype), x.shape[0])
rval = gemv_inplace(zeros, one, x, yv, one)
return [rval.dimshuffle(0, 'x')]
################################# #################################
# #
...@@ -1445,14 +1461,14 @@ optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run') ...@@ -1445,14 +1461,14 @@ optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
blas_optdb.register('local_dot_to_dot22', blas_optdb.register('local_dot_to_dot22',
EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5), EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5),
0, 'fast_run') 0, 'fast_run')
blas_optdb.register('local_dot_to_gemm', blas_optdb.register('gemm_optimizer',
GemmOptimizer(), GemmOptimizer(),
10, 'fast_run') 10, 'fast_run')
blas_optdb.register('local_gemm_to_gemv', blas_optdb.register('local_gemm_to_gemv',
EquilibriumOptimizer([ EquilibriumOptimizer([
local_gemm_to_gemv, local_gemm_to_gemv,
local_gemm_to_ger, local_gemm_to_ger,
local_dot22_to_ger, local_dot22_to_ger_or_gemv,
local_dimshuffle_lift], local_dimshuffle_lift],
max_use_ratio=5), max_use_ratio=5),
15, 'fast_run') 15, 'fast_run')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论