提交 c5dc5576 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add test for Rebroadcast on GpuArrayVariables

上级 332b691e
import numpy
import theano
from theano import tensor
from theano.tests import unittest_tools as utt
import theano.sandbox.gpuarray
from theano.sandbox.gpuarray.type import GpuArrayType
from theano.sandbox.gpuarray.basic_ops import GpuAlloc, GpuReshape, gpu_alloc
from theano.sandbox.gpuarray.elemwise import GpuCAReduceCuda
import theano.sandbox.gpuarray
from theano.sandbox.gpuarray.tests.test_basic_ops import (
rand_gpuarray, mode_with_gpu, mode_without_gpu
)
from theano.tests.unittest_tools import SkipTest
if theano.sandbox.gpuarray.pygpu is None:
raise SkipTest("pygpu not installed")
import theano.sandbox.cuda as cuda_ndarray
if cuda_ndarray.cuda_available and not theano.sandbox.gpuarray.pygpu_activated:
if not cuda_ndarray.use.device_number:
cuda_ndarray.use('gpu')
theano.sandbox.gpuarray.init_dev('cuda')
if not theano.sandbox.gpuarray.pygpu_activated:
raise SkipTest("pygpu disabled")
if theano.config.mode == 'FAST_COMPILE':
mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpuarray').excluding('gpu')
mode_without_gpu = theano.compile.mode.get_mode('FAST_RUN').excluding('gpuarray')
else:
mode_with_gpu = theano.compile.mode.get_default_mode().including('gpuarray').excluding('gpu')
mode_without_gpu = theano.compile.mode.get_default_mode().excluding('gpuarray')
def test_flatten():
m = theano.tensor.fmatrix()
f = theano.function([m], m.flatten(), mode=mode_with_gpu)
......@@ -104,3 +89,20 @@ def test_local_gpualloc_memset_0():
assert isinstance(topo[0].op, GpuAlloc)
assert not topo[0].op.memset_0
assert (numpy.asarray(f(2)) == 1).all()
def test_rebroadcast():
d = numpy.random.rand(10, 10).astype('float32')
v = theano.tensor.fmatrix()
up = tensor.unbroadcast(v.sum().dimshuffle('x', 'x'), 0, 1)
f = theano.function([v], [up], mode=mode_with_gpu)
f(d)
topo = f.maker.fgraph.toposort()
rebrs = [node for node in topo if isinstance(node.op, tensor.Rebroadcast)]
assert len(rebrs) == 1
rebr = rebrs[0]
assert isinstance(rebr.inputs[0].type, GpuArrayType)
assert isinstance(rebr.outputs[0].type, GpuArrayType)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论