提交 e95b4eba authored 作者: Frederic's avatar Frederic

Move print op to the new gpu back-end.

上级 81124da3
......@@ -288,6 +288,22 @@ def local_gpua_specifyShape(node):
return tensor.specify_shape(*inp)
def gpu_print_wrapper(op, cnda):
op.old_op.global_fn(op.old_op, numpy.asarray(cnda))
@register_opt()
@op_lifter([tensor.printing.Print])
def local_gpu_print_op(node):
x, = node.inputs
if x.owner and isinstance(x.owner.op, HostFromGpu):
gpu_x, = x.owner.inputs
new_op = node.op.__class__(global_fn=gpu_print_wrapper)
new_op.old_op = node.op
return new_op(gpu_x)
return False
@register_opt()
@op_lifter([tensor.Join])
def local_gpua_join(node):
......
......@@ -5,8 +5,9 @@ from theano import tensor
from theano.tests import unittest_tools as utt
import theano.sandbox.gpuarray
from theano.sandbox.gpuarray.type import GpuArrayType
from theano.sandbox.gpuarray.basic_ops import GpuAlloc, GpuReshape, gpu_alloc
from theano.sandbox.gpuarray.elemwise import GpuCAReduceCuda
from theano.sandbox.gpuarray.basic_ops import (
GpuAlloc, GpuReshape, gpu_alloc, gpu_from_host, host_from_gpu)
from theano.sandbox.gpuarray.elemwise import GpuCAReduceCuda, GpuElemwise
from theano.sandbox.gpuarray.tests.test_basic_ops import (
rand_gpuarray, mode_with_gpu, mode_without_gpu
)
......@@ -116,3 +117,19 @@ class TestSpecifyShape(TestSpecifyShape):
mode = mode_with_gpu
input_type = GpuArrayType
pass
def test_print_op():
""" Test that print ops don't block gpu optimization"""
b = tensor.fmatrix()
f = theano.function([b], theano.printing.Print()(b) * 2,
mode=mode_with_gpu)
theano.printing.debugprint(f)
#print f.maker.fgraph.toposort()
#[GpuFromHost(<TensorType(float32, matrix)>), <theano.printing.Print object at 0x3581210>(GpuFromHost.0), GpuElemwise{mul}(CudaNdarray{[[ 2.]]}, <theano.printing.Print object at 0x3581210>.0), HostFromGpu(GpuElemwise{mul}.0)]
topo = f.maker.fgraph.toposort()
assert topo[0].op == gpu_from_host
assert isinstance(topo[1].op, theano.printing.Print)
assert isinstance(topo[2].op, GpuElemwise)
assert topo[3].op == host_from_gpu
f(numpy.random.random((5, 5)).astype('float32'))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论