提交 e2a13cf5 authored 作者: Frederic's avatar Frederic

correctly do the version comparison

上级 9d9dc0ce
...@@ -746,7 +746,7 @@ class GpuDnnPoolDesc(GpuOp): ...@@ -746,7 +746,7 @@ class GpuDnnPoolDesc(GpuOp):
self.stride = stride self.stride = stride
assert len(stride) == 2 assert len(stride) == 2
self.pad = pad self.pad = pad
if (pad[0] != 0 or pad[1] != 0) and version() < 20: if (pad[0] != 0 or pad[1] != 0) and version() == -1:
raise RuntimeError("CuDNN pooling with padding requires CuDNN v2") raise RuntimeError("CuDNN pooling with padding requires CuDNN v2")
def __setstate__(self, d): def __setstate__(self, d):
...@@ -755,7 +755,7 @@ class GpuDnnPoolDesc(GpuOp): ...@@ -755,7 +755,7 @@ class GpuDnnPoolDesc(GpuOp):
self.pad = (0, 0) self.pad = (0, 0)
def make_node(self): def make_node(self):
if self.pad != (0, 0) and version() < 20: if self.pad != (0, 0) and version() == -1:
raise RuntimeError("CuDNN pooling with padding requires CuDNN v2") raise RuntimeError("CuDNN pooling with padding requires CuDNN v2")
return Apply(self, [], return Apply(self, [],
......
...@@ -70,7 +70,7 @@ def test_pooling(): ...@@ -70,7 +70,7 @@ def test_pooling():
x = T.ftensor4() x = T.ftensor4()
for func, pad in product((T.max, T.mean), for func, pad in product((T.max, T.mean),
((0, 0), (1, 0), (1, 0), (2, 3), (3, 2))): ((0, 0), (1, 0), (1, 0), (2, 3), (3, 2))):
if pad != (0, 0) and cuda.dnn.version() < 20: if pad != (0, 0) and cuda.dnn.version() == -1:
continue continue
if pad != (0, 0) and func is T.mean: if pad != (0, 0) and func is T.mean:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论