提交 d51c9900 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the GPU tests for CAReduce.

上级 710c193a
......@@ -212,32 +212,40 @@ class test_GpuCAReduceCPY(test_elemwise.test_CAReduce):
def test_perform(self):
for dtype in self.dtypes + self.bin_dtypes:
for op in self.reds:
self.with_linker(gof.PerformLinker(), op, dtype=dtype,
pre_scalar_op=self.pre_scalar_op)
self.with_mode(Mode(linker='py',
optimizer=mode_with_gpu.optimizer),
op, dtype=dtype,
pre_scalar_op=self.pre_scalar_op)
def test_perform_nan(self):
for dtype in self.dtypes:
if not dtype.startswith('float'):
continue
for op in self.reds:
self.with_linker(gof.PerformLinker(), op, dtype=dtype,
test_nan=True,
pre_scalar_op=self.pre_scalar_op)
self.with_mode(Mode(linker='py',
optimizer=mode_with_gpu.optimizer)
op, dtype=dtype,
test_nan=True,
pre_scalar_op=self.pre_scalar_op)
def test_c(self):
for dtype in self.dtypes + self.bin_dtypes:
for op in self.reds:
self.with_linker(gof.CLinker(), op, dtype=dtype,
pre_scalar_op=self.pre_scalar_op)
self.with_mode(Mode(linker='c',
optimizer=mode_with_gpu.optimizer)
op, dtype=dtype,
pre_scalar_op=self.pre_scalar_op)
def test_c_nan(self):
for dtype in self.dtypes:
if not dtype.startswith('float'):
continue
for op in self.reds:
self.with_linker(gof.CLinker(), op, dtype=dtype,
test_nan=True,
pre_scalar_op=self.pre_scalar_op)
self.with_mode(Mode(linker='c',
optimizer=mode_with_gpu.optimizer)
op, dtype=dtype,
test_nan=True,
pre_scalar_op=self.pre_scalar_op)
def test_infer_shape(self):
for dtype in self.dtypes:
......@@ -334,6 +342,9 @@ class test_GpuCAReduceCuda(test_GpuCAReduceCPY):
scalar.maximum, scalar.minimum]
pre_scalar_op = None
def test_perform_noopt(self):
return
def test_perform(self):
return
......
......@@ -546,7 +546,8 @@ class test_CAReduce(unittest_tools.InferShapeTester):
@attr('slow')
def test_c(self):
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_mode(Mode(linker='c'), scalar.add, dtype=dtype)
self.with_mode(Mode(linker='c'), scalar.mul, dtype=dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论