提交 3fb13c75 authored 作者: erakra's avatar erakra

fixing flake8 for test_blas.py

上级 184774c1
......@@ -208,7 +208,7 @@ class t_gemm(TestCase):
assert f[0].op == gemm_inplace
def test_destroy_map0(self):
"""test that only first input can be overwritten"""
# test that only first input can be overwritten.
Z = as_tensor_variable(self.rand(2, 2))
try:
gemm_inplace(Z, 1.0, Z, Z, 1.0)
......@@ -218,7 +218,7 @@ class t_gemm(TestCase):
self.fail()
def test_destroy_map1(self):
"""test that only first input can be overwritten"""
# test that only first input can be overwritten.
Z = as_tensor_variable(self.rand(2, 2))
A = as_tensor_variable(self.rand(2, 2))
try:
......@@ -229,7 +229,7 @@ class t_gemm(TestCase):
self.fail()
def test_destroy_map2(self):
"""test that only first input can be overwritten"""
# test that only first input can be overwritten.
Z = as_tensor_variable(self.rand(2, 2))
A = as_tensor_variable(self.rand(2, 2))
try:
......@@ -240,7 +240,7 @@ class t_gemm(TestCase):
self.fail()
def test_destroy_map3(self):
"""test that only first input can be overwritten"""
# test that only first input can be overwritten
Z = as_tensor_variable(self.rand(2, 2))
A = as_tensor_variable(self.rand(2, 2))
try:
......@@ -251,7 +251,7 @@ class t_gemm(TestCase):
self.fail()
def test_destroy_map4(self):
"""test that dot args can be aliased"""
# test that dot args can be aliased
Z = shared(self.rand(2, 2), name='Z')
A = shared(self.rand(2, 2), name='A')
one = T.constant(1.0).astype(Z.dtype)
......@@ -386,7 +386,7 @@ def test_res_is_a():
class t_as_scalar(TestCase):
def test0(self):
"""Test that it works on scalar constants"""
# Test that it works on scalar constants
a = T.constant(2.5)
b = T.constant(np.asarray([[[0.5]]]))
b2 = b.dimshuffle()
......@@ -402,13 +402,13 @@ class t_as_scalar(TestCase):
self.assertTrue(_as_scalar(d_a2) != d_a2)
def test1(self):
"""Test that it fails on nonscalar constants"""
# Test that it fails on nonscalar constants
a = T.constant(np.ones(5))
self.assertTrue(_as_scalar(a) is None)
self.assertTrue(_as_scalar(T.DimShuffle([False], [0, 'x'])(a)) is None)
def test2(self):
"""Test that it works on scalar variables"""
# Test that it works on scalar variables
a = T.dscalar()
d_a = T.DimShuffle([], [])(a)
d_a2 = T.DimShuffle([], ['x', 'x'])(a)
......@@ -418,7 +418,7 @@ class t_as_scalar(TestCase):
self.assertTrue(_as_scalar(d_a2) is a)
def test3(self):
"""Test that it fails on nonscalar variables"""
# Test that it fails on nonscalar variables
a = T.matrix()
self.assertTrue(_as_scalar(a) is None)
self.assertTrue(_as_scalar(T.DimShuffle([False, False],
......@@ -502,7 +502,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
def test_gemm_opt0():
"""Many subgraphs whose dots can be eliminated"""
# Many subgraphs whose dots can be eliminated
X, Y, Z, a, b = XYZab()
just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a + Z * b])
......@@ -530,7 +530,7 @@ def test_gemm_opt0():
def test_gemm_opt_double_gemm():
"""This is the pattern that shows up in the autoencoder"""
# This is the pattern that shows up in the autoencoder
X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
R, S, c = T.matrix(), T.matrix(), T.scalar()
......@@ -705,11 +705,10 @@ def test_gemm_opt_wishlist():
def test_gemm_with_vector():
"""Many subgraphs whose dots can be eliminated. This adds a
vector two the previous test, which triggers the long-sought GEMM
bug.
# Many subgraphs whose dots can be eliminated. This adds a
# vector two the previous test, which triggers the long-sought GEMM
# bug.
"""
X, Y, Z, a, b = XYZab()
v = T.vector()
......@@ -750,13 +749,12 @@ def test_gemm_opt_vector_stuff():
def test_gemm_unrolled():
"""This test that the gemm optimizer remove the dot22 that was
present in the graph. Otherwise, this add a gemm, but still
compute the dot22.
# This test that the gemm optimizer remove the dot22 that was
# present in the graph. Otherwise, this add a gemm, but still
# compute the dot22.
This was not always the case in the with this the following code.
# This was not always the case in the with this the following code.
"""
batch_size = 100
rep_size = 40
rng = np.random.RandomState([1, 2, 3])
......@@ -985,9 +983,7 @@ def test_dot22scalar():
def test_dot22scalar_cast():
"""
Test that in `dot22_to_dot22scalar` we properly cast integers to floats.
"""
# Test that in `dot22_to_dot22scalar` we properly cast integers to floats.
# Note that this test was failing before d5ff6904.
A = T.dmatrix()
for scalar_int_type in T.int_dtypes:
......@@ -1005,9 +1001,7 @@ def test_dot22scalar_cast():
def test_local_dot22_to_dot22scalar():
"""
This test that the bug in gh-1507 is really fixed
"""
# This test that the bug in gh-1507 is really fixed
A = T.dmatrix()
mode = theano.compile.mode.get_default_mode()
opt = theano.tensor.opt.in2out(
......@@ -1643,25 +1637,25 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
return T.as_tensor_variable(np.asarray(bval, dtype=self.dtype))
def test_b_0_triggers_ger(self):
""" test local_gemm_to_ger opt"""
# test local_gemm_to_ger opt
assert T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(0)).owner)
def test_b_1_triggers_ger(self):
""" test local_gemm_to_ger opt"""
# test local_gemm_to_ger opt
assert T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(1)).owner)
def test_b_other_does_not_triggers_ger(self):
""" test local_gemm_to_ger opt"""
# test local_gemm_to_ger opt
assert not T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(1.5)).owner)
def test_b_nonconst_does_not_triggers_ger(self):
""" test local_gemm_to_ger opt"""
# test local_gemm_to_ger opt
assert not T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.a).owner)
......@@ -1710,7 +1704,7 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
np.random.rand(4).astype(self.dtype))
def given_dtype(self, dtype, M, N):
""" test corner case shape and dtype"""
# test corner case shape and dtype
f = self.function([self.A, self.x, self.y],
self.A + 0.1 * T.outer(self.x, self.y))
......@@ -2155,7 +2149,7 @@ class TestBlasStrides(TestCase):
self.cmp_ger((0, 0), 0, 0)
def test_gemm_non_contiguous(self):
"""test_gemm_non_contiguous: Test if GEMM works well with non-contiguous matrices."""
# test_gemm_non_contiguous: Test if GEMM works well with non-contiguous matrices.
aval = np.ones((6, 2))
bval = np.ones((2, 7))
cval = np.arange(7) + np.arange(0, .6, .1)[:, np.newaxis]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论