提交 65019192 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 of theano/scalar/sharedvar.py

上级 11a78c73
...@@ -14,17 +14,18 @@ default when calling theano.shared(value) then users must really go out of their ...@@ -14,17 +14,18 @@ default when calling theano.shared(value) then users must really go out of their
way (as scan does) to create a shared variable of this kind. way (as scan does) to create a shared variable of this kind.
""" """
__authors__ = "James Bergstra"
__copyright__ = "(c) 2010, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en"
import numpy import numpy
from theano.compile import SharedVariable from theano.compile import SharedVariable
from .basic import Scalar, _scalar_py_operators from .basic import Scalar, _scalar_py_operators
__authors__ = "James Bergstra"
__copyright__ = "(c) 2010, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en"
class ScalarSharedVariable(_scalar_py_operators, SharedVariable): class ScalarSharedVariable(_scalar_py_operators, SharedVariable):
pass pass
...@@ -41,7 +42,7 @@ def shared(value, name=None, strict=False, allow_downcast=None): ...@@ -41,7 +42,7 @@ def shared(value, name=None, strict=False, allow_downcast=None):
:note: We implement this using 0-d tensors for now. :note: We implement this using 0-d tensors for now.
""" """
if not isinstance (value, (numpy.number, float, int, complex)): if not isinstance(value, (numpy.number, float, int, complex)):
raise TypeError() raise TypeError()
try: try:
dtype = value.dtype dtype = value.dtype
...@@ -52,7 +53,9 @@ def shared(value, name=None, strict=False, allow_downcast=None): ...@@ -52,7 +53,9 @@ def shared(value, name=None, strict=False, allow_downcast=None):
value = getattr(numpy, dtype)(value) value = getattr(numpy, dtype)(value)
scalar_type = Scalar(dtype=dtype) scalar_type = Scalar(dtype=dtype)
rval = ScalarSharedVariable( rval = ScalarSharedVariable(
type=scalar_type, type=scalar_type,
value=value, value=value,
name=name, strict=strict, allow_downcast=allow_downcast) name=name,
strict=strict,
allow_downcast=allow_downcast)
return rval return rval
...@@ -114,7 +114,6 @@ whitelist_flake8 = [ ...@@ -114,7 +114,6 @@ whitelist_flake8 = [
"tensor/nnet/tests/test_conv3d.py", "tensor/nnet/tests/test_conv3d.py",
"tensor/nnet/tests/speed_test_conv.py", "tensor/nnet/tests/speed_test_conv.py",
"tensor/nnet/tests/test_sigm.py", "tensor/nnet/tests/test_sigm.py",
"scalar/sharedvar.py",
"scalar/basic_scipy.py", "scalar/basic_scipy.py",
"scalar/basic_sympy.py", "scalar/basic_sympy.py",
"scalar/__init__.py", "scalar/__init__.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论