提交 dd31b668 authored 作者: James Bergstra's avatar James Bergstra

added dot22->ger specialization

上级 73f79883
...@@ -1319,6 +1319,10 @@ def local_inplace_gemv(node): ...@@ -1319,6 +1319,10 @@ def local_inplace_gemv(node):
if node.op == gemv_no_inplace: if node.op == gemv_no_inplace:
return [gemv_inplace(*node.inputs)] return [gemv_inplace(*node.inputs)]
@local_optimizer([ger])
def local_inplace_ger(node):
if node.op == ger:
return [ger_destructive(*node.inputs)]
@local_optimizer([gemm_no_inplace]) @local_optimizer([gemm_no_inplace])
def local_gemm_to_gemv(node): def local_gemm_to_gemv(node):
...@@ -1363,6 +1367,25 @@ def local_gemm_to_ger(node): ...@@ -1363,6 +1367,25 @@ def local_gemm_to_ger(node):
# pre-scaled and GER isn't really the right tool for the job. # pre-scaled and GER isn't really the right tool for the job.
return return
#TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
# working
@local_optimizer([_dot22])
def local_dot22_to_ger(node):
"""GEMM computing an outer-product -> GER
"""
if node.op == _dot22:
x, y = node.inputs
if x.broadcastable[1] and y.broadcastable[0]:
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
one = T.as_tensor_variable(numpy.asarray(1, dtype=x.dtype))
zeros = T.alloc(numpy.asarray(0, dtype=x.dtype), x.shape[0], y.shape[1])
rval = Ger(destructive=False)(zeros, one, xv, yv)
return [rval]
################################# #################################
# #
# Set up the BlasOpt optimizer # Set up the BlasOpt optimizer
...@@ -1383,7 +1406,10 @@ blas_optdb.register('local_dot_to_gemm', ...@@ -1383,7 +1406,10 @@ blas_optdb.register('local_dot_to_gemm',
GemmOptimizer(), GemmOptimizer(),
10, 'fast_run') 10, 'fast_run')
blas_optdb.register('local_gemm_to_gemv', blas_optdb.register('local_gemm_to_gemv',
EquilibriumOptimizer([local_gemm_to_gemv, local_gemm_to_ger, EquilibriumOptimizer([
local_gemm_to_gemv,
local_gemm_to_ger,
local_dot22_to_ger,
local_dimshuffle_lift], local_dimshuffle_lift],
max_use_ratio=5), max_use_ratio=5),
15, 'fast_run') 15, 'fast_run')
...@@ -1393,7 +1419,9 @@ blas_optdb.register('local_gemm_to_gemv', ...@@ -1393,7 +1419,9 @@ blas_optdb.register('local_gemm_to_gemv',
# Try to make gemm inplace # Try to make gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71) # Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71)
optdb.register('InplaceBlasOpt', optdb.register('InplaceBlasOpt',
EquilibriumOptimizer([local_inplace_gemm, local_inplace_gemv], failure_callback=EquilibriumOptimizer.warn_inplace, EquilibriumOptimizer(
[local_inplace_gemm, local_inplace_gemv, local_inplace_ger],
failure_callback=EquilibriumOptimizer.warn_inplace,
max_use_ratio=5), max_use_ratio=5),
70.0, 'fast_run', 'inplace') 70.0, 'fast_run', 'inplace')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论