提交 9b4238a8 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the tests for GpuMaxAndArgmax.

上级 aa636171
...@@ -83,7 +83,7 @@ int APPLY_SPECIFIC(dnn_redux)(PyGpuArrayObject *input, ...@@ -83,7 +83,7 @@ int APPLY_SPECIFIC(dnn_redux)(PyGpuArrayObject *input,
if (indices != NULL) { if (indices != NULL) {
if (theano_prep_output(indices, p, dims, GA_UINT, GA_C_ORDER, c) != 0) if (theano_prep_output(indices, p, dims, GA_UINT, GA_C_ORDER, c) != 0)
return 1; return 1;
indsize = PyGpuArray_SIZE(*indices); indsize = PyGpuArray_SIZE(*indices) * 4;
} }
if (p == input->ga.nd || rsz == 1) { if (p == input->ga.nd || rsz == 1) {
......
...@@ -10,6 +10,8 @@ from theano.tests.unittest_tools import SkipTest ...@@ -10,6 +10,8 @@ from theano.tests.unittest_tools import SkipTest
from .config import mode_with_gpu, mode_without_gpu from .config import mode_with_gpu, mode_without_gpu
from .test_basic_ops import rand_gpuarray from .test_basic_ops import rand_gpuarray
from .. import GpuArrayType from .. import GpuArrayType
from ..reduction import GpuMaxAndArgmax
from ..dnn import GpuDnnReduction
import math import math
...@@ -54,13 +56,13 @@ def numpy_maxandargmax(X, axis=None): ...@@ -54,13 +56,13 @@ def numpy_maxandargmax(X, axis=None):
def check_if_gpu_maxandargmax_in_graph(theano_function): def check_if_gpu_maxandargmax_in_graph(theano_function):
assert len([node for node in theano_function.maker.fgraph.apply_nodes assert any(isinstance(node.op, (GpuMaxAndArgmax, GpuDnnReduction))
if isinstance(node.op, theano.gpuarray.reduction.GpuMaxAndArgmax)]) > 0 for node in theano_function.maker.fgraph.apply_nodes)
def check_if_gpu_maxandargmax_not_in_graph(theano_function): def check_if_gpu_maxandargmax_not_in_graph(theano_function):
assert len([node for node in theano_function.maker.fgraph.apply_nodes assert all(not isinstance(node.op, (GpuMaxAndArgmax, GpuDnnReduction))
if isinstance(node.op, theano.gpuarray.reduction.GpuMaxAndArgmax)]) == 0 for node in theano_function.maker.fgraph.apply_nodes)
class BaseTest: class BaseTest:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论