提交 75db05b6 authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

test_gemmcorr3d.py has been modified in order to respect the flake8 style.

上级 879b8716
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import unittest import unittest
import numpy import numpy
import copy
import theano import theano
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
# Skip tests if cuda_ndarray is not available. # Skip tests if cuda_ndarray is not available.
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import theano.sandbox.cuda as cuda_ndarray
if not cuda_ndarray.cuda_available:
raise SkipTest('Optional package cuda not available')
from theano.sandbox.cuda import float32_shared_constructor as shared from theano.sandbox.cuda import float32_shared_constructor as shared
from theano.sandbox.cuda.blas import ( from theano.sandbox.cuda.blas import (
GpuCorr3dMM, GpuCorr3dMM_gradWeights, GpuCorr3dMM_gradInputs) GpuCorr3dMM, GpuCorr3dMM_gradWeights, GpuCorr3dMM_gradInputs)
from theano.sandbox.cuda.basic_ops import gpu_contiguous from theano.sandbox.cuda.basic_ops import gpu_contiguous
import theano.sandbox.cuda as cuda_ndarray
if not cuda_ndarray.cuda_available:
raise SkipTest('Optional package cuda not available')
if theano.config.mode == 'FAST_COMPILE': if theano.config.mode == 'FAST_COMPILE':
mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu') mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu')
...@@ -122,7 +121,9 @@ class TestCorr3DMM(unittest.TestCase): ...@@ -122,7 +121,9 @@ class TestCorr3DMM(unittest.TestCase):
inputs = shared(inputs_val) inputs = shared(inputs_val)
filters = shared(filters_val) filters = shared(filters_val)
bias = shared(numpy.zeros(filters_shape[4]).astype('float32')) bias = shared(numpy.zeros(filters_shape[4]).astype('float32'))
conv = theano.tensor.nnet.convTransp3D(W=filters, b=bias, d=subsample, conv = theano.tensor.nnet.convTransp3D(W=filters,
b=bias,
d=subsample,
H=inputs) H=inputs)
f_ref = theano.function([], conv) f_ref = theano.function([], conv)
res_ref = f_ref() res_ref = f_ref()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论