提交 56a54835 authored 作者: Frederic's avatar Frederic

pep8

上级 9da3e43d
import traceback import traceback
import numpy import numpy
import theano.tensor.basic import theano.tensor.basic
from basic import TensorType, _tensor_py_operators, autocast_int, autocast_float from basic import TensorType, _tensor_py_operators
from theano.compile import shared_constructor, SharedVariable from theano.compile import shared_constructor, SharedVariable
from theano import config
def load_shared_variable(val): def load_shared_variable(val):
"""This function is only here to keep some pickles loading """This function is only here to keep some pickles loading
...@@ -11,35 +13,40 @@ def load_shared_variable(val): ...@@ -11,35 +13,40 @@ def load_shared_variable(val):
It can be removed after sufficient time has passed.""" It can be removed after sufficient time has passed."""
return tensor_constructor(val) return tensor_constructor(val)
# _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__ # _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__
class TensorSharedVariable(_tensor_py_operators, SharedVariable): class TensorSharedVariable(_tensor_py_operators, SharedVariable):
pass pass
@shared_constructor @shared_constructor
def tensor_constructor(value, name=None, strict=False, allow_downcast=None, borrow=False, broadcastable=None): def tensor_constructor(value, name=None, strict=False, allow_downcast=None,
borrow=False, broadcastable=None):
"""SharedVariable Constructor for TensorType """SharedVariable Constructor for TensorType
:note: Regarding the inference of the broadcastable pattern... :note: Regarding the inference of the broadcastable pattern...
The default is to assume that the value might be resized in any dimension, so the default The default is to assume that the value might be resized in any
broadcastable is ``(False,)*len(value.shape)``. The optional `broadcastable` argument will dimension, so the default broadcastable is
override this default. ``(False,)*len(value.shape)``. The optional `broadcastable`
argument will override this default.
""" """
if not isinstance(value, numpy.ndarray): if not isinstance(value, numpy.ndarray):
raise TypeError() raise TypeError()
# if no broadcastable is given, then the default is to assume that the value might be # if no broadcastable is given, then the default is to assume that
# resized in any dimension in the future. # the value might be resized in any dimension in the future.
# #
if broadcastable is None: if broadcastable is None:
broadcastable = (False,)*len(value.shape) broadcastable = (False,) * len(value.shape)
type = TensorType(value.dtype, broadcastable=broadcastable) type = TensorType(value.dtype, broadcastable=broadcastable)
return TensorSharedVariable(type=type, return TensorSharedVariable(type=type,
value=numpy.array(value,copy=(not borrow)), value=numpy.array(value, copy=(not borrow)),
name=name, name=name,
strict=strict, strict=strict,
allow_downcast=allow_downcast) allow_downcast=allow_downcast)
# TensorSharedVariable brings in the tensor operators, is not ideal, but works # TensorSharedVariable brings in the tensor operators, is not ideal, but works
# as long as we dont do purely scalar-scalar operations # as long as we dont do purely scalar-scalar operations
# _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__ # _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__
...@@ -50,6 +57,7 @@ def tensor_constructor(value, name=None, strict=False, allow_downcast=None, borr ...@@ -50,6 +57,7 @@ def tensor_constructor(value, name=None, strict=False, allow_downcast=None, borr
class ScalarSharedVariable(_tensor_py_operators, SharedVariable): class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
pass pass
@shared_constructor @shared_constructor
def scalar_constructor(value, name=None, strict=False, allow_downcast=None): def scalar_constructor(value, name=None, strict=False, allow_downcast=None):
"""SharedVariable constructor for scalar values. Default: int64 or float64. """SharedVariable constructor for scalar values. Default: int64 or float64.
...@@ -57,14 +65,14 @@ def scalar_constructor(value, name=None, strict=False, allow_downcast=None): ...@@ -57,14 +65,14 @@ def scalar_constructor(value, name=None, strict=False, allow_downcast=None):
:note: We implement this using 0-d tensors for now. :note: We implement this using 0-d tensors for now.
""" """
if not isinstance (value, (numpy.number, float, int, complex)): if not isinstance(value, (numpy.number, float, int, complex)):
raise TypeError() raise TypeError()
try: try:
dtype=value.dtype dtype = value.dtype
except Exception: except Exception:
dtype=numpy.asarray(value).dtype dtype = numpy.asarray(value).dtype
dtype=str(dtype) dtype = str(dtype)
value = theano._asarray(value, dtype=dtype) value = theano._asarray(value, dtype=dtype)
tensor_type = TensorType(dtype=str(value.dtype), broadcastable=[]) tensor_type = TensorType(dtype=str(value.dtype), broadcastable=[])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论