提交 1e831dbe authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the Print lifter.

上级 5f75ecdc
......@@ -451,7 +451,7 @@ def gpu_print_wrapper(op, cnda):
@op_lifter([tensor.printing.Print])
def local_gpu_print_op(node, context_name):
x, = node.inputs
gpu_x, = x.owner.inputs
gpu_x = as_gpuarray_variable(x, context_name=context_name)
new_op = node.op.__class__(global_fn=gpu_print_wrapper)
new_op.old_op = node.op
return new_op(gpu_x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论