提交 f8a19f5b authored 作者: Frederic's avatar Frederic

Make sure dot(matrix, matrix) get moved to the gpu in fast_compile

上级 8c7991b3
...@@ -1815,13 +1815,14 @@ def local_dot22_to_ger_or_gemv(node): ...@@ -1815,13 +1815,14 @@ def local_dot22_to_ger_or_gemv(node):
blas_optdb = SequenceDB() blas_optdb = SequenceDB()
# run after numerical stability optimizations (1.5) # run after numerical stability optimizations (1.5)
optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run') optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run', 'fast_compile')
# run before specialize (2.0) because specialize is basically a # run before specialize (2.0) because specialize is basically a
# free-for-all that makes the graph crazy. # free-for-all that makes the graph crazy.
#fast_compile is needed to have GpuDot22 created.
blas_optdb.register('local_dot_to_dot22', blas_optdb.register('local_dot_to_dot22',
in2out(local_dot_to_dot22), in2out(local_dot_to_dot22),
0, 'fast_run') 0, 'fast_run', 'fast_compile')
blas_optdb.register('gemm_optimizer', blas_optdb.register('gemm_optimizer',
GemmOptimizer(), GemmOptimizer(),
10, 'fast_run') 10, 'fast_run')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论