提交 dfd1c614 authored 作者: James Bergstra's avatar James Bergstra 提交者: Frederic

added late dot22->gemv optimization

上级 32fa6f2e
......@@ -43,7 +43,7 @@ GEMV: Gemv
----------
The BLAS GEMV operation implements Z <- a X Y + b Z,
where Z is a matrix, Y, and Z are vectors, and a and b are scalars.
where X is a matrix, Y, and Z are vectors, and a and b are scalars.
Gemv implements the GEMV call in all its generality.
......@@ -1148,7 +1148,7 @@ def _gemm_from_factored_list(lst):
# Try every pair in the sM_list, trying to turn it into a gemm operation
for i in xrange(len(lst) - 1):
s_i,M_i = lst[i]
s_i, M_i = lst[i]
for j in xrange(i+1, len(lst)):
s_j, M_j = lst[j]
......@@ -1414,20 +1414,36 @@ def local_gemm_to_ger(node):
#TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
# working
@local_optimizer([_dot22])
def local_dot22_to_ger(node):
"""GEMM computing an outer-product -> GER
def local_dot22_to_ger_or_gemv(node):
"""dot22 computing an outer-product -> GER
"""
if node.op == _dot22:
x, y = node.inputs
if x.broadcastable[1] and y.broadcastable[0]:
xb = x.broadcastable
yb = y.broadcastable
one = T.as_tensor_variable(numpy.asarray(1, dtype=x.dtype))
zero = T.as_tensor_variable(numpy.asarray(0, dtype=x.dtype))
if xb[1] and yb[0]:
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
one = T.as_tensor_variable(numpy.asarray(1, dtype=x.dtype))
zeros = T.alloc(numpy.asarray(0, dtype=x.dtype), x.shape[0], y.shape[1])
rval = ger(zeros, one, xv, yv)
return [rval]
if xb[0] and not yb[0] and not yb[1]:
# x is vector, y is matrix so try gemv
xv = x.dimshuffle(1)
zeros = T.alloc(numpy.asarray(0, dtype=x.dtype), y.shape[1])
rval = gemv_inplace(zeros, one, y.T, xv, one)
return [rval.dimshuffle('x', 0)]
if not xb[0] and not xb[1] and yb[1]:
# x is matrix, y is vector, try gemv
yv = y.dimshuffle(0)
zeros = T.alloc(numpy.asarray(0, dtype=x.dtype), x.shape[0])
rval = gemv_inplace(zeros, one, x, yv, one)
return [rval.dimshuffle(0, 'x')]
#################################
#
......@@ -1445,14 +1461,14 @@ optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
blas_optdb.register('local_dot_to_dot22',
EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5),
0, 'fast_run')
blas_optdb.register('local_dot_to_gemm',
blas_optdb.register('gemm_optimizer',
GemmOptimizer(),
10, 'fast_run')
blas_optdb.register('local_gemm_to_gemv',
EquilibriumOptimizer([
local_gemm_to_gemv,
local_gemm_to_ger,
local_dot22_to_ger,
local_dot22_to_ger_or_gemv,
local_dimshuffle_lift],
max_use_ratio=5),
15, 'fast_run')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论