提交 184774c1 authored 作者: erakra's avatar erakra

fixing flake8 for test_blas.py

上级 2618d044
...@@ -875,7 +875,7 @@ def test_dot22scalar(): ...@@ -875,7 +875,7 @@ def test_dot22scalar():
def check_dot22scalar(func, len_topo_scalar=-1): def check_dot22scalar(func, len_topo_scalar=-1):
topo = func.maker.fgraph.toposort() topo = func.maker.fgraph.toposort()
ops = [x.op for x in topo] ops = [x.op for x in topo]
classes = [type(x.op) for x in topo] # classes = [type(x.op) for x in topo]
dtype4_upcast = theano.scalar.upcast(dtype4, dtype1, dtype4_upcast = theano.scalar.upcast(dtype4, dtype1,
dtype2) dtype2)
...@@ -914,7 +914,7 @@ def test_dot22scalar(): ...@@ -914,7 +914,7 @@ def test_dot22scalar():
if False: if False:
f = theano.function([a, b], cst * T.dot(a, b), f = theano.function([a, b], cst * T.dot(a, b),
mode=mode_blas_opt) mode=mode_blas_opt)
topo = f.maker.fgraph.toposort() f.maker.fgraph.toposort()
check_dot22scalar(f, 1) check_dot22scalar(f, 1)
f(av, bv) f(av, bv)
...@@ -923,7 +923,7 @@ def test_dot22scalar(): ...@@ -923,7 +923,7 @@ def test_dot22scalar():
f = theano.function([a, b, c], f = theano.function([a, b, c],
cst * c * T.dot(a, b), cst * c * T.dot(a, b),
mode=mode_blas_opt) mode=mode_blas_opt)
topo = f.maker.fgraph.toposort() f.maker.fgraph.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 2)
f(av, bv, cv) f(av, bv, cv)
...@@ -931,7 +931,7 @@ def test_dot22scalar(): ...@@ -931,7 +931,7 @@ def test_dot22scalar():
f = theano.function([a, b, c], f = theano.function([a, b, c],
c * cst * T.dot(a, b), c * cst * T.dot(a, b),
mode=mode_blas_opt) mode=mode_blas_opt)
topo = f.maker.fgraph.toposort() f.maker.fgraph.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 2)
f(av, bv, cv) f(av, bv, cv)
...@@ -941,7 +941,7 @@ def test_dot22scalar(): ...@@ -941,7 +941,7 @@ def test_dot22scalar():
f = theano.function([a, b, c], f = theano.function([a, b, c],
cst2 * c * cst * T.dot(a, b), cst2 * c * cst * T.dot(a, b),
mode=m2) mode=m2)
topo = f.maker.fgraph.toposort() f.maker.fgraph.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 2)
f(av, bv, cv) f(av, bv, cv)
...@@ -949,14 +949,14 @@ def test_dot22scalar(): ...@@ -949,14 +949,14 @@ def test_dot22scalar():
f = theano.function([a, b, c], f = theano.function([a, b, c],
c * cst * a * T.dot(a, b), c * cst * a * T.dot(a, b),
mode=m2) mode=m2)
topo = f.maker.fgraph.toposort() f.maker.fgraph.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 2)
f(sv, sv, sv) f(sv, sv, sv)
f = theano.function([a, b, c], f = theano.function([a, b, c],
cst * c * a * T.dot(a, b), cst * c * a * T.dot(a, b),
mode=mode_blas_opt) mode=mode_blas_opt)
topo = f.maker.fgraph.toposort() f.maker.fgraph.toposort()
# currently the canonizer don't always # currently the canonizer don't always
# merge all Mul together... dot22scalar # merge all Mul together... dot22scalar
# optimizer does not do a recursive search # optimizer does not do a recursive search
...@@ -972,7 +972,7 @@ def test_dot22scalar(): ...@@ -972,7 +972,7 @@ def test_dot22scalar():
f = theano.function([a, b, c], f = theano.function([a, b, c],
c * a * cst * T.dot(a, b), c * a * cst * T.dot(a, b),
mode=m2) mode=m2)
topo = f.maker.fgraph.toposort() f.maker.fgraph.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 2)
f(sv, sv, sv) f(sv, sv, sv)
...@@ -1330,7 +1330,7 @@ class BaseGemv(object): ...@@ -1330,7 +1330,7 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode=self.mode) oy_func = theano.function([], oy, mode=self.mode)
topo = oy_func.maker.fgraph.toposort() oy_func.maker.fgraph.toposort()
self.assertFunctionContains1(oy_func, self.gemv) self.assertFunctionContains1(oy_func, self.gemv)
oy_val = oy_func() oy_val = oy_func()
...@@ -1393,8 +1393,8 @@ class BaseGemv(object): ...@@ -1393,8 +1393,8 @@ class BaseGemv(object):
alpha_v, beta_v, a_v, x_v, y_v = vs alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [self.shared(v) for v in vs] alpha, beta, a, x, y = [self.shared(v) for v in vs]
desired_oy = alpha_v * matrixmultiply(transpose(a_v), x_v[::2]) + desired_oy = alpha_v * matrixmultiply(transpose(a_v), x_v[::2]) + \
beta_v * y_v beta_v * y_v
oy = alpha * T.dot(a.T, x[::2]) + beta * y oy = alpha * T.dot(a.T, x[::2]) + beta * y
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论