提交 94bdc43c authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 of theano/scalar/basic_sympy.py

上级 2b73732a
import numpy as np import itertools as it
from theano.scalar.basic import Apply, ScalarOp, as_scalar, float64, float32, int64 from theano.scalar.basic import Apply, ScalarOp, as_scalar, float64, float32, int64
from theano.gof.utils import remove from theano.gof.utils import remove
imported_sympy = False imported_sympy = False
try: try:
import sympy
from sympy.utilities.codegen import get_default_datatype, codegen from sympy.utilities.codegen import get_default_datatype, codegen
imported_sympy = True imported_sympy = True
except ImportError: except ImportError:
pass pass
import itertools as it names = ("sympy_func_%d" % i for i in it.count(0))
names = ("sympy_func_%d"%i for i in it.count(0))
def include_line(line): def include_line(line):
...@@ -53,8 +51,8 @@ class SymPyCCode(ScalarOp): ...@@ -53,8 +51,8 @@ class SymPyCCode(ScalarOp):
def _sympy_c_code(self): def _sympy_c_code(self):
[(c_name, c_code), (h_name, c_header)] = codegen( [(c_name, c_code), (h_name, c_header)] = codegen(
(self.name, self.expr), 'C', 'project_name', (self.name, self.expr), 'C', 'project_name',
header=False, argument_sequence=self.inputs) header=False, argument_sequence=self.inputs)
return c_code return c_code
def c_support_code(self): def c_support_code(self):
...@@ -64,8 +62,8 @@ class SymPyCCode(ScalarOp): ...@@ -64,8 +62,8 @@ class SymPyCCode(ScalarOp):
def c_headers(self): def c_headers(self):
c_code = self._sympy_c_code() c_code = self._sympy_c_code()
return [line.replace("#include", "").strip() for line in return [line.replace("#include", "").strip() for line in
c_code.split('\n') if include_line(line) c_code.split('\n') if include_line(line) and
and not 'project_name' in line] 'project_name' not in line]
def c_code(self, node, name, input_names, output_names, sub): def c_code(self, node, name, input_names, output_names, sub):
y, = output_names y, = output_names
...@@ -92,7 +90,7 @@ class SymPyCCode(ScalarOp): ...@@ -92,7 +90,7 @@ class SymPyCCode(ScalarOp):
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
return [SymPyCCode(self.inputs, return [SymPyCCode(self.inputs,
self.expr.diff(inp), self.expr.diff(inp),
name=self.name+"_grad_%d"%i)(*inputs) name=self.name + "_grad_%d" % i)(*inputs)
for i, inp in enumerate(self.inputs)] for i, inp in enumerate(self.inputs)]
def _info(self): def _info(self):
......
...@@ -114,7 +114,6 @@ whitelist_flake8 = [ ...@@ -114,7 +114,6 @@ whitelist_flake8 = [
"tensor/nnet/tests/test_conv3d.py", "tensor/nnet/tests/test_conv3d.py",
"tensor/nnet/tests/speed_test_conv.py", "tensor/nnet/tests/speed_test_conv.py",
"tensor/nnet/tests/test_sigm.py", "tensor/nnet/tests/test_sigm.py",
"scalar/basic_sympy.py",
"scalar/__init__.py", "scalar/__init__.py",
"scalar/tests/test_basic.py", "scalar/tests/test_basic.py",
"sandbox/test_theano_object.py", "sandbox/test_theano_object.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论