提交 e6914181 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Flake8 fixes.

上级 7e03d1a9
import os.path
from theano import Op, Apply, config
from theano import Apply, config
from theano.compile import optdb
from theano.gof import local_optimizer, LocalOptGroup
......@@ -52,7 +52,7 @@ PyGpuArrayObject *gpublas_try_copy(PyGpuArrayObject *out,
class GpuGemv(BlasOp, Gemv):
def make_node(self, y, alpha, A, x, beta):
res = Gemv.make_node(self, y, alpha, A, x, beta)
Gemv.make_node(self, y, alpha, A, x, beta)
A = as_gpuarray_variable(A)
x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y)
......@@ -180,7 +180,7 @@ gpugemm_inplace = GpuGemm(inplace=True)
class GpuGer(BlasOp, Ger):
def make_node(self, A, alpha, x, y):
res = Ger.make_node(self, A, alpha, x, y)
Ger.make_node(self, A, alpha, x, y)
A = as_gpuarray_variable(A)
x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y)
......@@ -240,7 +240,7 @@ gpuger_inplace = GpuGer(destructive=True)
class GpuDot22(BlasOp, Dot22):
def make_node(self, x, y):
res = Dot22.make_node(self, x, y)
Dot22.make_node(self, x, y)
x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y)
assert x.dtype == y.dtype
......@@ -291,6 +291,7 @@ class GpuDot22(BlasOp, Dot22):
gpu_dot22 = GpuDot22()
@local_optimizer([gpugemv_no_inplace], inplace=True)
def local_inplace_gpuagemv(node):
if node.op == gpugemv_no_inplace:
......@@ -313,9 +314,11 @@ def local_inplace_gpuager(node):
if node.op == gpuger_no_inplace:
return [gpuger_inplace(*node.inputs)]
gpuablas_opt_inplace = in2out(LocalOptGroup(
local_inplace_gpuagemv, local_inplace_gpuagemm, local_inplace_gpuager),
gpuablas_opt_inplace = in2out(LocalOptGroup(local_inplace_gpuagemv,
local_inplace_gpuagemm,
local_inplace_gpuager),
name='gpuablas_opt_inplace')
optdb.register('InplaceGpuaBlasOpt',
gpuablas_opt_inplace,
70.0, 'fast_run', 'inplace', 'gpuarray')
......@@ -621,6 +621,7 @@ def local_gpua_hgemm(node):
shape_i(B, 1, fgraph))
return gpugemm_no_inplace(C, 1.0, A, B, 0.0)
@register_opt()
@alpha_merge(GpuGemm, alpha_in=1, beta_in=2, nd=2)
def local_gpuagemm_alpha_merge(node, *inputs):
......
......@@ -164,7 +164,6 @@ whitelist_flake8 = [
"sandbox/gpuarray/elemwise.py",
"sandbox/gpuarray/type.py",
"sandbox/gpuarray/__init__.py",
"sandbox/gpuarray/blas.py",
"sandbox/gpuarray/kernel_codegen.py",
"sandbox/gpuarray/conv.py",
"sandbox/gpuarray/neighbours.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论