提交 08762910 authored 作者: Reyhane Askari's avatar Reyhane Askari

minor changes

上级 f80fbcea
...@@ -265,10 +265,13 @@ def op_lifter(OP, cuda_only=False): ...@@ -265,10 +265,13 @@ def op_lifter(OP, cuda_only=False):
return x.transfer('cpu') return x.transfer('cpu')
# copy stack traces onto gpu outputs # copy stack traces onto gpu outputs
# also copy the stack traces onto HostFromGpu outputs # also copy the stack traces onto HostFromGpu outputs
on_cpu = []
for old_output, new_output in zip(node.outputs, new_outputs): for old_output, new_output in zip(node.outputs, new_outputs):
copy_stack_trace(old_output, new_output) copy_stack_trace(old_output, new_output)
copy_stack_trace(old_output, to_cpu_fn(new_output)) cpu = to_cpu_fn(new_output)
return new_outputs on_cpu.append(cpu)
copy_stack_trace(old_output, cpu)
return on_cpu
return False return False
local_opt.__name__ = maker.__name__ local_opt.__name__ = maker.__name__
return local_optimizer(OP)(local_opt) return local_optimizer(OP)(local_opt)
...@@ -1355,28 +1358,24 @@ def local_gpua_gemmbatch(op, context_name, inputs, outputs): ...@@ -1355,28 +1358,24 @@ def local_gpua_gemmbatch(op, context_name, inputs, outputs):
@register_opt() @register_opt()
@alpha_merge(GpuGemm, alpha_in=1, beta_in=4) @alpha_merge(GpuGemm, alpha_in=1, beta_in=4)
def local_gpua_gemm_alpha_merge(node, *inputs): def local_gpua_gemm_alpha_merge(node, *inputs):
with inherit_stack_trace(node.outputs):
return [gpugemm_no_inplace(*inputs)] return [gpugemm_no_inplace(*inputs)]
@register_opt() @register_opt()
@output_merge(GpuGemm, alpha_in=1, beta_in=4, out_in=0) @output_merge(GpuGemm, alpha_in=1, beta_in=4, out_in=0)
def local_gpua_gemm_output_merge(node, *inputs): def local_gpua_gemm_output_merge(node, *inputs):
with inherit_stack_trace(node.outputs):
return [gpugemm_no_inplace(*inputs)] return [gpugemm_no_inplace(*inputs)]
@register_opt() @register_opt()
@alpha_merge(GpuGemmBatch, alpha_in=1, beta_in=4) @alpha_merge(GpuGemmBatch, alpha_in=1, beta_in=4)
def local_gpua_gemmbatch_alpha_merge(node, *inputs): def local_gpua_gemmbatch_alpha_merge(node, *inputs):
with inherit_stack_trace(node.outputs):
return [gpugemmbatch_no_inplace(*inputs)] return [gpugemmbatch_no_inplace(*inputs)]
@register_opt() @register_opt()
@output_merge(GpuGemmBatch, alpha_in=1, beta_in=4, out_in=0) @output_merge(GpuGemmBatch, alpha_in=1, beta_in=4, out_in=0)
def local_gpua_gemmbatch_output_merge(node, *inputs): def local_gpua_gemmbatch_output_merge(node, *inputs):
with inherit_stack_trace(node.outputs):
return [gpugemmbatch_no_inplace(*inputs)] return [gpugemmbatch_no_inplace(*inputs)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论