提交 9aaa972a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add special-case optimization to move Allocs that are inputs to Join to the Gpu.

上级 0cf6ba43
...@@ -152,9 +152,27 @@ optdb['canonicalize'].register('local_cut_gpua_host_gpua', ...@@ -152,9 +152,27 @@ optdb['canonicalize'].register('local_cut_gpua_host_gpua',
local_cut_gpu_host_gpu, 'fast_run', 'gpuarray') local_cut_gpu_host_gpu, 'fast_run', 'gpuarray')
@register_opt()
@local_optimizer([tensor.Alloc])
def local_gpuaalloc2(node):
"""
Join(axis, Alloc, Alloc, ...) -> Join(axis, GpuAlloc, Alloc, ...)
Moves an alloc that is an input to join to the gpu.
"""
if (isinstance(node.op, tensor.Alloc) and
all(c != 'output' and
c.op == tensor.join and
all(i.owner and
i.owner.op in [host_from_gpu, tensor.alloc]
for i in c.inputs[1:])
for c, idx in node.outputs[0].clients)):
return [host_from_gpu(gpu_alloc(*node.inputs))]
@register_opt() @register_opt()
@op_lifter([tensor.Alloc]) @op_lifter([tensor.Alloc])
def local_gpualloc(node): def local_gpuaalloc(node):
new_out = gpu_alloc(*node.inputs) new_out = gpu_alloc(*node.inputs)
# We need to hide new broadcastable dimensions because # We need to hide new broadcastable dimensions because
# ReplaceValidate doesn't like when they change. # ReplaceValidate doesn't like when they change.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论