提交 52dcfe33 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix tests for GpuCAReduce and a type error in some nan cases.

上级 618b2d94
...@@ -319,13 +319,14 @@ class GpuCAReduce(HideC, CAReduceDtype): ...@@ -319,13 +319,14 @@ class GpuCAReduce(HideC, CAReduceDtype):
acc_dtype = getattr(self, 'acc_dtype', None) acc_dtype = getattr(self, 'acc_dtype', None)
if acc_dtype is None: if acc_dtype is None:
acc_dtype = node.output[0].type.dtype acc_dtype = node.outputs[0].type.dtype
if any(redux): if any(redux):
if not hasattr(node, '_cache_reduction_k'): if not hasattr(node, '_cache_reduction_k'):
node._cache_reduction_k = self.generate_kernel(node, acc_dtype, node._cache_reduction_k = self.generate_kernel(node, acc_dtype,
redux) redux)
output[0] = node._cache_reduction_k(input) output[0] = node._cache_reduction_k(input).astype(copy=False,
dtype=node.outputs[0].type.dtype)
else: else:
output[0] = pygpu.gpuarray.array(input, copy=True, output[0] = pygpu.gpuarray.array(input, copy=True,
dtype=node.outputs[0].type.dtype) dtype=node.outputs[0].type.dtype)
...@@ -46,7 +46,6 @@ class test_GpuCAReduce(test_CAReduce): ...@@ -46,7 +46,6 @@ class test_GpuCAReduce(test_CAReduce):
self.with_linker(gof.PerformLinker(), op, dtype=dtype) self.with_linker(gof.PerformLinker(), op, dtype=dtype)
def test_perform_nan(self): def test_perform_nan(self):
raise SkipTest("for now")
for dtype in self.dtypes: for dtype in self.dtypes:
for op in self.reds: for op in self.reds:
self.with_linker(gof.PerformLinker(), op, dtype=dtype, self.with_linker(gof.PerformLinker(), op, dtype=dtype,
......
...@@ -11,7 +11,7 @@ from theano.gof.python25 import all, any ...@@ -11,7 +11,7 @@ from theano.gof.python25 import all, any
from theano import gof, scalar, config from theano import gof, scalar, config
from theano import tensor from theano import tensor
from theano.tensor import TensorType from theano.tensor import TensorType, as_tensor_variable
from theano.compile.mode import get_default_mode from theano.compile.mode import get_default_mode
from theano.tensor.elemwise import (CAReduce, Elemwise, DimShuffle, from theano.tensor.elemwise import (CAReduce, Elemwise, DimShuffle,
Prod, ProdWithoutZeros) Prod, ProdWithoutZeros)
...@@ -306,9 +306,9 @@ class test_CAReduce(unittest_tools.InferShapeTester): ...@@ -306,9 +306,9 @@ class test_CAReduce(unittest_tools.InferShapeTester):
dtype = theano.config.floatX dtype = theano.config.floatX
x = TensorType(dtype, [(entry == 1) for entry in xsh])('x') x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
if tensor_op is None: if tensor_op is None:
e = self.op(scalar_op, axis=tosum)(x) e = as_tensor_variable(self.op(scalar_op, axis=tosum)(x))
else: else:
e = tensor_op(x, axis=tosum) e = as_tensor_variable(tensor_op(x, axis=tosum))
if tosum is None: if tosum is None:
tosum = range(len(xsh)) tosum = range(len(xsh))
...@@ -413,7 +413,7 @@ class test_CAReduce(unittest_tools.InferShapeTester): ...@@ -413,7 +413,7 @@ class test_CAReduce(unittest_tools.InferShapeTester):
if isinstance(linker, gof.PerformLinker): if isinstance(linker, gof.PerformLinker):
x = TensorType(dtype, [(entry == 1) for entry in xsh])('x') x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
if tensor_op is None: if tensor_op is None:
e = CAReduce(scalar_op, axis=tosum)(x) e = self.op(scalar_op, axis=tosum)(x)
else: else:
e = tensor_op(x, axis=tosum) e = tensor_op(x, axis=tosum)
if tosum is None: if tosum is None:
...@@ -509,8 +509,8 @@ class test_CAReduce(unittest_tools.InferShapeTester): ...@@ -509,8 +509,8 @@ class test_CAReduce(unittest_tools.InferShapeTester):
tosum = range(len(xsh)) tosum = range(len(xsh))
xv = numpy.asarray(numpy.random.rand(*xsh), dtype=dtype) xv = numpy.asarray(numpy.random.rand(*xsh), dtype=dtype)
self._compile_and_check([x], self._compile_and_check([x],
[CAReduce(scalar.add, axis=tosum)(x)], [self.op(scalar.add, axis=tosum)(x)],
[xv], CAReduce, ["local_cut_useless_reduce"]) [xv], self.op, ["local_cut_useless_reduce"])
class test_Prod(unittest.TestCase): class test_Prod(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论