提交 6b0fe3fb authored 作者: notoraptor's avatar notoraptor

Re-add test_local_gpu_elemwise_careduce with cudnn opts excluded.

上级 9d5789bd
......@@ -359,6 +359,29 @@ def test_pdbbreakpoint_op():
assert _check_stack_trace(f)
def test_local_gpu_elemwise_careduce():
mode_with_gpu_no_cudnn = mode_with_gpu.excluding('cudnn')
x = theano.tensor.matrix()
o = (x * x).sum()
f = theano.function([x], o, mode=mode_with_gpu_no_cudnn)
topo = f.maker.fgraph.toposort()
assert len(topo) == 3
assert isinstance(topo[1].op, GpuCAReduceCuda)
assert topo[1].op.pre_scalar_op == theano.scalar.sqr
assert _check_stack_trace(f)
data = np.random.rand(3, 4).astype(theano.config.floatX)
utt.assert_allclose(f(data), (data * data).sum())
o = (x * x).sum(axis=1)
f = theano.function([x], o, mode=mode_with_gpu_no_cudnn)
topo = f.maker.fgraph.toposort()
assert len(topo) == 3
assert isinstance(topo[1].op, GpuCAReduceCuda)
assert topo[1].op.pre_scalar_op == theano.scalar.sqr
assert _check_stack_trace(f)
utt.assert_allclose(f(data), (data * data).sum(axis=1))
def test_dnn_reduction_sum_squares():
if not dnn.dnn_available(test_ctx_name) or dnn.version(raises=False) < 6000:
raise SkipTest(dnn.dnn_available.msg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论