提交 7cf174c8 authored 作者: notoraptor's avatar notoraptor

flake8

上级 c4c19aac
...@@ -104,7 +104,7 @@ class BaseTest: ...@@ -104,7 +104,7 @@ class BaseTest:
def compute_host(self, test_tensor, axis): def compute_host(self, test_tensor, axis):
M = self.get_host_tensor() M = self.get_host_tensor()
f = theano.function([M], [T.max(M, axis=axis), T.argmax(M, axis=axis)], f = theano.function([M], [T.max(M, axis=axis), T.argmax(M, axis=axis)],
name='HOST/shape:'+str(test_tensor.shape)+'/axis:'+str(axis), mode=mode_without_gpu) name='HOST/shape:' + str(test_tensor.shape) + '/axis:' + str(axis), mode=mode_without_gpu)
check_if_gpu_maxandargmax_not_in_graph(f) check_if_gpu_maxandargmax_not_in_graph(f)
f(test_tensor) f(test_tensor)
theano_max, theano_argmax = f(test_tensor) theano_max, theano_argmax = f(test_tensor)
...@@ -115,7 +115,7 @@ class BaseTest: ...@@ -115,7 +115,7 @@ class BaseTest:
def compute_gpu(self, test_gpu_tensor, test_host_tensor, axis): def compute_gpu(self, test_gpu_tensor, test_host_tensor, axis):
M = self.get_gpu_tensor() M = self.get_gpu_tensor()
f = theano.function([M], [T.max(M, axis=axis), T.argmax(M, axis=axis)], f = theano.function([M], [T.max(M, axis=axis), T.argmax(M, axis=axis)],
name='GPU/shape:'+str(test_gpu_tensor.shape)+'/axis:'+str(axis), mode=mode_with_gpu) name='GPU/shape:' + str(test_gpu_tensor.shape) + '/axis:' + str(axis), mode=mode_with_gpu)
check_if_gpu_maxandargmax_in_graph(f) check_if_gpu_maxandargmax_in_graph(f)
f(test_gpu_tensor) f(test_gpu_tensor)
theano_max, theano_argmax = f(test_gpu_tensor) theano_max, theano_argmax = f(test_gpu_tensor)
...@@ -176,10 +176,12 @@ class TestScalar(BaseTest, TestCase): ...@@ -176,10 +176,12 @@ class TestScalar(BaseTest, TestCase):
class TestVector(BaseTest, TestCase): class TestVector(BaseTest, TestCase):
tensor_size = 1 tensor_size = 1
# Special case # Special case
class TestRow(BaseTest, TestCase): class TestRow(BaseTest, TestCase):
tensor_size = 2 tensor_size = 2
shape = [1,test_size] shape = [1, test_size]
# Special case # Special case
class TestColumn(BaseTest, TestCase): class TestColumn(BaseTest, TestCase):
...@@ -193,4 +195,3 @@ class TestMatrix(BaseTest, TestCase): ...@@ -193,4 +195,3 @@ class TestMatrix(BaseTest, TestCase):
class TestTensor5(BaseTest, TestCase): class TestTensor5(BaseTest, TestCase):
tensor_size = 5 tensor_size = 5
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论