提交 5ae5123d authored 作者: Frederic's avatar Frederic

test pep8 and better error msg

上级 ca732b67
...@@ -369,12 +369,14 @@ class CompressTester(utt.InferShapeTester): ...@@ -369,12 +369,14 @@ class CompressTester(utt.InferShapeTester):
self.op = compress self.op = compress
def test_op(self): def test_op(self):
for axis, cond, shape in zip(self.axis_list, self.cond_list, self.shape_list): for axis, cond, shape in zip(self.axis_list, self.cond_list,
self.shape_list):
cond_var = theano.tensor.ivector() cond_var = theano.tensor.ivector()
data = numpy.random.random(size=shape).astype(theano.config.floatX) data = numpy.random.random(size=shape).astype(theano.config.floatX)
data_var = theano.tensor.matrix() data_var = theano.tensor.matrix()
f = theano.function([cond_var, data_var], self.op(cond_var, data_var, axis=axis)) f = theano.function([cond_var, data_var],
self.op(cond_var, data_var, axis=axis))
expected = numpy.compress(cond, data, axis=axis) expected = numpy.compress(cond, data, axis=axis)
tested = f(cond, data) tested = f(cond, data)
......
...@@ -5053,7 +5053,7 @@ class Test_lift_transpose_through_dot(unittest.TestCase): ...@@ -5053,7 +5053,7 @@ class Test_lift_transpose_through_dot(unittest.TestCase):
a, b = matrices('ab') a, b = matrices('ab')
g = self.simple_optimize(FunctionGraph([a, b], [tensor.dot(a, b).T])) g = self.simple_optimize(FunctionGraph([a, b], [tensor.dot(a, b).T]))
sg = '[dot(DimShuffle{1,0}(b), DimShuffle{1,0}(a))]' sg = '[dot(DimShuffle{1,0}(b), DimShuffle{1,0}(a))]'
assert str(g) == sg assert str(g) == sg, (str(g), sg)
def test_row_matrix(self): def test_row_matrix(self):
a = vector('a') a = vector('a')
...@@ -5063,7 +5063,7 @@ class Test_lift_transpose_through_dot(unittest.TestCase): ...@@ -5063,7 +5063,7 @@ class Test_lift_transpose_through_dot(unittest.TestCase):
[tensor.dot(a.dimshuffle('x', 0), b).T]), [tensor.dot(a.dimshuffle('x', 0), b).T]),
level='stabilize') level='stabilize')
sg = '[dot(DimShuffle{1,0}(b), DimShuffle{0,x}(a))]' sg = '[dot(DimShuffle{1,0}(b), DimShuffle{0,x}(a))]'
assert str(g) == sg assert str(g) == sg, (str(g), sg)
def test_matrix_col(self): def test_matrix_col(self):
a = vector('a') a = vector('a')
...@@ -5073,7 +5073,7 @@ class Test_lift_transpose_through_dot(unittest.TestCase): ...@@ -5073,7 +5073,7 @@ class Test_lift_transpose_through_dot(unittest.TestCase):
[tensor.dot(b, a.dimshuffle(0, 'x')).T]), [tensor.dot(b, a.dimshuffle(0, 'x')).T]),
level='stabilize') level='stabilize')
sg = '[dot(DimShuffle{x,0}(a), DimShuffle{1,0}(b))]' sg = '[dot(DimShuffle{x,0}(a), DimShuffle{1,0}(b))]'
assert str(g) == sg assert str(g) == sg, (str(g), sg)
def test_local_upcast_elemwise_constant_inputs(): def test_local_upcast_elemwise_constant_inputs():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论