提交 87a5ccc0 authored 作者: Frederic's avatar Frederic

Small fix when receiving a list of opt in in2out and out2in.

上级 15b005cf
...@@ -1198,9 +1198,10 @@ def local_inplace_ger(node): ...@@ -1198,9 +1198,10 @@ def local_inplace_ger(node):
# Also, need to make the gemm optimisation(step 70) happen before the fusion of # Also, need to make the gemm optimisation(step 70) happen before the fusion of
# elemwise(step 71) # elemwise(step 71)
optdb.register('InplaceGpuBlasOpt', optdb.register('InplaceGpuBlasOpt',
tensor.opt.in2out(gof.LocalOptGroup(local_inplace_gemm, tensor.opt.in2out(local_inplace_gemm,
local_inplace_gemv, local_inplace_gemv,
local_inplace_ger)), local_inplace_ger,
name="InplaceGpuBlasOpt"),
70.0, 'fast_run', 'inplace', 'gpu') 70.0, 'fast_run', 'inplace', 'gpu')
......
...@@ -52,7 +52,7 @@ def out2in(*local_opts, **kwargs): ...@@ -52,7 +52,7 @@ def out2in(*local_opts, **kwargs):
name = (kwargs and kwargs.pop('name', None)) name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1: if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization. # Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts), local_opts = opt.LocalOptGroup(*local_opts)
else: else:
local_opts, = local_opts local_opts, = local_opts
if not name: if not name:
...@@ -71,7 +71,7 @@ def in2out(*local_opts, **kwargs): ...@@ -71,7 +71,7 @@ def in2out(*local_opts, **kwargs):
name = (kwargs and kwargs.pop('name', None)) name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1: if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization. # Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts), local_opts = opt.LocalOptGroup(*local_opts)
else: else:
local_opts, = local_opts local_opts, = local_opts
if not name: if not name:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论