提交 fa143f13 authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

type.py has been modified in order to respect the flake8 style.

上级 4fd12c42
...@@ -130,10 +130,10 @@ class CudaNdarrayType(Type): ...@@ -130,10 +130,10 @@ class CudaNdarrayType(Type):
type(data) is float and type(data) is float and
self.dtype == theano.config.floatX): self.dtype == theano.config.floatX):
return cuda.filter(converted_data, self.broadcastable, return cuda.filter(converted_data, self.broadcastable,
strict, old_data) strict, old_data)
elif numpy.all(data == converted_data): elif numpy.all(data == converted_data):
return cuda.filter(converted_data, self.broadcastable, return cuda.filter(converted_data, self.broadcastable,
strict, old_data) strict, old_data)
else: else:
raise TypeError( raise TypeError(
'%s, with dtype %s, cannot store accurately value %s, ' '%s, with dtype %s, cannot store accurately value %s, '
...@@ -259,8 +259,8 @@ class CudaNdarrayType(Type): ...@@ -259,8 +259,8 @@ class CudaNdarrayType(Type):
'complex64': (complex, 'theano_complex64', 'complex64': (complex, 'theano_complex64',
'NPY_COMPLEX64')}[self.dtype] 'NPY_COMPLEX64')}[self.dtype]
except KeyError: except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % ( raise TypeError("Unsupported dtype for %s: %s" %
self.__class__.__name__, self.dtype)) (self.__class__.__name__, self.dtype))
def __eq__(self, other): def __eq__(self, other):
""" """
...@@ -271,10 +271,11 @@ class CudaNdarrayType(Type): ...@@ -271,10 +271,11 @@ class CudaNdarrayType(Type):
other.broadcastable == self.broadcastable) other.broadcastable == self.broadcastable)
def convert_variable(self, var): def convert_variable(self, var):
if (type(self) == type(var.type) and if (isinstance(self, type(var.type)) and
self.ndim == var.type.ndim and self.ndim == var.type.ndim and
all(sb == ob or ob for sb, ob in zip(self.broadcastable, all(sb == ob or ob for sb, ob in zip(
var.type.broadcastable))): self.broadcastable,
var.type.broadcastable))):
return theano.tensor.patternbroadcast(var, self.broadcastable) return theano.tensor.patternbroadcast(var, self.broadcastable)
def __hash__(self): def __hash__(self):
...@@ -312,7 +313,7 @@ class CudaNdarrayType(Type): ...@@ -312,7 +313,7 @@ class CudaNdarrayType(Type):
return self.name return self.name
else: else:
b = self.broadcastable b = self.broadcastable
#bcast = str(self.broadcastable) # bcast = str(self.broadcastable)
if not numpy.any(b): if not numpy.any(b):
s = "%iD" % len(b) s = "%iD" % len(b)
else: else:
...@@ -327,7 +328,7 @@ class CudaNdarrayType(Type): ...@@ -327,7 +328,7 @@ class CudaNdarrayType(Type):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
#"CudaNdarrayType{%s, %s}" % (str(self.dtype), str(self.broadcastable)) # "CudaNdarrayType{%s, %s}" % (str(self.dtype), str(self.broadcastable))
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
return """ CudaNdarray * %(name)s;""" % locals() return """ CudaNdarray * %(name)s;""" % locals()
...@@ -417,7 +418,7 @@ class CudaNdarrayType(Type): ...@@ -417,7 +418,7 @@ class CudaNdarrayType(Type):
return sio.getvalue() return sio.getvalue()
def c_extract_out(self, name, sub, check_input=True, check_broadcast=True): def c_extract_out(self, name, sub, check_input=True, check_broadcast=True):
""" """
To allow the hack to skip check_broadcast. To allow the hack to skip check_broadcast.
""" """
...@@ -528,13 +529,13 @@ theano.compile.ops.expandable_types += (CudaNdarrayType,) ...@@ -528,13 +529,13 @@ theano.compile.ops.expandable_types += (CudaNdarrayType,)
# Register C code for ViewOp on CudaNdarrayType # Register C code for ViewOp on CudaNdarrayType
theano.compile.register_view_op_c_code( theano.compile.register_view_op_c_code(
CudaNdarrayType, CudaNdarrayType,
""" """
Py_XDECREF(%(oname)s); Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s; %(oname)s = %(iname)s;
Py_XINCREF(%(oname)s); Py_XINCREF(%(oname)s);
""", """,
version=1) version=1)
theano.compile.register_shape_i_c_code( theano.compile.register_shape_i_c_code(
CudaNdarrayType, CudaNdarrayType,
...@@ -555,16 +556,15 @@ theano.compile.register_shape_i_c_code( ...@@ -555,16 +556,15 @@ theano.compile.register_shape_i_c_code(
# Register CudaNdarrayType to the DeepCopyOp list of types with c code. # Register CudaNdarrayType to the DeepCopyOp list of types with c code.
theano.compile.register_deep_copy_op_c_code( theano.compile.register_deep_copy_op_c_code(
CudaNdarrayType, CudaNdarrayType,
""" """
int alloc = %(oname)s == NULL; int alloc = %(oname)s == NULL;
for(int i=0; !alloc && i<CudaNdarray_NDIM(%(oname)s); i++) { for(int i=0; !alloc && i<CudaNdarray_NDIM(%(oname)s); i++) {
if(CudaNdarray_HOST_DIMS(%(iname)s)[i] != if(CudaNdarray_HOST_DIMS(%(iname)s)[i] !=
CudaNdarray_HOST_DIMS(%(oname)s)[i]) { CudaNdarray_HOST_DIMS(%(oname)s)[i]) {
alloc = true; alloc = true;
break; break;
} }}
}
if(alloc) { if(alloc) {
Py_XDECREF(%(oname)s); Py_XDECREF(%(oname)s);
%(oname)s = (CudaNdarray*)CudaNdarray_Copy(%(iname)s); %(oname)s = (CudaNdarray*)CudaNdarray_Copy(%(iname)s);
...@@ -581,8 +581,7 @@ theano.compile.register_deep_copy_op_c_code( ...@@ -581,8 +581,7 @@ theano.compile.register_deep_copy_op_c_code(
%(fail)s; %(fail)s;
} }
} }
""", """, version=3)
version=3)
# THIS WORKS But CudaNdarray instances don't compare equal to one # THIS WORKS But CudaNdarray instances don't compare equal to one
...@@ -608,5 +607,5 @@ def CudaNdarray_pickler(cnda): ...@@ -608,5 +607,5 @@ def CudaNdarray_pickler(cnda):
# In case cuda is not imported. # In case cuda is not imported.
if cuda is not None: if cuda is not None:
copyreg.pickle(cuda.CudaNdarray, CudaNdarray_pickler, copyreg.pickle(
CudaNdarray_unpickler) cuda.CudaNdarray, CudaNdarray_pickler, CudaNdarray_unpickler)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论