提交 99612b72 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

pep8 elemwise

上级 3179ca7a
...@@ -278,7 +278,8 @@ class DimShuffle(Op): ...@@ -278,7 +278,8 @@ class DimShuffle(Op):
#get the copy / view of the input depending on whether we're doingi #get the copy / view of the input depending on whether we're doingi
# things inplace or not. # things inplace or not.
if self.inplace: if self.inplace:
get_base = ['{ PyArrayObject * %(basename)s = %(input)s', 'Py_INCREF((PyObject*)%(basename)s)'] get_base = [
'{ PyArrayObject * %(basename)s = %(input)s', 'Py_INCREF((PyObject*)%(basename)s)']
else: else:
get_base = [('{ PyArrayObject * %(basename)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s, NULL,' get_base = [('{ PyArrayObject * %(basename)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s, NULL,'
'0, 0, NPY_ALIGNED|NPY_ENSURECOPY, NULL)')] '0, 0, NPY_ALIGNED|NPY_ENSURECOPY, NULL)')]
...@@ -286,7 +287,8 @@ class DimShuffle(Op): ...@@ -286,7 +287,8 @@ class DimShuffle(Op):
shape_statements = ['npy_intp dimensions[%i]' % nd_out] shape_statements = ['npy_intp dimensions[%i]' % nd_out]
for i, o in enumerate(self.new_order): for i, o in enumerate(self.new_order):
if o != 'x': if o != 'x':
shape_statements += [('dimensions[' + str(i) + '] = %(basename)s->dimensions[' + str(o) + ']')] shape_statements += [('dimensions[' + str(
i) + '] = %(basename)s->dimensions[' + str(o) + ']')]
else: else:
shape_statements += [('dimensions[' + str(i) + '] = 1')] shape_statements += [('dimensions[' + str(i) + '] = 1')]
...@@ -295,7 +297,8 @@ class DimShuffle(Op): ...@@ -295,7 +297,8 @@ class DimShuffle(Op):
#set the strides of the non-broadcasted dimensions #set the strides of the non-broadcasted dimensions
for i, o in enumerate(self.new_order): for i, o in enumerate(self.new_order):
if o != 'x': if o != 'x':
strides_statements += [('strides[' + str(i) + '] = %(basename)s->strides[' + str(o) + ']')] strides_statements += [('strides[' + str(i)
+ '] = %(basename)s->strides[' + str(o) + ']')]
else: else:
strides_statements += [('strides[' + str(i) + '] = 0')] strides_statements += [('strides[' + str(i) + '] = 0')]
...@@ -311,7 +314,8 @@ class DimShuffle(Op): ...@@ -311,7 +314,8 @@ class DimShuffle(Op):
'-1] = %(basename)s->descr->elsize' '-1] = %(basename)s->descr->elsize'
) )
for i in xrange(nd_out - 2, -1, -1): for i in xrange(nd_out - 2, -1, -1):
strides_statements.append("if (strides[%(i)s] == 0) strides[%(i)s] = strides[%(i)s+1] * dimensions[%(i)s+1]" % dict(i=str(i))) strides_statements.append(
"if (strides[%(i)s] == 0) strides[%(i)s] = strides[%(i)s+1] * dimensions[%(i)s+1]" % dict(i=str(i)))
# #
# PyObject* PyArray_New(PyTypeObject* subtype, int nd, npy_intp* dims, int type_num, # PyObject* PyArray_New(PyTypeObject* subtype, int nd, npy_intp* dims, int type_num,
...@@ -619,7 +623,6 @@ class Elemwise(Op): ...@@ -619,7 +623,6 @@ class Elemwise(Op):
return rval return rval
def connection_pattern(self, node): def connection_pattern(self, node):
if hasattr(self.scalar_op, 'connection_pattern'): if hasattr(self.scalar_op, 'connection_pattern'):
...@@ -686,7 +689,7 @@ class Elemwise(Op): ...@@ -686,7 +689,7 @@ class Elemwise(Op):
theano.config.compute_test_value = prev_setting theano.config.compute_test_value = prev_setting
if not isinstance(scalar_igrads,(list,tuple)): if not isinstance(scalar_igrads, (list, tuple)):
raise TypeError('%s.grad returned %s instead of list or tuple' % raise TypeError('%s.grad returned %s instead of list or tuple' %
(str(self.scalar_op), str(type(scalar_igrads)))) (str(self.scalar_op), str(type(scalar_igrads))))
...@@ -1340,7 +1343,8 @@ class CAReduce(Op): ...@@ -1340,7 +1343,8 @@ class CAReduce(Op):
alloc += """ alloc += """
for(int i=0;i<%(iname)s->nd;i++){ for(int i=0;i<%(iname)s->nd;i++){
if(PyArray_DIMS(%(iname)s)[i]==0 && tosum[i]){ if(PyArray_DIMS(%(iname)s)[i]==0 && tosum[i]){
PyErr_Format(PyExc_ValueError, "Input of CAReduce{%(scal_name)s} has zero-size on axis %%d",i); PyErr_Format(PyExc_ValueError,
"Input of CAReduce{%(scal_name)s} has zero-size on axis %%d",i);
%(fail)s; %(fail)s;
} }
} }
...@@ -1718,7 +1722,7 @@ class Prod(CAReduceDtype): ...@@ -1718,7 +1722,7 @@ class Prod(CAReduceDtype):
out = self(*inp) out = self(*inp)
if out.dtype[0:3] in ('int', 'uin'): if out.dtype[0:3] in ('int', 'uin'):
return [ prod_in.zeros_like().astype(theano.config.floatX) ] return [prod_in.zeros_like().astype(theano.config.floatX)]
# Prepare the broadcasting that is used everywhere to broadcast # Prepare the broadcasting that is used everywhere to broadcast
# over the original groups (ie. broadcast over the elements of a given # over the original groups (ie. broadcast over the elements of a given
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论