提交 3fb13c75 authored 作者: erakra's avatar erakra

fixing flake8 for test_blas.py

上级 184774c1
...@@ -208,7 +208,7 @@ class t_gemm(TestCase): ...@@ -208,7 +208,7 @@ class t_gemm(TestCase):
assert f[0].op == gemm_inplace assert f[0].op == gemm_inplace
def test_destroy_map0(self): def test_destroy_map0(self):
"""test that only first input can be overwritten""" # test that only first input can be overwritten.
Z = as_tensor_variable(self.rand(2, 2)) Z = as_tensor_variable(self.rand(2, 2))
try: try:
gemm_inplace(Z, 1.0, Z, Z, 1.0) gemm_inplace(Z, 1.0, Z, Z, 1.0)
...@@ -218,7 +218,7 @@ class t_gemm(TestCase): ...@@ -218,7 +218,7 @@ class t_gemm(TestCase):
self.fail() self.fail()
def test_destroy_map1(self): def test_destroy_map1(self):
"""test that only first input can be overwritten""" # test that only first input can be overwritten.
Z = as_tensor_variable(self.rand(2, 2)) Z = as_tensor_variable(self.rand(2, 2))
A = as_tensor_variable(self.rand(2, 2)) A = as_tensor_variable(self.rand(2, 2))
try: try:
...@@ -229,7 +229,7 @@ class t_gemm(TestCase): ...@@ -229,7 +229,7 @@ class t_gemm(TestCase):
self.fail() self.fail()
def test_destroy_map2(self): def test_destroy_map2(self):
"""test that only first input can be overwritten""" # test that only first input can be overwritten.
Z = as_tensor_variable(self.rand(2, 2)) Z = as_tensor_variable(self.rand(2, 2))
A = as_tensor_variable(self.rand(2, 2)) A = as_tensor_variable(self.rand(2, 2))
try: try:
...@@ -240,7 +240,7 @@ class t_gemm(TestCase): ...@@ -240,7 +240,7 @@ class t_gemm(TestCase):
self.fail() self.fail()
def test_destroy_map3(self): def test_destroy_map3(self):
"""test that only first input can be overwritten""" # test that only first input can be overwritten
Z = as_tensor_variable(self.rand(2, 2)) Z = as_tensor_variable(self.rand(2, 2))
A = as_tensor_variable(self.rand(2, 2)) A = as_tensor_variable(self.rand(2, 2))
try: try:
...@@ -251,7 +251,7 @@ class t_gemm(TestCase): ...@@ -251,7 +251,7 @@ class t_gemm(TestCase):
self.fail() self.fail()
def test_destroy_map4(self): def test_destroy_map4(self):
"""test that dot args can be aliased""" # test that dot args can be aliased
Z = shared(self.rand(2, 2), name='Z') Z = shared(self.rand(2, 2), name='Z')
A = shared(self.rand(2, 2), name='A') A = shared(self.rand(2, 2), name='A')
one = T.constant(1.0).astype(Z.dtype) one = T.constant(1.0).astype(Z.dtype)
...@@ -386,7 +386,7 @@ def test_res_is_a(): ...@@ -386,7 +386,7 @@ def test_res_is_a():
class t_as_scalar(TestCase): class t_as_scalar(TestCase):
def test0(self): def test0(self):
"""Test that it works on scalar constants""" # Test that it works on scalar constants
a = T.constant(2.5) a = T.constant(2.5)
b = T.constant(np.asarray([[[0.5]]])) b = T.constant(np.asarray([[[0.5]]]))
b2 = b.dimshuffle() b2 = b.dimshuffle()
...@@ -402,13 +402,13 @@ class t_as_scalar(TestCase): ...@@ -402,13 +402,13 @@ class t_as_scalar(TestCase):
self.assertTrue(_as_scalar(d_a2) != d_a2) self.assertTrue(_as_scalar(d_a2) != d_a2)
def test1(self): def test1(self):
"""Test that it fails on nonscalar constants""" # Test that it fails on nonscalar constants
a = T.constant(np.ones(5)) a = T.constant(np.ones(5))
self.assertTrue(_as_scalar(a) is None) self.assertTrue(_as_scalar(a) is None)
self.assertTrue(_as_scalar(T.DimShuffle([False], [0, 'x'])(a)) is None) self.assertTrue(_as_scalar(T.DimShuffle([False], [0, 'x'])(a)) is None)
def test2(self): def test2(self):
"""Test that it works on scalar variables""" # Test that it works on scalar variables
a = T.dscalar() a = T.dscalar()
d_a = T.DimShuffle([], [])(a) d_a = T.DimShuffle([], [])(a)
d_a2 = T.DimShuffle([], ['x', 'x'])(a) d_a2 = T.DimShuffle([], ['x', 'x'])(a)
...@@ -418,7 +418,7 @@ class t_as_scalar(TestCase): ...@@ -418,7 +418,7 @@ class t_as_scalar(TestCase):
self.assertTrue(_as_scalar(d_a2) is a) self.assertTrue(_as_scalar(d_a2) is a)
def test3(self): def test3(self):
"""Test that it fails on nonscalar variables""" # Test that it fails on nonscalar variables
a = T.matrix() a = T.matrix()
self.assertTrue(_as_scalar(a) is None) self.assertTrue(_as_scalar(a) is None)
self.assertTrue(_as_scalar(T.DimShuffle([False, False], self.assertTrue(_as_scalar(T.DimShuffle([False, False],
...@@ -502,7 +502,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], ...@@ -502,7 +502,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
def test_gemm_opt0(): def test_gemm_opt0():
"""Many subgraphs whose dots can be eliminated""" # Many subgraphs whose dots can be eliminated
X, Y, Z, a, b = XYZab() X, Y, Z, a, b = XYZab()
just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a + Z * b]) just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a + Z * b])
...@@ -530,7 +530,7 @@ def test_gemm_opt0(): ...@@ -530,7 +530,7 @@ def test_gemm_opt0():
def test_gemm_opt_double_gemm(): def test_gemm_opt_double_gemm():
"""This is the pattern that shows up in the autoencoder""" # This is the pattern that shows up in the autoencoder
X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar() X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
R, S, c = T.matrix(), T.matrix(), T.scalar() R, S, c = T.matrix(), T.matrix(), T.scalar()
...@@ -705,11 +705,10 @@ def test_gemm_opt_wishlist(): ...@@ -705,11 +705,10 @@ def test_gemm_opt_wishlist():
def test_gemm_with_vector(): def test_gemm_with_vector():
"""Many subgraphs whose dots can be eliminated. This adds a # Many subgraphs whose dots can be eliminated. This adds a
vector two the previous test, which triggers the long-sought GEMM # vector two the previous test, which triggers the long-sought GEMM
bug. # bug.
"""
X, Y, Z, a, b = XYZab() X, Y, Z, a, b = XYZab()
v = T.vector() v = T.vector()
...@@ -750,13 +749,12 @@ def test_gemm_opt_vector_stuff(): ...@@ -750,13 +749,12 @@ def test_gemm_opt_vector_stuff():
def test_gemm_unrolled(): def test_gemm_unrolled():
"""This test that the gemm optimizer remove the dot22 that was # This test that the gemm optimizer remove the dot22 that was
present in the graph. Otherwise, this add a gemm, but still # present in the graph. Otherwise, this add a gemm, but still
compute the dot22. # compute the dot22.
This was not always the case in the with this the following code. # This was not always the case in the with this the following code.
"""
batch_size = 100 batch_size = 100
rep_size = 40 rep_size = 40
rng = np.random.RandomState([1, 2, 3]) rng = np.random.RandomState([1, 2, 3])
...@@ -985,9 +983,7 @@ def test_dot22scalar(): ...@@ -985,9 +983,7 @@ def test_dot22scalar():
def test_dot22scalar_cast(): def test_dot22scalar_cast():
""" # Test that in `dot22_to_dot22scalar` we properly cast integers to floats.
Test that in `dot22_to_dot22scalar` we properly cast integers to floats.
"""
# Note that this test was failing before d5ff6904. # Note that this test was failing before d5ff6904.
A = T.dmatrix() A = T.dmatrix()
for scalar_int_type in T.int_dtypes: for scalar_int_type in T.int_dtypes:
...@@ -1005,9 +1001,7 @@ def test_dot22scalar_cast(): ...@@ -1005,9 +1001,7 @@ def test_dot22scalar_cast():
def test_local_dot22_to_dot22scalar(): def test_local_dot22_to_dot22scalar():
""" # This test that the bug in gh-1507 is really fixed
This test that the bug in gh-1507 is really fixed
"""
A = T.dmatrix() A = T.dmatrix()
mode = theano.compile.mode.get_default_mode() mode = theano.compile.mode.get_default_mode()
opt = theano.tensor.opt.in2out( opt = theano.tensor.opt.in2out(
...@@ -1643,25 +1637,25 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin): ...@@ -1643,25 +1637,25 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
return T.as_tensor_variable(np.asarray(bval, dtype=self.dtype)) return T.as_tensor_variable(np.asarray(bval, dtype=self.dtype))
def test_b_0_triggers_ger(self): def test_b_0_triggers_ger(self):
""" test local_gemm_to_ger opt""" # test local_gemm_to_ger opt
assert T.blas.local_gemm_to_ger.transform( assert T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'), gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(0)).owner) self.y.dimshuffle('x', 0), self.b(0)).owner)
def test_b_1_triggers_ger(self): def test_b_1_triggers_ger(self):
""" test local_gemm_to_ger opt""" # test local_gemm_to_ger opt
assert T.blas.local_gemm_to_ger.transform( assert T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'), gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(1)).owner) self.y.dimshuffle('x', 0), self.b(1)).owner)
def test_b_other_does_not_triggers_ger(self): def test_b_other_does_not_triggers_ger(self):
""" test local_gemm_to_ger opt""" # test local_gemm_to_ger opt
assert not T.blas.local_gemm_to_ger.transform( assert not T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'), gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(1.5)).owner) self.y.dimshuffle('x', 0), self.b(1.5)).owner)
def test_b_nonconst_does_not_triggers_ger(self): def test_b_nonconst_does_not_triggers_ger(self):
""" test local_gemm_to_ger opt""" # test local_gemm_to_ger opt
assert not T.blas.local_gemm_to_ger.transform( assert not T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'), gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.a).owner) self.y.dimshuffle('x', 0), self.a).owner)
...@@ -1710,7 +1704,7 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin): ...@@ -1710,7 +1704,7 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
np.random.rand(4).astype(self.dtype)) np.random.rand(4).astype(self.dtype))
def given_dtype(self, dtype, M, N): def given_dtype(self, dtype, M, N):
""" test corner case shape and dtype""" # test corner case shape and dtype
f = self.function([self.A, self.x, self.y], f = self.function([self.A, self.x, self.y],
self.A + 0.1 * T.outer(self.x, self.y)) self.A + 0.1 * T.outer(self.x, self.y))
...@@ -2155,7 +2149,7 @@ class TestBlasStrides(TestCase): ...@@ -2155,7 +2149,7 @@ class TestBlasStrides(TestCase):
self.cmp_ger((0, 0), 0, 0) self.cmp_ger((0, 0), 0, 0)
def test_gemm_non_contiguous(self): def test_gemm_non_contiguous(self):
"""test_gemm_non_contiguous: Test if GEMM works well with non-contiguous matrices.""" # test_gemm_non_contiguous: Test if GEMM works well with non-contiguous matrices.
aval = np.ones((6, 2)) aval = np.ones((6, 2))
bval = np.ones((2, 7)) bval = np.ones((2, 7))
cval = np.arange(7) + np.arange(0, .6, .1)[:, np.newaxis] cval = np.arange(7) + np.arange(0, .6, .1)[:, np.newaxis]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论